diff options
author | Josh Rahm <joshuarahm@gmail.com> | 2025-02-05 23:09:29 +0000 |
---|---|---|
committer | Josh Rahm <joshuarahm@gmail.com> | 2025-02-05 23:09:29 +0000 |
commit | d5f194ce780c95821a855aca3c19426576d28ae0 (patch) | |
tree | d45f461b19f9118ad2bb1f440a7a08973ad18832 /runtime/lua/vim/treesitter/_fold.lua | |
parent | c5d770d311841ea5230426cc4c868e8db27300a8 (diff) | |
parent | 44740e561fc93afe3ebecfd3618bda2d2abeafb0 (diff) | |
download | rneovim-d5f194ce780c95821a855aca3c19426576d28ae0.tar.gz rneovim-d5f194ce780c95821a855aca3c19426576d28ae0.tar.bz2 rneovim-d5f194ce780c95821a855aca3c19426576d28ae0.zip |
Diffstat (limited to 'runtime/lua/vim/treesitter/_fold.lua')
-rw-r--r-- | runtime/lua/vim/treesitter/_fold.lua | 325 |
1 files changed, 163 insertions, 162 deletions
diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua index 7237d2e7d4..38318347a7 100644 --- a/runtime/lua/vim/treesitter/_fold.lua +++ b/runtime/lua/vim/treesitter/_fold.lua @@ -19,76 +19,36 @@ local api = vim.api ---The range on which to evaluate foldexpr. ---When in insert mode, the evaluation is deferred to InsertLeave. ---@field foldupdate_range? Range2 +--- +---The treesitter parser associated with this buffer. +---@field parser? vim.treesitter.LanguageTree local FoldInfo = {} FoldInfo.__index = FoldInfo ---@private -function FoldInfo.new() +---@param bufnr integer +function FoldInfo.new(bufnr) return setmetatable({ levels0 = {}, levels = {}, + parser = ts.get_parser(bufnr, nil, { error = false }), }, FoldInfo) end ---- Efficiently remove items from middle of a list a list. ---- ---- Calling table.remove() in a loop will re-index the tail of the table on ---- every iteration, instead this function will re-index the table exactly ---- once. ---- ---- Based on https://stackoverflow.com/questions/12394841/safely-remove-items-from-an-array-table-while-iterating/53038524#53038524 ---- ----@param t any[] ----@param first integer ----@param last integer -local function list_remove(t, first, last) - local n = #t - for i = 0, n - first do - t[first + i] = t[last + 1 + i] - t[last + 1 + i] = nil - end -end - ---@package ---@param srow integer ---@param erow integer 0-indexed, exclusive function FoldInfo:remove_range(srow, erow) - list_remove(self.levels, srow + 1, erow) - list_remove(self.levels0, srow + 1, erow) -end - ---- Efficiently insert items into the middle of a list. ---- ---- Calling table.insert() in a loop will re-index the tail of the table on ---- every iteration, instead this function will re-index the table exactly ---- once. ---- ---- Based on https://stackoverflow.com/questions/12394841/safely-remove-items-from-an-array-table-while-iterating/53038524#53038524 ---- ----@param t any[] ----@param first integer ----@param last integer ----@param v any -local function list_insert(t, first, last, v) - local n = #t - - -- Shift table forward - for i = n - first, 0, -1 do - t[last + 1 + i] = t[first + i] - end - - -- Fill in new values - for i = first, last do - t[i] = v - end + vim._list_remove(self.levels, srow + 1, erow) + vim._list_remove(self.levels0, srow + 1, erow) end ---@package ---@param srow integer ---@param erow integer 0-indexed, exclusive function FoldInfo:add_range(srow, erow) - list_insert(self.levels, srow + 1, erow, -1) - list_insert(self.levels0, srow + 1, erow, -1) + vim._list_insert(self.levels, srow + 1, erow, -1) + vim._list_insert(self.levels0, srow + 1, erow, -1) end ---@param range Range2 @@ -109,111 +69,122 @@ end ---@param info TS.FoldInfo ---@param srow integer? ---@param erow integer? 0-indexed, exclusive ----@param parse_injections? boolean -local function compute_folds_levels(bufnr, info, srow, erow, parse_injections) +---@param callback function? +local function compute_folds_levels(bufnr, info, srow, erow, callback) srow = srow or 0 erow = erow or api.nvim_buf_line_count(bufnr) - local parser = assert(ts.get_parser(bufnr, nil, { error = false })) - - parser:parse(parse_injections and { srow, erow } or nil) - - local enter_counts = {} ---@type table<integer, integer> - local leave_counts = {} ---@type table<integer, integer> - local prev_start = -1 - local prev_stop = -1 + local parser = info.parser + if not parser then + return + end - parser:for_each_tree(function(tree, ltree) - local query = ts.query.get(ltree:lang(), 'folds') - if not query then + parser:parse(nil, function(_, trees) + if not trees then return end - -- Collect folds starting from srow - 1, because we should first subtract the folds that end at - -- srow - 1 from the level of srow - 1 to get accurate level of srow. - for _, match, metadata in query:iter_matches(tree:root(), bufnr, math.max(srow - 1, 0), erow) do - for id, nodes in pairs(match) do - if query.captures[id] == 'fold' then - local range = ts.get_range(nodes[1], bufnr, metadata[id]) - local start, _, stop, stop_col = Range.unpack4(range) - - if #nodes > 1 then - -- assumes nodes are ordered by range - local end_range = ts.get_range(nodes[#nodes], bufnr, metadata[id]) - local _, _, end_stop, end_stop_col = Range.unpack4(end_range) - stop = end_stop - stop_col = end_stop_col - end + local enter_counts = {} ---@type table<integer, integer> + local leave_counts = {} ---@type table<integer, integer> + local prev_start = -1 + local prev_stop = -1 - if stop_col == 0 then - stop = stop - 1 - end + parser:for_each_tree(function(tree, ltree) + local query = ts.query.get(ltree:lang(), 'folds') + if not query then + return + end - local fold_length = stop - start + 1 - - -- Fold only multiline nodes that are not exactly the same as previously met folds - -- Checking against just the previously found fold is sufficient if nodes - -- are returned in preorder or postorder when traversing tree - if - fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop) - then - enter_counts[start + 1] = (enter_counts[start + 1] or 0) + 1 - leave_counts[stop + 1] = (leave_counts[stop + 1] or 0) + 1 - prev_start = start - prev_stop = stop + -- Collect folds starting from srow - 1, because we should first subtract the folds that end at + -- srow - 1 from the level of srow - 1 to get accurate level of srow. + for _, match, metadata in query:iter_matches(tree:root(), bufnr, math.max(srow - 1, 0), erow) do + for id, nodes in pairs(match) do + if query.captures[id] == 'fold' then + local range = ts.get_range(nodes[1], bufnr, metadata[id]) + local start, _, stop, stop_col = Range.unpack4(range) + + if #nodes > 1 then + -- assumes nodes are ordered by range + local end_range = ts.get_range(nodes[#nodes], bufnr, metadata[id]) + local _, _, end_stop, end_stop_col = Range.unpack4(end_range) + stop = end_stop + stop_col = end_stop_col + end + + if stop_col == 0 then + stop = stop - 1 + end + + local fold_length = stop - start + 1 + + -- Fold only multiline nodes that are not exactly the same as previously met folds + -- Checking against just the previously found fold is sufficient if nodes + -- are returned in preorder or postorder when traversing tree + if + fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop) + then + enter_counts[start + 1] = (enter_counts[start + 1] or 0) + 1 + leave_counts[stop + 1] = (leave_counts[stop + 1] or 0) + 1 + prev_start = start + prev_stop = stop + end end end end - end - end) + end) - local nestmax = vim.wo.foldnestmax - local level0_prev = info.levels0[srow] or 0 - local leave_prev = leave_counts[srow] or 0 - - -- We now have the list of fold opening and closing, fill the gaps and mark where fold start - for lnum = srow + 1, erow do - local enter_line = enter_counts[lnum] or 0 - local leave_line = leave_counts[lnum] or 0 - local level0 = level0_prev - leave_prev + enter_line - - -- Determine if it's the start/end of a fold - -- NB: vim's fold-expr interface does not have a mechanism to indicate that - -- two (or more) folds start at this line, so it cannot distinguish between - -- ( \n ( \n )) \n (( \n ) \n ) - -- versus - -- ( \n ( \n ) \n ( \n ) \n ) - -- Both are represented by ['>1', '>2', '2', '>2', '2', '1'], and - -- vim interprets as the second case. - -- If it did have such a mechanism, (clamped - clamped_prev) - -- would be the correct number of starts to pass on. - local adjusted = level0 ---@type integer - local prefix = '' - if enter_line > 0 then - prefix = '>' - if leave_line > 0 then - -- If this line ends a fold f1 and starts a fold f2, then move f1's end to the previous line - -- so that f2 gets the correct level on this line. This may reduce the size of f1 below - -- foldminlines, but we don't handle it for simplicity. - adjusted = level0 - leave_line - leave_line = 0 + local nestmax = vim.wo.foldnestmax + local level0_prev = info.levels0[srow] or 0 + local leave_prev = leave_counts[srow] or 0 + + -- We now have the list of fold opening and closing, fill the gaps and mark where fold start + for lnum = srow + 1, erow do + local enter_line = enter_counts[lnum] or 0 + local leave_line = leave_counts[lnum] or 0 + local level0 = level0_prev - leave_prev + enter_line + + -- Determine if it's the start/end of a fold + -- NB: vim's fold-expr interface does not have a mechanism to indicate that + -- two (or more) folds start at this line, so it cannot distinguish between + -- ( \n ( \n )) \n (( \n ) \n ) + -- versus + -- ( \n ( \n ) \n ( \n ) \n ) + -- Both are represented by ['>1', '>2', '2', '>2', '2', '1'], and + -- vim interprets as the second case. + -- If it did have such a mechanism, (clamped - clamped_prev) + -- would be the correct number of starts to pass on. + local adjusted = level0 ---@type integer + local prefix = '' + if enter_line > 0 then + prefix = '>' + if leave_line > 0 then + -- If this line ends a fold f1 and starts a fold f2, then move f1's end to the previous line + -- so that f2 gets the correct level on this line. This may reduce the size of f1 below + -- foldminlines, but we don't handle it for simplicity. + adjusted = level0 - leave_line + leave_line = 0 + end end - end - -- Clamp at foldnestmax. - local clamped = adjusted - if adjusted > nestmax then - prefix = '' - clamped = nestmax - end + -- Clamp at foldnestmax. + local clamped = adjusted + if adjusted > nestmax then + prefix = '' + clamped = nestmax + end - -- Record the "real" level, so that it can be used as "base" of later compute_folds_levels(). - info.levels0[lnum] = adjusted - info.levels[lnum] = prefix .. tostring(clamped) + -- Record the "real" level, so that it can be used as "base" of later compute_folds_levels(). + info.levels0[lnum] = adjusted + info.levels[lnum] = prefix .. tostring(clamped) - leave_prev = leave_line - level0_prev = adjusted - end + leave_prev = leave_line + level0_prev = adjusted + end + + if callback then + callback() + end + end) end local M = {} @@ -221,7 +192,7 @@ local M = {} ---@type table<integer,TS.FoldInfo> local foldinfos = {} -local group = api.nvim_create_augroup('treesitter/fold', {}) +local group = api.nvim_create_augroup('nvim.treesitter.fold', {}) --- Update the folds in the windows that contain the buffer and use expr foldmethod (assuming that --- the user doesn't use different foldexpr for the same buffer). @@ -298,12 +269,19 @@ local function schedule_if_loaded(bufnr, fn) end ---@param bufnr integer ----@param foldinfo TS.FoldInfo ---@param tree_changes Range4[] -local function on_changedtree(bufnr, foldinfo, tree_changes) +local function on_changedtree(bufnr, tree_changes) schedule_if_loaded(bufnr, function() + -- Buffer reload clears `foldinfos[bufnr]`, which may still be nil when callback is invoked. + local foldinfo = foldinfos[bufnr] + if not foldinfo then + return + end + local srow_upd, erow_upd ---@type integer?, integer? local max_erow = api.nvim_buf_line_count(bufnr) + -- TODO(ribru17): Replace this with a proper .all() awaiter once #19624 is resolved + local iterations = 0 for _, change in ipairs(tree_changes) do local srow, _, erow, ecol = Range.unpack4(change) -- If a parser doesn't have any ranges explicitly set, treesitter will @@ -317,24 +295,31 @@ local function on_changedtree(bufnr, foldinfo, tree_changes) end -- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit. srow = math.max(srow - vim.wo.foldminlines, 0) - compute_folds_levels(bufnr, foldinfo, srow, erow) srow_upd = srow_upd and math.min(srow_upd, srow) or srow erow_upd = erow_upd and math.max(erow_upd, erow) or erow - end - if #tree_changes > 0 then - foldinfo:foldupdate(bufnr, srow_upd, erow_upd) + compute_folds_levels(bufnr, foldinfo, srow, erow, function() + iterations = iterations + 1 + if iterations == #tree_changes then + foldinfo:foldupdate(bufnr, srow_upd, erow_upd) + end + end) end end) end ---@param bufnr integer ----@param foldinfo TS.FoldInfo ---@param start_row integer ---@param old_row integer ---@param old_col integer ---@param new_row integer ---@param new_col integer -local function on_bytes(bufnr, foldinfo, start_row, start_col, old_row, old_col, new_row, new_col) +local function on_bytes(bufnr, start_row, start_col, old_row, old_col, new_row, new_col) + -- Buffer reload clears `foldinfos[bufnr]`, which may still be nil when callback is invoked. + local foldinfo = foldinfos[bufnr] + if not foldinfo then + return + end + -- extend the end to fully include the range local end_row_old = start_row + old_row + 1 local end_row_new = start_row + new_row + 1 @@ -373,15 +358,16 @@ local function on_bytes(bufnr, foldinfo, start_row, start_col, old_row, old_col, -- is invoked. For example, `J` with non-zero count triggers multiple on_bytes before executing -- the scheduled callback. So we accumulate the edited ranges in `on_bytes_range`. schedule_if_loaded(bufnr, function() - if not foldinfo.on_bytes_range then + if not (foldinfo.on_bytes_range and foldinfos[bufnr]) then return end local srow, erow = foldinfo.on_bytes_range[1], foldinfo.on_bytes_range[2] foldinfo.on_bytes_range = nil -- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit. srow = math.max(srow - vim.wo.foldminlines, 0) - compute_folds_levels(bufnr, foldinfo, srow, erow) - foldinfo:foldupdate(bufnr, srow, erow) + compute_folds_levels(bufnr, foldinfo, srow, erow, function() + foldinfo:foldupdate(bufnr, srow, erow) + end) end) end end @@ -392,22 +378,30 @@ function M.foldexpr(lnum) lnum = lnum or vim.v.lnum local bufnr = api.nvim_get_current_buf() - local parser = ts.get_parser(bufnr, nil, { error = false }) - if not parser then - return '0' - end - if not foldinfos[bufnr] then - foldinfos[bufnr] = FoldInfo.new() + foldinfos[bufnr] = FoldInfo.new(bufnr) + api.nvim_create_autocmd({ 'BufUnload', 'VimEnter' }, { + buffer = bufnr, + once = true, + callback = function() + foldinfos[bufnr] = nil + end, + }) + + local parser = foldinfos[bufnr].parser + if not parser then + return '0' + end + compute_folds_levels(bufnr, foldinfos[bufnr]) parser:register_cbs({ on_changedtree = function(tree_changes) - on_changedtree(bufnr, foldinfos[bufnr], tree_changes) + on_changedtree(bufnr, tree_changes) end, on_bytes = function(_, _, start_row, start_col, _, old_row, old_col, _, new_row, new_col, _) - on_bytes(bufnr, foldinfos[bufnr], start_row, start_col, old_row, old_col, new_row, new_col) + on_bytes(bufnr, start_row, start_col, old_row, old_col, new_row, new_col) end, on_detach = function() @@ -423,10 +417,17 @@ api.nvim_create_autocmd('OptionSet', { pattern = { 'foldminlines', 'foldnestmax' }, desc = 'Refresh treesitter folds', callback = function() - for bufnr, _ in pairs(foldinfos) do - foldinfos[bufnr] = FoldInfo.new() - compute_folds_levels(bufnr, foldinfos[bufnr]) - foldinfos[bufnr]:foldupdate(bufnr, 0, api.nvim_buf_line_count(bufnr)) + local buf = api.nvim_get_current_buf() + local bufs = vim.v.option_type == 'global' and vim.tbl_keys(foldinfos) + or foldinfos[buf] and { buf } + or {} + for _, bufnr in ipairs(bufs) do + foldinfos[bufnr] = FoldInfo.new(bufnr) + api.nvim_buf_call(bufnr, function() + compute_folds_levels(bufnr, foldinfos[bufnr], nil, nil, function() + foldinfos[bufnr]:foldupdate(bufnr, 0, api.nvim_buf_line_count(bufnr)) + end) + end) end end, }) |