1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
|
local api = vim.api
local M = {}
--- Memoizes a function based on the buffer tick of the provided bufnr.
--- The cache entry is cleared when the buffer is detached to avoid memory leaks.
---@generic F: function
---@param fn F fn to memoize, taking the bufnr as first argument
---@return F
local function memoize_by_changedtick(fn)
---@type table<integer,{result:any,last_tick:integer}>
local cache = {}
---@param bufnr integer
return function(bufnr, ...)
local tick = api.nvim_buf_get_changedtick(bufnr)
if cache[bufnr] then
if cache[bufnr].last_tick == tick then
return cache[bufnr].result
end
else
local function detach_handler()
cache[bufnr] = nil
end
-- Clean up logic only!
api.nvim_buf_attach(bufnr, false, {
on_detach = detach_handler,
on_reload = detach_handler,
})
end
cache[bufnr] = {
result = fn(bufnr, ...),
last_tick = tick,
}
return cache[bufnr].result
end
end
---@param bufnr integer
---@param capture string
---@param query_name string
---@param callback fun(id: integer, node:TSNode, metadata: TSMetadata)
local function iter_matches_with_capture(bufnr, capture, query_name, callback)
local parser = vim.treesitter.get_parser(bufnr)
if not parser then
return
end
parser:for_each_tree(function(tree, lang_tree)
local lang = lang_tree:lang()
local query = vim.treesitter.query.get_query(lang, query_name)
if query then
local root = tree:root()
local start, _, stop = root:range()
for _, match, metadata in query:iter_matches(root, bufnr, start, stop) do
for id, node in pairs(match) do
if query.captures[id] == capture then
callback(id, node, metadata)
end
end
end
end
end)
end
---@private
--- TODO(lewis6991): copied from languagetree.lua. Consolidate
---@param node TSNode
---@param id integer
---@param metadata TSMetadata
---@return Range
local function get_range_from_metadata(node, id, metadata)
if metadata[id] and metadata[id].range then
return metadata[id].range --[[@as Range]]
end
return { node:range() }
end
-- This is cached on buf tick to avoid computing that multiple times
-- Especially not for every line in the file when `zx` is hit
---@param bufnr integer
---@return table<integer,string>
local folds_levels = memoize_by_changedtick(function(bufnr)
local max_fold_level = vim.wo.foldnestmax
local function trim_level(level)
if level > max_fold_level then
return max_fold_level
end
return level
end
-- start..stop is an inclusive range
local start_counts = {} ---@type table<integer,integer>
local stop_counts = {} ---@type table<integer,integer>
local prev_start = -1
local prev_stop = -1
local min_fold_lines = vim.wo.foldminlines
iter_matches_with_capture(bufnr, 'fold', 'folds', function(id, node, metadata)
local range = get_range_from_metadata(node, id, metadata)
local start, stop, stop_col = range[1], range[3], range[4]
if stop_col == 0 then
stop = stop - 1
end
local fold_length = stop - start + 1
-- Fold only multiline nodes that are not exactly the same as previously met folds
-- Checking against just the previously found fold is sufficient if nodes
-- are returned in preorder or postorder when traversing tree
if fold_length > min_fold_lines and not (start == prev_start and stop == prev_stop) then
start_counts[start] = (start_counts[start] or 0) + 1
stop_counts[stop] = (stop_counts[stop] or 0) + 1
prev_start = start
prev_stop = stop
end
end)
---@type table<integer,string>
local levels = {}
local current_level = 0
-- We now have the list of fold opening and closing, fill the gaps and mark where fold start
for lnum = 0, api.nvim_buf_line_count(bufnr) do
local last_trimmed_level = trim_level(current_level)
current_level = current_level + (start_counts[lnum] or 0)
local trimmed_level = trim_level(current_level)
current_level = current_level - (stop_counts[lnum] or 0)
-- Determine if it's the start/end of a fold
-- NB: vim's fold-expr interface does not have a mechanism to indicate that
-- two (or more) folds start at this line, so it cannot distinguish between
-- ( \n ( \n )) \n (( \n ) \n )
-- versus
-- ( \n ( \n ) \n ( \n ) \n )
-- If it did have such a mechanism, (trimmed_level - last_trimmed_level)
-- would be the correct number of starts to pass on.
local prefix = ''
if trimmed_level - last_trimmed_level > 0 then
prefix = '>'
end
levels[lnum + 1] = prefix .. tostring(trimmed_level)
end
return levels
end)
---@param lnum integer|nil
---@return string
function M.foldexpr(lnum)
lnum = lnum or vim.v.lnum
local bufnr = api.nvim_get_current_buf()
---@diagnostic disable-next-line:invisible
if not vim.treesitter._has_parser(bufnr) or not lnum then
return '0'
end
local levels = folds_levels(bufnr) or {}
return levels[lnum] or '0'
end
return M
|