aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter/_fold.lua
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/lua/vim/treesitter/_fold.lua')
-rw-r--r--runtime/lua/vim/treesitter/_fold.lua325
1 files changed, 163 insertions, 162 deletions
diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua
index 7237d2e7d4..38318347a7 100644
--- a/runtime/lua/vim/treesitter/_fold.lua
+++ b/runtime/lua/vim/treesitter/_fold.lua
@@ -19,76 +19,36 @@ local api = vim.api
---The range on which to evaluate foldexpr.
---When in insert mode, the evaluation is deferred to InsertLeave.
---@field foldupdate_range? Range2
+---
+---The treesitter parser associated with this buffer.
+---@field parser? vim.treesitter.LanguageTree
local FoldInfo = {}
FoldInfo.__index = FoldInfo
---@private
-function FoldInfo.new()
+---@param bufnr integer
+function FoldInfo.new(bufnr)
return setmetatable({
levels0 = {},
levels = {},
+ parser = ts.get_parser(bufnr, nil, { error = false }),
}, FoldInfo)
end
---- Efficiently remove items from middle of a list a list.
----
---- Calling table.remove() in a loop will re-index the tail of the table on
---- every iteration, instead this function will re-index the table exactly
---- once.
----
---- Based on https://stackoverflow.com/questions/12394841/safely-remove-items-from-an-array-table-while-iterating/53038524#53038524
----
----@param t any[]
----@param first integer
----@param last integer
-local function list_remove(t, first, last)
- local n = #t
- for i = 0, n - first do
- t[first + i] = t[last + 1 + i]
- t[last + 1 + i] = nil
- end
-end
-
---@package
---@param srow integer
---@param erow integer 0-indexed, exclusive
function FoldInfo:remove_range(srow, erow)
- list_remove(self.levels, srow + 1, erow)
- list_remove(self.levels0, srow + 1, erow)
-end
-
---- Efficiently insert items into the middle of a list.
----
---- Calling table.insert() in a loop will re-index the tail of the table on
---- every iteration, instead this function will re-index the table exactly
---- once.
----
---- Based on https://stackoverflow.com/questions/12394841/safely-remove-items-from-an-array-table-while-iterating/53038524#53038524
----
----@param t any[]
----@param first integer
----@param last integer
----@param v any
-local function list_insert(t, first, last, v)
- local n = #t
-
- -- Shift table forward
- for i = n - first, 0, -1 do
- t[last + 1 + i] = t[first + i]
- end
-
- -- Fill in new values
- for i = first, last do
- t[i] = v
- end
+ vim._list_remove(self.levels, srow + 1, erow)
+ vim._list_remove(self.levels0, srow + 1, erow)
end
---@package
---@param srow integer
---@param erow integer 0-indexed, exclusive
function FoldInfo:add_range(srow, erow)
- list_insert(self.levels, srow + 1, erow, -1)
- list_insert(self.levels0, srow + 1, erow, -1)
+ vim._list_insert(self.levels, srow + 1, erow, -1)
+ vim._list_insert(self.levels0, srow + 1, erow, -1)
end
---@param range Range2
@@ -109,111 +69,122 @@ end
---@param info TS.FoldInfo
---@param srow integer?
---@param erow integer? 0-indexed, exclusive
----@param parse_injections? boolean
-local function compute_folds_levels(bufnr, info, srow, erow, parse_injections)
+---@param callback function?
+local function compute_folds_levels(bufnr, info, srow, erow, callback)
srow = srow or 0
erow = erow or api.nvim_buf_line_count(bufnr)
- local parser = assert(ts.get_parser(bufnr, nil, { error = false }))
-
- parser:parse(parse_injections and { srow, erow } or nil)
-
- local enter_counts = {} ---@type table<integer, integer>
- local leave_counts = {} ---@type table<integer, integer>
- local prev_start = -1
- local prev_stop = -1
+ local parser = info.parser
+ if not parser then
+ return
+ end
- parser:for_each_tree(function(tree, ltree)
- local query = ts.query.get(ltree:lang(), 'folds')
- if not query then
+ parser:parse(nil, function(_, trees)
+ if not trees then
return
end
- -- Collect folds starting from srow - 1, because we should first subtract the folds that end at
- -- srow - 1 from the level of srow - 1 to get accurate level of srow.
- for _, match, metadata in query:iter_matches(tree:root(), bufnr, math.max(srow - 1, 0), erow) do
- for id, nodes in pairs(match) do
- if query.captures[id] == 'fold' then
- local range = ts.get_range(nodes[1], bufnr, metadata[id])
- local start, _, stop, stop_col = Range.unpack4(range)
-
- if #nodes > 1 then
- -- assumes nodes are ordered by range
- local end_range = ts.get_range(nodes[#nodes], bufnr, metadata[id])
- local _, _, end_stop, end_stop_col = Range.unpack4(end_range)
- stop = end_stop
- stop_col = end_stop_col
- end
+ local enter_counts = {} ---@type table<integer, integer>
+ local leave_counts = {} ---@type table<integer, integer>
+ local prev_start = -1
+ local prev_stop = -1
- if stop_col == 0 then
- stop = stop - 1
- end
+ parser:for_each_tree(function(tree, ltree)
+ local query = ts.query.get(ltree:lang(), 'folds')
+ if not query then
+ return
+ end
- local fold_length = stop - start + 1
-
- -- Fold only multiline nodes that are not exactly the same as previously met folds
- -- Checking against just the previously found fold is sufficient if nodes
- -- are returned in preorder or postorder when traversing tree
- if
- fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop)
- then
- enter_counts[start + 1] = (enter_counts[start + 1] or 0) + 1
- leave_counts[stop + 1] = (leave_counts[stop + 1] or 0) + 1
- prev_start = start
- prev_stop = stop
+ -- Collect folds starting from srow - 1, because we should first subtract the folds that end at
+ -- srow - 1 from the level of srow - 1 to get accurate level of srow.
+ for _, match, metadata in query:iter_matches(tree:root(), bufnr, math.max(srow - 1, 0), erow) do
+ for id, nodes in pairs(match) do
+ if query.captures[id] == 'fold' then
+ local range = ts.get_range(nodes[1], bufnr, metadata[id])
+ local start, _, stop, stop_col = Range.unpack4(range)
+
+ if #nodes > 1 then
+ -- assumes nodes are ordered by range
+ local end_range = ts.get_range(nodes[#nodes], bufnr, metadata[id])
+ local _, _, end_stop, end_stop_col = Range.unpack4(end_range)
+ stop = end_stop
+ stop_col = end_stop_col
+ end
+
+ if stop_col == 0 then
+ stop = stop - 1
+ end
+
+ local fold_length = stop - start + 1
+
+ -- Fold only multiline nodes that are not exactly the same as previously met folds
+ -- Checking against just the previously found fold is sufficient if nodes
+ -- are returned in preorder or postorder when traversing tree
+ if
+ fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop)
+ then
+ enter_counts[start + 1] = (enter_counts[start + 1] or 0) + 1
+ leave_counts[stop + 1] = (leave_counts[stop + 1] or 0) + 1
+ prev_start = start
+ prev_stop = stop
+ end
end
end
end
- end
- end)
+ end)
- local nestmax = vim.wo.foldnestmax
- local level0_prev = info.levels0[srow] or 0
- local leave_prev = leave_counts[srow] or 0
-
- -- We now have the list of fold opening and closing, fill the gaps and mark where fold start
- for lnum = srow + 1, erow do
- local enter_line = enter_counts[lnum] or 0
- local leave_line = leave_counts[lnum] or 0
- local level0 = level0_prev - leave_prev + enter_line
-
- -- Determine if it's the start/end of a fold
- -- NB: vim's fold-expr interface does not have a mechanism to indicate that
- -- two (or more) folds start at this line, so it cannot distinguish between
- -- ( \n ( \n )) \n (( \n ) \n )
- -- versus
- -- ( \n ( \n ) \n ( \n ) \n )
- -- Both are represented by ['>1', '>2', '2', '>2', '2', '1'], and
- -- vim interprets as the second case.
- -- If it did have such a mechanism, (clamped - clamped_prev)
- -- would be the correct number of starts to pass on.
- local adjusted = level0 ---@type integer
- local prefix = ''
- if enter_line > 0 then
- prefix = '>'
- if leave_line > 0 then
- -- If this line ends a fold f1 and starts a fold f2, then move f1's end to the previous line
- -- so that f2 gets the correct level on this line. This may reduce the size of f1 below
- -- foldminlines, but we don't handle it for simplicity.
- adjusted = level0 - leave_line
- leave_line = 0
+ local nestmax = vim.wo.foldnestmax
+ local level0_prev = info.levels0[srow] or 0
+ local leave_prev = leave_counts[srow] or 0
+
+ -- We now have the list of fold opening and closing, fill the gaps and mark where fold start
+ for lnum = srow + 1, erow do
+ local enter_line = enter_counts[lnum] or 0
+ local leave_line = leave_counts[lnum] or 0
+ local level0 = level0_prev - leave_prev + enter_line
+
+ -- Determine if it's the start/end of a fold
+ -- NB: vim's fold-expr interface does not have a mechanism to indicate that
+ -- two (or more) folds start at this line, so it cannot distinguish between
+ -- ( \n ( \n )) \n (( \n ) \n )
+ -- versus
+ -- ( \n ( \n ) \n ( \n ) \n )
+ -- Both are represented by ['>1', '>2', '2', '>2', '2', '1'], and
+ -- vim interprets as the second case.
+ -- If it did have such a mechanism, (clamped - clamped_prev)
+ -- would be the correct number of starts to pass on.
+ local adjusted = level0 ---@type integer
+ local prefix = ''
+ if enter_line > 0 then
+ prefix = '>'
+ if leave_line > 0 then
+ -- If this line ends a fold f1 and starts a fold f2, then move f1's end to the previous line
+ -- so that f2 gets the correct level on this line. This may reduce the size of f1 below
+ -- foldminlines, but we don't handle it for simplicity.
+ adjusted = level0 - leave_line
+ leave_line = 0
+ end
end
- end
- -- Clamp at foldnestmax.
- local clamped = adjusted
- if adjusted > nestmax then
- prefix = ''
- clamped = nestmax
- end
+ -- Clamp at foldnestmax.
+ local clamped = adjusted
+ if adjusted > nestmax then
+ prefix = ''
+ clamped = nestmax
+ end
- -- Record the "real" level, so that it can be used as "base" of later compute_folds_levels().
- info.levels0[lnum] = adjusted
- info.levels[lnum] = prefix .. tostring(clamped)
+ -- Record the "real" level, so that it can be used as "base" of later compute_folds_levels().
+ info.levels0[lnum] = adjusted
+ info.levels[lnum] = prefix .. tostring(clamped)
- leave_prev = leave_line
- level0_prev = adjusted
- end
+ leave_prev = leave_line
+ level0_prev = adjusted
+ end
+
+ if callback then
+ callback()
+ end
+ end)
end
local M = {}
@@ -221,7 +192,7 @@ local M = {}
---@type table<integer,TS.FoldInfo>
local foldinfos = {}
-local group = api.nvim_create_augroup('treesitter/fold', {})
+local group = api.nvim_create_augroup('nvim.treesitter.fold', {})
--- Update the folds in the windows that contain the buffer and use expr foldmethod (assuming that
--- the user doesn't use different foldexpr for the same buffer).
@@ -298,12 +269,19 @@ local function schedule_if_loaded(bufnr, fn)
end
---@param bufnr integer
----@param foldinfo TS.FoldInfo
---@param tree_changes Range4[]
-local function on_changedtree(bufnr, foldinfo, tree_changes)
+local function on_changedtree(bufnr, tree_changes)
schedule_if_loaded(bufnr, function()
+ -- Buffer reload clears `foldinfos[bufnr]`, which may still be nil when callback is invoked.
+ local foldinfo = foldinfos[bufnr]
+ if not foldinfo then
+ return
+ end
+
local srow_upd, erow_upd ---@type integer?, integer?
local max_erow = api.nvim_buf_line_count(bufnr)
+ -- TODO(ribru17): Replace this with a proper .all() awaiter once #19624 is resolved
+ local iterations = 0
for _, change in ipairs(tree_changes) do
local srow, _, erow, ecol = Range.unpack4(change)
-- If a parser doesn't have any ranges explicitly set, treesitter will
@@ -317,24 +295,31 @@ local function on_changedtree(bufnr, foldinfo, tree_changes)
end
-- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit.
srow = math.max(srow - vim.wo.foldminlines, 0)
- compute_folds_levels(bufnr, foldinfo, srow, erow)
srow_upd = srow_upd and math.min(srow_upd, srow) or srow
erow_upd = erow_upd and math.max(erow_upd, erow) or erow
- end
- if #tree_changes > 0 then
- foldinfo:foldupdate(bufnr, srow_upd, erow_upd)
+ compute_folds_levels(bufnr, foldinfo, srow, erow, function()
+ iterations = iterations + 1
+ if iterations == #tree_changes then
+ foldinfo:foldupdate(bufnr, srow_upd, erow_upd)
+ end
+ end)
end
end)
end
---@param bufnr integer
----@param foldinfo TS.FoldInfo
---@param start_row integer
---@param old_row integer
---@param old_col integer
---@param new_row integer
---@param new_col integer
-local function on_bytes(bufnr, foldinfo, start_row, start_col, old_row, old_col, new_row, new_col)
+local function on_bytes(bufnr, start_row, start_col, old_row, old_col, new_row, new_col)
+ -- Buffer reload clears `foldinfos[bufnr]`, which may still be nil when callback is invoked.
+ local foldinfo = foldinfos[bufnr]
+ if not foldinfo then
+ return
+ end
+
-- extend the end to fully include the range
local end_row_old = start_row + old_row + 1
local end_row_new = start_row + new_row + 1
@@ -373,15 +358,16 @@ local function on_bytes(bufnr, foldinfo, start_row, start_col, old_row, old_col,
-- is invoked. For example, `J` with non-zero count triggers multiple on_bytes before executing
-- the scheduled callback. So we accumulate the edited ranges in `on_bytes_range`.
schedule_if_loaded(bufnr, function()
- if not foldinfo.on_bytes_range then
+ if not (foldinfo.on_bytes_range and foldinfos[bufnr]) then
return
end
local srow, erow = foldinfo.on_bytes_range[1], foldinfo.on_bytes_range[2]
foldinfo.on_bytes_range = nil
-- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit.
srow = math.max(srow - vim.wo.foldminlines, 0)
- compute_folds_levels(bufnr, foldinfo, srow, erow)
- foldinfo:foldupdate(bufnr, srow, erow)
+ compute_folds_levels(bufnr, foldinfo, srow, erow, function()
+ foldinfo:foldupdate(bufnr, srow, erow)
+ end)
end)
end
end
@@ -392,22 +378,30 @@ function M.foldexpr(lnum)
lnum = lnum or vim.v.lnum
local bufnr = api.nvim_get_current_buf()
- local parser = ts.get_parser(bufnr, nil, { error = false })
- if not parser then
- return '0'
- end
-
if not foldinfos[bufnr] then
- foldinfos[bufnr] = FoldInfo.new()
+ foldinfos[bufnr] = FoldInfo.new(bufnr)
+ api.nvim_create_autocmd({ 'BufUnload', 'VimEnter' }, {
+ buffer = bufnr,
+ once = true,
+ callback = function()
+ foldinfos[bufnr] = nil
+ end,
+ })
+
+ local parser = foldinfos[bufnr].parser
+ if not parser then
+ return '0'
+ end
+
compute_folds_levels(bufnr, foldinfos[bufnr])
parser:register_cbs({
on_changedtree = function(tree_changes)
- on_changedtree(bufnr, foldinfos[bufnr], tree_changes)
+ on_changedtree(bufnr, tree_changes)
end,
on_bytes = function(_, _, start_row, start_col, _, old_row, old_col, _, new_row, new_col, _)
- on_bytes(bufnr, foldinfos[bufnr], start_row, start_col, old_row, old_col, new_row, new_col)
+ on_bytes(bufnr, start_row, start_col, old_row, old_col, new_row, new_col)
end,
on_detach = function()
@@ -423,10 +417,17 @@ api.nvim_create_autocmd('OptionSet', {
pattern = { 'foldminlines', 'foldnestmax' },
desc = 'Refresh treesitter folds',
callback = function()
- for bufnr, _ in pairs(foldinfos) do
- foldinfos[bufnr] = FoldInfo.new()
- compute_folds_levels(bufnr, foldinfos[bufnr])
- foldinfos[bufnr]:foldupdate(bufnr, 0, api.nvim_buf_line_count(bufnr))
+ local buf = api.nvim_get_current_buf()
+ local bufs = vim.v.option_type == 'global' and vim.tbl_keys(foldinfos)
+ or foldinfos[buf] and { buf }
+ or {}
+ for _, bufnr in ipairs(bufs) do
+ foldinfos[bufnr] = FoldInfo.new(bufnr)
+ api.nvim_buf_call(bufnr, function()
+ compute_folds_levels(bufnr, foldinfos[bufnr], nil, nil, function()
+ foldinfos[bufnr]:foldupdate(bufnr, 0, api.nvim_buf_line_count(bufnr))
+ end)
+ end)
end
end,
})