diff options
Diffstat (limited to 'runtime/lua/vim/treesitter/_fold.lua')
-rw-r--r-- | runtime/lua/vim/treesitter/_fold.lua | 207 |
1 files changed, 119 insertions, 88 deletions
diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua index d96cc966de..eecf1ad6b1 100644 --- a/runtime/lua/vim/treesitter/_fold.lua +++ b/runtime/lua/vim/treesitter/_fold.lua @@ -4,10 +4,21 @@ local Range = require('vim.treesitter._range') local api = vim.api +---Treesitter folding is done in two steps: +---(1) compute the fold levels with the syntax tree and cache the result (`compute_folds_levels`) +---(2) evaluate foldexpr for each window, which reads from the cache (`foldupdate`) ---@class TS.FoldInfo ----@field levels string[] the foldexpr result for each line ----@field levels0 integer[] the raw fold levels ----@field edits? {[1]: integer, [2]: integer} line range edited since the last invocation of the callback scheduled in on_bytes. 0-indexed, end-exclusive. +--- +---@field levels string[] the cached foldexpr result for each line +---@field levels0 integer[] the cached raw fold levels +--- +---The range edited since the last invocation of the callback scheduled in on_bytes. +---Should compute fold levels in this range. +---@field on_bytes_range? Range2 +--- +---The range on which to evaluate foldexpr. +---When in insert mode, the evaluation is deferred to InsertLeave. +---@field foldupdate_range? Range2 local FoldInfo = {} FoldInfo.__index = FoldInfo @@ -80,45 +91,16 @@ function FoldInfo:add_range(srow, erow) list_insert(self.levels0, srow + 1, erow, -1) end ----@package +---@param range Range2 ---@param srow integer ---@param erow_old integer ---@param erow_new integer 0-indexed, exclusive -function FoldInfo:edit_range(srow, erow_old, erow_new) - if self.edits then - self.edits[1] = math.min(srow, self.edits[1]) - if erow_old <= self.edits[2] then - self.edits[2] = self.edits[2] + (erow_new - erow_old) - end - self.edits[2] = math.max(self.edits[2], erow_new) - else - self.edits = { srow, erow_new } +local function edit_range(range, srow, erow_old, erow_new) + range[1] = math.min(srow, range[1]) + if erow_old <= range[2] then + range[2] = range[2] + (erow_new - erow_old) end -end - ----@package ----@return integer? srow ----@return integer? erow 0-indexed, exclusive -function FoldInfo:flush_edit() - if self.edits then - local srow, erow = self.edits[1], self.edits[2] - self.edits = nil - return srow, erow - end -end - ---- If a parser doesn't have any ranges explicitly set, treesitter will ---- return a range with end_row and end_bytes with a value of UINT32_MAX, ---- so clip end_row to the max buffer line. ---- ---- TODO(lewis6991): Handle this generally ---- ---- @param bufnr integer ---- @param erow integer? 0-indexed, exclusive ---- @return integer -local function normalise_erow(bufnr, erow) - local max_erow = api.nvim_buf_line_count(bufnr) - return math.min(erow or max_erow, max_erow) + range[2] = math.max(range[2], erow_new) end -- TODO(lewis6991): Setup a decor provider so injections folds can be parsed @@ -128,9 +110,9 @@ end ---@param srow integer? ---@param erow integer? 0-indexed, exclusive ---@param parse_injections? boolean -local function get_folds_levels(bufnr, info, srow, erow, parse_injections) +local function compute_folds_levels(bufnr, info, srow, erow, parse_injections) srow = srow or 0 - erow = normalise_erow(bufnr, erow) + erow = erow or api.nvim_buf_line_count(bufnr) local parser = ts.get_parser(bufnr) @@ -149,27 +131,43 @@ local function get_folds_levels(bufnr, info, srow, erow, parse_injections) -- Collect folds starting from srow - 1, because we should first subtract the folds that end at -- srow - 1 from the level of srow - 1 to get accurate level of srow. - for id, node, metadata in query:iter_captures(tree:root(), bufnr, math.max(srow - 1, 0), erow) do - if query.captures[id] == 'fold' then - local range = ts.get_range(node, bufnr, metadata[id]) - local start, _, stop, stop_col = Range.unpack4(range) - - if stop_col == 0 then - stop = stop - 1 - end - - local fold_length = stop - start + 1 - - -- Fold only multiline nodes that are not exactly the same as previously met folds - -- Checking against just the previously found fold is sufficient if nodes - -- are returned in preorder or postorder when traversing tree - if - fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop) - then - enter_counts[start + 1] = (enter_counts[start + 1] or 0) + 1 - leave_counts[stop + 1] = (leave_counts[stop + 1] or 0) + 1 - prev_start = start - prev_stop = stop + for _, match, metadata in + query:iter_matches(tree:root(), bufnr, math.max(srow - 1, 0), erow, { all = true }) + do + for id, nodes in pairs(match) do + if query.captures[id] == 'fold' then + local range = ts.get_range(nodes[1], bufnr, metadata[id]) + local start, _, stop, stop_col = Range.unpack4(range) + + for i = 2, #nodes, 1 do + local node_range = ts.get_range(nodes[i], bufnr, metadata[id]) + local node_start, _, node_stop, node_stop_col = Range.unpack4(node_range) + if node_start < start then + start = node_start + end + if node_stop > stop then + stop = node_stop + stop_col = node_stop_col + end + end + + if stop_col == 0 then + stop = stop - 1 + end + + local fold_length = stop - start + 1 + + -- Fold only multiline nodes that are not exactly the same as previously met folds + -- Checking against just the previously found fold is sufficient if nodes + -- are returned in preorder or postorder when traversing tree + if + fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop) + then + enter_counts[start + 1] = (enter_counts[start + 1] or 0) + 1 + leave_counts[stop + 1] = (leave_counts[stop + 1] or 0) + 1 + prev_start = start + prev_stop = stop + end end end end @@ -215,7 +213,7 @@ local function get_folds_levels(bufnr, info, srow, erow, parse_injections) clamped = nestmax end - -- Record the "real" level, so that it can be used as "base" of later get_folds_levels(). + -- Record the "real" level, so that it can be used as "base" of later compute_folds_levels(). info.levels0[lnum] = adjusted info.levels[lnum] = prefix .. tostring(clamped) @@ -236,18 +234,17 @@ local group = api.nvim_create_augroup('treesitter/fold', {}) --- --- Nvim usually automatically updates folds when text changes, but it doesn't work here because --- FoldInfo update is scheduled. So we do it manually. -local function foldupdate(bufnr) - local function do_update() - for _, win in ipairs(vim.fn.win_findbuf(bufnr)) do - api.nvim_win_call(win, function() - if vim.wo.foldmethod == 'expr' then - vim._foldupdate() - end - end) - end +---@package +---@param srow integer +---@param erow integer 0-indexed, exclusive +function FoldInfo:foldupdate(bufnr, srow, erow) + if self.foldupdate_range then + edit_range(self.foldupdate_range, srow, erow, erow) + else + self.foldupdate_range = { srow, erow } end - if api.nvim_get_mode().mode == 'i' then + if api.nvim_get_mode().mode:match('^i') then -- foldUpdate() is guarded in insert mode. So update folds on InsertLeave if #(api.nvim_get_autocmds({ group = group, @@ -259,12 +256,25 @@ local function foldupdate(bufnr) group = group, buffer = bufnr, once = true, - callback = do_update, + callback = function() + self:do_foldupdate(bufnr) + end, }) return end - do_update() + self:do_foldupdate(bufnr) +end + +---@package +function FoldInfo:do_foldupdate(bufnr) + local srow, erow = self.foldupdate_range[1], self.foldupdate_range[2] + self.foldupdate_range = nil + for _, win in ipairs(vim.fn.win_findbuf(bufnr)) do + if vim.wo[win].foldmethod == 'expr' then + vim._foldupdate(win, srow, erow) + end + end end --- Schedule a function only if bufnr is loaded. @@ -272,7 +282,7 @@ end --- * queries seem to use the old buffer state in on_bytes for some unknown reason; --- * to avoid textlock; --- * to avoid infinite recursion: ---- get_folds_levels → parse → _do_callback → on_changedtree → get_folds_levels. +--- compute_folds_levels → parse → _do_callback → on_changedtree → compute_folds_levels. ---@param bufnr integer ---@param fn function local function schedule_if_loaded(bufnr, fn) @@ -289,16 +299,27 @@ end ---@param tree_changes Range4[] local function on_changedtree(bufnr, foldinfo, tree_changes) schedule_if_loaded(bufnr, function() + local srow_upd, erow_upd ---@type integer?, integer? + local max_erow = api.nvim_buf_line_count(bufnr) for _, change in ipairs(tree_changes) do local srow, _, erow, ecol = Range.unpack4(change) - if ecol > 0 then + -- If a parser doesn't have any ranges explicitly set, treesitter will + -- return a range with end_row and end_bytes with a value of UINT32_MAX, + -- so clip end_row to the max buffer line. + -- TODO(lewis6991): Handle this generally + if erow > max_erow then + erow = max_erow + elseif ecol > 0 then erow = erow + 1 end -- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit. - get_folds_levels(bufnr, foldinfo, math.max(srow - vim.wo.foldminlines, 0), erow) + srow = math.max(srow - vim.wo.foldminlines, 0) + compute_folds_levels(bufnr, foldinfo, srow, erow) + srow_upd = srow_upd and math.min(srow_upd, srow) or srow + erow_upd = erow_upd and math.max(erow_upd, erow) or erow end if #tree_changes > 0 then - foldupdate(bufnr) + foldinfo:foldupdate(bufnr, srow_upd, erow_upd) end end) end @@ -335,19 +356,29 @@ local function on_bytes(bufnr, foldinfo, start_row, start_col, old_row, old_col, foldinfo:add_range(end_row_old, end_row_new) end end - foldinfo:edit_range(start_row, end_row_old, end_row_new) + + if foldinfo.on_bytes_range then + edit_range(foldinfo.on_bytes_range, start_row, end_row_old, end_row_new) + else + foldinfo.on_bytes_range = { start_row, end_row_new } + end + if foldinfo.foldupdate_range then + edit_range(foldinfo.foldupdate_range, start_row, end_row_old, end_row_new) + end -- This callback must not use on_bytes arguments, because they can be outdated when the callback -- is invoked. For example, `J` with non-zero count triggers multiple on_bytes before executing - -- the scheduled callback. So we should collect the edits. + -- the scheduled callback. So we accumulate the edited ranges in `on_bytes_range`. schedule_if_loaded(bufnr, function() - local srow, erow = foldinfo:flush_edit() - if not srow then + if not foldinfo.on_bytes_range then return end + local srow, erow = foldinfo.on_bytes_range[1], foldinfo.on_bytes_range[2] + foldinfo.on_bytes_range = nil -- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit. - get_folds_levels(bufnr, foldinfo, math.max(srow - vim.wo.foldminlines, 0), erow) - foldupdate(bufnr) + srow = math.max(srow - vim.wo.foldminlines, 0) + compute_folds_levels(bufnr, foldinfo, srow, erow) + foldinfo:foldupdate(bufnr, srow, erow) end) end end @@ -366,7 +397,7 @@ function M.foldexpr(lnum) if not foldinfos[bufnr] then foldinfos[bufnr] = FoldInfo.new() - get_folds_levels(bufnr, foldinfos[bufnr]) + compute_folds_levels(bufnr, foldinfos[bufnr]) parser:register_cbs({ on_changedtree = function(tree_changes) @@ -390,10 +421,10 @@ api.nvim_create_autocmd('OptionSet', { pattern = { 'foldminlines', 'foldnestmax' }, desc = 'Refresh treesitter folds', callback = function() - for _, bufnr in ipairs(vim.tbl_keys(foldinfos)) do + for bufnr, _ in pairs(foldinfos) do foldinfos[bufnr] = FoldInfo.new() - get_folds_levels(bufnr, foldinfos[bufnr]) - foldupdate(bufnr) + compute_folds_levels(bufnr, foldinfos[bufnr]) + foldinfos[bufnr]:foldupdate(bufnr, 0, api.nvim_buf_line_count(bufnr)) end end, }) |