diff options
author | Lewis Russell <lewis6991@gmail.com> | 2023-03-09 15:28:55 +0000 |
---|---|---|
committer | Lewis Russell <lewis6991@gmail.com> | 2023-03-10 11:51:33 +0000 |
commit | 46b73bf22cb951151de9bf0712d42e194000b677 (patch) | |
tree | 670fd0241b94031b26ba2c7a9fb5e48033dbff45 /runtime/lua/vim/treesitter/_fold.lua | |
parent | c5b9643bf1b0f6d5166b4abf6a7c3f29532aefeb (diff) | |
download | rneovim-46b73bf22cb951151de9bf0712d42e194000b677.tar.gz rneovim-46b73bf22cb951151de9bf0712d42e194000b677.tar.bz2 rneovim-46b73bf22cb951151de9bf0712d42e194000b677.zip |
perf(treesitter): more efficient foldexpr
Diffstat (limited to 'runtime/lua/vim/treesitter/_fold.lua')
-rw-r--r-- | runtime/lua/vim/treesitter/_fold.lua | 290 |
1 files changed, 188 insertions, 102 deletions
diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua index a66cc6d543..435cb9fdb6 100644 --- a/runtime/lua/vim/treesitter/_fold.lua +++ b/runtime/lua/vim/treesitter/_fold.lua @@ -1,139 +1,157 @@ +local Range = require('vim.treesitter._range') + local api = vim.api -local M = {} +---@class FoldInfo +---@field levels table<integer,string> +---@field levels0 table<integer,integer> +---@field private start_counts table<integer,integer> +---@field private stop_counts table<integer,integer> +local FoldInfo = {} +FoldInfo.__index = FoldInfo ---- Memoizes a function based on the buffer tick of the provided bufnr. ---- The cache entry is cleared when the buffer is detached to avoid memory leaks. ----@generic F: function ----@param fn F fn to memoize, taking the bufnr as first argument ----@return F -local function memoize_by_changedtick(fn) - ---@type table<integer,{result:any,last_tick:integer}> - local cache = {} - - ---@param bufnr integer - return function(bufnr, ...) - local tick = api.nvim_buf_get_changedtick(bufnr) - - if cache[bufnr] then - if cache[bufnr].last_tick == tick then - return cache[bufnr].result - end - else - local function detach_handler() - cache[bufnr] = nil - end +function FoldInfo.new() + return setmetatable({ + start_counts = {}, + stop_counts = {}, + levels0 = {}, + levels = {}, + }, FoldInfo) +end - -- Clean up logic only! - api.nvim_buf_attach(bufnr, false, { - on_detach = detach_handler, - on_reload = detach_handler, - }) - end +---@param srow integer +---@param erow integer +function FoldInfo:invalidate_range(srow, erow) + for i = srow, erow do + self.start_counts[i + 1] = nil + self.stop_counts[i + 1] = nil + self.levels0[i + 1] = nil + self.levels[i + 1] = nil + end +end - cache[bufnr] = { - result = fn(bufnr, ...), - last_tick = tick, - } +---@param srow integer +---@param erow integer +function FoldInfo:remove_range(srow, erow) + for i = erow - 1, srow, -1 do + table.remove(self.levels, i + 1) + table.remove(self.levels0, i + 1) + table.remove(self.start_counts, i + 1) + table.remove(self.stop_counts, i + 1) + end +end - return cache[bufnr].result +---@param srow integer +---@param erow integer +function FoldInfo:add_range(srow, erow) + for i = srow, erow - 1 do + table.insert(self.levels, i + 1, '-1') + table.insert(self.levels0, i + 1, -1) + table.insert(self.start_counts, i + 1, nil) + table.insert(self.stop_counts, i + 1, nil) end end ----@param bufnr integer ----@param capture string ----@param query_name string ----@param callback fun(id: integer, node:TSNode, metadata: TSMetadata) -local function iter_matches_with_capture(bufnr, capture, query_name, callback) - local parser = vim.treesitter.get_parser(bufnr) +---@param lnum integer +function FoldInfo:add_start(lnum) + self.start_counts[lnum] = (self.start_counts[lnum] or 0) + 1 +end - if not parser then - return - end +---@param lnum integer +function FoldInfo:add_stop(lnum) + self.stop_counts[lnum] = (self.stop_counts[lnum] or 0) + 1 +end - parser:for_each_tree(function(tree, lang_tree) - local lang = lang_tree:lang() - local query = vim.treesitter.query.get_query(lang, query_name) - if query then - local root = tree:root() - local start, _, stop = root:range() - for _, match, metadata in query:iter_matches(root, bufnr, start, stop) do - for id, node in pairs(match) do - if query.captures[id] == capture then - callback(id, node, metadata) - end - end - end - end - end) +---@param lnum integer +---@return integer +function FoldInfo:get_start(lnum) + return self.start_counts[lnum] or 0 +end + +---@param lnum integer +---@return integer +function FoldInfo:get_stop(lnum) + return self.stop_counts[lnum] or 0 end ---@private --- TODO(lewis6991): copied from languagetree.lua. Consolidate ---@param node TSNode ----@param id integer ---@param metadata TSMetadata ----@return Range -local function get_range_from_metadata(node, id, metadata) - if metadata[id] and metadata[id].range then - return metadata[id].range --[[@as Range]] +---@return Range4 +local function get_range_from_metadata(node, metadata) + if metadata and metadata.range then + return metadata.range --[[@as Range4]] end return { node:range() } end --- This is cached on buf tick to avoid computing that multiple times --- Especially not for every line in the file when `zx` is hit ----@param bufnr integer ----@return table<integer,string> -local folds_levels = memoize_by_changedtick(function(bufnr) +local function trim_level(level) local max_fold_level = vim.wo.foldnestmax - local function trim_level(level) - if level > max_fold_level then - return max_fold_level - end - return level + if level > max_fold_level then + return max_fold_level end + return level +end - -- start..stop is an inclusive range - local start_counts = {} ---@type table<integer,integer> - local stop_counts = {} ---@type table<integer,integer> +---@param bufnr integer +---@param info FoldInfo +---@param srow integer? +---@param erow integer? +local function get_folds_levels(bufnr, info, srow, erow) + srow = srow or 0 + erow = erow or api.nvim_buf_line_count(bufnr) + + info:invalidate_range(srow, erow) local prev_start = -1 local prev_stop = -1 - local min_fold_lines = vim.wo.foldminlines + vim.treesitter.get_parser(bufnr):for_each_tree(function(tree, ltree) + local query = vim.treesitter.query.get_query(ltree:lang(), 'folds') + if not query then + return + end + + -- erow in query is end-exclusive + local q_erow = erow and erow + 1 or -1 - iter_matches_with_capture(bufnr, 'fold', 'folds', function(id, node, metadata) - local range = get_range_from_metadata(node, id, metadata) - local start, stop, stop_col = range[1], range[3], range[4] + for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow or 0, q_erow) do + if query.captures[id] == 'fold' then + local range = get_range_from_metadata(node, metadata[id]) + local start, _, stop, stop_col = Range.unpack4(range) - if stop_col == 0 then - stop = stop - 1 - end + if stop_col == 0 then + stop = stop - 1 + end - local fold_length = stop - start + 1 + local fold_length = stop - start + 1 - -- Fold only multiline nodes that are not exactly the same as previously met folds - -- Checking against just the previously found fold is sufficient if nodes - -- are returned in preorder or postorder when traversing tree - if fold_length > min_fold_lines and not (start == prev_start and stop == prev_stop) then - start_counts[start] = (start_counts[start] or 0) + 1 - stop_counts[stop] = (stop_counts[stop] or 0) + 1 - prev_start = start - prev_stop = stop + -- Fold only multiline nodes that are not exactly the same as previously met folds + -- Checking against just the previously found fold is sufficient if nodes + -- are returned in preorder or postorder when traversing tree + if + fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop) + then + info:add_start(start + 1) + info:add_stop(stop + 1) + prev_start = start + prev_stop = stop + end + end end end) - ---@type table<integer,string> - local levels = {} - local current_level = 0 + local current_level = info.levels0[srow] or 0 -- We now have the list of fold opening and closing, fill the gaps and mark where fold start - for lnum = 0, api.nvim_buf_line_count(bufnr) do + for lnum = srow + 1, erow + 1 do local last_trimmed_level = trim_level(current_level) - current_level = current_level + (start_counts[lnum] or 0) + current_level = current_level + info:get_start(lnum) + info.levels0[lnum] = current_level + local trimmed_level = trim_level(current_level) - current_level = current_level - (stop_counts[lnum] or 0) + current_level = current_level - info:get_stop(lnum) -- Determine if it's the start/end of a fold -- NB: vim's fold-expr interface does not have a mechanism to indicate that @@ -148,11 +166,61 @@ local folds_levels = memoize_by_changedtick(function(bufnr) prefix = '>' end - levels[lnum + 1] = prefix .. tostring(trimmed_level) + info.levels[lnum] = prefix .. tostring(trimmed_level) + end +end + +local M = {} + +---@type table<integer,FoldInfo> +local foldinfos = {} + +local function recompute_folds() + if api.nvim_get_mode().mode == 'i' then + -- foldUpdate() is guarded in insert mode. So update folds on InsertLeave + api.nvim_create_autocmd('InsertLeave', { + once = true, + callback = vim._foldupdate, + }) + return end - return levels -end) + vim._foldupdate() +end + +---@param bufnr integer +---@param foldinfo FoldInfo +---@param tree_changes Range4[] +local function on_changedtree(bufnr, foldinfo, tree_changes) + -- For some reason, queries seem to use the old buffer state in on_bytes. + -- Get around this by scheduling and manually updating folds. + vim.schedule(function() + for _, change in ipairs(tree_changes) do + local srow, _, erow = Range.unpack4(change) + get_folds_levels(bufnr, foldinfo, srow, erow) + end + recompute_folds() + end) +end + +---@param bufnr integer +---@param foldinfo FoldInfo +---@param start_row integer +---@param old_row integer +---@param new_row integer +local function on_bytes(bufnr, foldinfo, start_row, old_row, new_row) + local end_row_old = start_row + old_row + local end_row_new = start_row + new_row + if new_row < old_row then + foldinfo:remove_range(end_row_old, end_row_new) + elseif new_row > old_row then + foldinfo:add_range(start_row, end_row_new) + vim.schedule(function() + get_folds_levels(bufnr, foldinfo, start_row, end_row_new) + recompute_folds() + end) + end +end ---@param lnum integer|nil ---@return string @@ -165,9 +233,27 @@ function M.foldexpr(lnum) return '0' end - local levels = folds_levels(bufnr) or {} + if not foldinfos[bufnr] then + foldinfos[bufnr] = FoldInfo.new() + get_folds_levels(bufnr, foldinfos[bufnr]) + + local parser = vim.treesitter.get_parser(bufnr) + parser:register_cbs({ + on_changedtree = function(tree_changes) + on_changedtree(bufnr, foldinfos[bufnr], tree_changes) + end, + + on_bytes = function(_, _, start_row, _, _, old_row, _, _, new_row, _, _) + on_bytes(bufnr, foldinfos[bufnr], start_row, old_row, new_row) + end, + + on_detach = function() + foldinfos[bufnr] = nil + end, + }) + end - return levels[lnum] or '0' + return foldinfos[bufnr].levels[lnum] or '0' end return M |