aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter/_fold.lua
diff options
context:
space:
mode:
authorLewis Russell <lewis6991@gmail.com>2023-03-10 13:35:07 +0000
committerGitHub <noreply@github.com>2023-03-10 13:35:07 +0000
commit845efb8e12cb014b385deac62fb83622a99024ec (patch)
tree3686fc9ffbdd4bd2afeb4419ff649ef6e0f34c55 /runtime/lua/vim/treesitter/_fold.lua
parent75537768ef0b8cc35ef9c6aa906237e449640b46 (diff)
parent46b73bf22cb951151de9bf0712d42e194000b677 (diff)
downloadrneovim-845efb8e12cb014b385deac62fb83622a99024ec.tar.gz
rneovim-845efb8e12cb014b385deac62fb83622a99024ec.tar.bz2
rneovim-845efb8e12cb014b385deac62fb83622a99024ec.zip
Merge pull request #22594 from lewis6991/perf/treefold
Diffstat (limited to 'runtime/lua/vim/treesitter/_fold.lua')
-rw-r--r--runtime/lua/vim/treesitter/_fold.lua290
1 files changed, 188 insertions, 102 deletions
diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua
index a66cc6d543..435cb9fdb6 100644
--- a/runtime/lua/vim/treesitter/_fold.lua
+++ b/runtime/lua/vim/treesitter/_fold.lua
@@ -1,139 +1,157 @@
+local Range = require('vim.treesitter._range')
+
local api = vim.api
-local M = {}
+---@class FoldInfo
+---@field levels table<integer,string>
+---@field levels0 table<integer,integer>
+---@field private start_counts table<integer,integer>
+---@field private stop_counts table<integer,integer>
+local FoldInfo = {}
+FoldInfo.__index = FoldInfo
---- Memoizes a function based on the buffer tick of the provided bufnr.
---- The cache entry is cleared when the buffer is detached to avoid memory leaks.
----@generic F: function
----@param fn F fn to memoize, taking the bufnr as first argument
----@return F
-local function memoize_by_changedtick(fn)
- ---@type table<integer,{result:any,last_tick:integer}>
- local cache = {}
-
- ---@param bufnr integer
- return function(bufnr, ...)
- local tick = api.nvim_buf_get_changedtick(bufnr)
-
- if cache[bufnr] then
- if cache[bufnr].last_tick == tick then
- return cache[bufnr].result
- end
- else
- local function detach_handler()
- cache[bufnr] = nil
- end
+function FoldInfo.new()
+ return setmetatable({
+ start_counts = {},
+ stop_counts = {},
+ levels0 = {},
+ levels = {},
+ }, FoldInfo)
+end
- -- Clean up logic only!
- api.nvim_buf_attach(bufnr, false, {
- on_detach = detach_handler,
- on_reload = detach_handler,
- })
- end
+---@param srow integer
+---@param erow integer
+function FoldInfo:invalidate_range(srow, erow)
+ for i = srow, erow do
+ self.start_counts[i + 1] = nil
+ self.stop_counts[i + 1] = nil
+ self.levels0[i + 1] = nil
+ self.levels[i + 1] = nil
+ end
+end
- cache[bufnr] = {
- result = fn(bufnr, ...),
- last_tick = tick,
- }
+---@param srow integer
+---@param erow integer
+function FoldInfo:remove_range(srow, erow)
+ for i = erow - 1, srow, -1 do
+ table.remove(self.levels, i + 1)
+ table.remove(self.levels0, i + 1)
+ table.remove(self.start_counts, i + 1)
+ table.remove(self.stop_counts, i + 1)
+ end
+end
- return cache[bufnr].result
+---@param srow integer
+---@param erow integer
+function FoldInfo:add_range(srow, erow)
+ for i = srow, erow - 1 do
+ table.insert(self.levels, i + 1, '-1')
+ table.insert(self.levels0, i + 1, -1)
+ table.insert(self.start_counts, i + 1, nil)
+ table.insert(self.stop_counts, i + 1, nil)
end
end
----@param bufnr integer
----@param capture string
----@param query_name string
----@param callback fun(id: integer, node:TSNode, metadata: TSMetadata)
-local function iter_matches_with_capture(bufnr, capture, query_name, callback)
- local parser = vim.treesitter.get_parser(bufnr)
+---@param lnum integer
+function FoldInfo:add_start(lnum)
+ self.start_counts[lnum] = (self.start_counts[lnum] or 0) + 1
+end
- if not parser then
- return
- end
+---@param lnum integer
+function FoldInfo:add_stop(lnum)
+ self.stop_counts[lnum] = (self.stop_counts[lnum] or 0) + 1
+end
- parser:for_each_tree(function(tree, lang_tree)
- local lang = lang_tree:lang()
- local query = vim.treesitter.query.get_query(lang, query_name)
- if query then
- local root = tree:root()
- local start, _, stop = root:range()
- for _, match, metadata in query:iter_matches(root, bufnr, start, stop) do
- for id, node in pairs(match) do
- if query.captures[id] == capture then
- callback(id, node, metadata)
- end
- end
- end
- end
- end)
+---@param lnum integer
+---@return integer
+function FoldInfo:get_start(lnum)
+ return self.start_counts[lnum] or 0
+end
+
+---@param lnum integer
+---@return integer
+function FoldInfo:get_stop(lnum)
+ return self.stop_counts[lnum] or 0
end
---@private
--- TODO(lewis6991): copied from languagetree.lua. Consolidate
---@param node TSNode
----@param id integer
---@param metadata TSMetadata
----@return Range
-local function get_range_from_metadata(node, id, metadata)
- if metadata[id] and metadata[id].range then
- return metadata[id].range --[[@as Range]]
+---@return Range4
+local function get_range_from_metadata(node, metadata)
+ if metadata and metadata.range then
+ return metadata.range --[[@as Range4]]
end
return { node:range() }
end
--- This is cached on buf tick to avoid computing that multiple times
--- Especially not for every line in the file when `zx` is hit
----@param bufnr integer
----@return table<integer,string>
-local folds_levels = memoize_by_changedtick(function(bufnr)
+local function trim_level(level)
local max_fold_level = vim.wo.foldnestmax
- local function trim_level(level)
- if level > max_fold_level then
- return max_fold_level
- end
- return level
+ if level > max_fold_level then
+ return max_fold_level
end
+ return level
+end
- -- start..stop is an inclusive range
- local start_counts = {} ---@type table<integer,integer>
- local stop_counts = {} ---@type table<integer,integer>
+---@param bufnr integer
+---@param info FoldInfo
+---@param srow integer?
+---@param erow integer?
+local function get_folds_levels(bufnr, info, srow, erow)
+ srow = srow or 0
+ erow = erow or api.nvim_buf_line_count(bufnr)
+
+ info:invalidate_range(srow, erow)
local prev_start = -1
local prev_stop = -1
- local min_fold_lines = vim.wo.foldminlines
+ vim.treesitter.get_parser(bufnr):for_each_tree(function(tree, ltree)
+ local query = vim.treesitter.query.get_query(ltree:lang(), 'folds')
+ if not query then
+ return
+ end
+
+ -- erow in query is end-exclusive
+ local q_erow = erow and erow + 1 or -1
- iter_matches_with_capture(bufnr, 'fold', 'folds', function(id, node, metadata)
- local range = get_range_from_metadata(node, id, metadata)
- local start, stop, stop_col = range[1], range[3], range[4]
+ for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow or 0, q_erow) do
+ if query.captures[id] == 'fold' then
+ local range = get_range_from_metadata(node, metadata[id])
+ local start, _, stop, stop_col = Range.unpack4(range)
- if stop_col == 0 then
- stop = stop - 1
- end
+ if stop_col == 0 then
+ stop = stop - 1
+ end
- local fold_length = stop - start + 1
+ local fold_length = stop - start + 1
- -- Fold only multiline nodes that are not exactly the same as previously met folds
- -- Checking against just the previously found fold is sufficient if nodes
- -- are returned in preorder or postorder when traversing tree
- if fold_length > min_fold_lines and not (start == prev_start and stop == prev_stop) then
- start_counts[start] = (start_counts[start] or 0) + 1
- stop_counts[stop] = (stop_counts[stop] or 0) + 1
- prev_start = start
- prev_stop = stop
+ -- Fold only multiline nodes that are not exactly the same as previously met folds
+ -- Checking against just the previously found fold is sufficient if nodes
+ -- are returned in preorder or postorder when traversing tree
+ if
+ fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop)
+ then
+ info:add_start(start + 1)
+ info:add_stop(stop + 1)
+ prev_start = start
+ prev_stop = stop
+ end
+ end
end
end)
- ---@type table<integer,string>
- local levels = {}
- local current_level = 0
+ local current_level = info.levels0[srow] or 0
-- We now have the list of fold opening and closing, fill the gaps and mark where fold start
- for lnum = 0, api.nvim_buf_line_count(bufnr) do
+ for lnum = srow + 1, erow + 1 do
local last_trimmed_level = trim_level(current_level)
- current_level = current_level + (start_counts[lnum] or 0)
+ current_level = current_level + info:get_start(lnum)
+ info.levels0[lnum] = current_level
+
local trimmed_level = trim_level(current_level)
- current_level = current_level - (stop_counts[lnum] or 0)
+ current_level = current_level - info:get_stop(lnum)
-- Determine if it's the start/end of a fold
-- NB: vim's fold-expr interface does not have a mechanism to indicate that
@@ -148,11 +166,61 @@ local folds_levels = memoize_by_changedtick(function(bufnr)
prefix = '>'
end
- levels[lnum + 1] = prefix .. tostring(trimmed_level)
+ info.levels[lnum] = prefix .. tostring(trimmed_level)
+ end
+end
+
+local M = {}
+
+---@type table<integer,FoldInfo>
+local foldinfos = {}
+
+local function recompute_folds()
+ if api.nvim_get_mode().mode == 'i' then
+ -- foldUpdate() is guarded in insert mode. So update folds on InsertLeave
+ api.nvim_create_autocmd('InsertLeave', {
+ once = true,
+ callback = vim._foldupdate,
+ })
+ return
end
- return levels
-end)
+ vim._foldupdate()
+end
+
+---@param bufnr integer
+---@param foldinfo FoldInfo
+---@param tree_changes Range4[]
+local function on_changedtree(bufnr, foldinfo, tree_changes)
+ -- For some reason, queries seem to use the old buffer state in on_bytes.
+ -- Get around this by scheduling and manually updating folds.
+ vim.schedule(function()
+ for _, change in ipairs(tree_changes) do
+ local srow, _, erow = Range.unpack4(change)
+ get_folds_levels(bufnr, foldinfo, srow, erow)
+ end
+ recompute_folds()
+ end)
+end
+
+---@param bufnr integer
+---@param foldinfo FoldInfo
+---@param start_row integer
+---@param old_row integer
+---@param new_row integer
+local function on_bytes(bufnr, foldinfo, start_row, old_row, new_row)
+ local end_row_old = start_row + old_row
+ local end_row_new = start_row + new_row
+ if new_row < old_row then
+ foldinfo:remove_range(end_row_old, end_row_new)
+ elseif new_row > old_row then
+ foldinfo:add_range(start_row, end_row_new)
+ vim.schedule(function()
+ get_folds_levels(bufnr, foldinfo, start_row, end_row_new)
+ recompute_folds()
+ end)
+ end
+end
---@param lnum integer|nil
---@return string
@@ -165,9 +233,27 @@ function M.foldexpr(lnum)
return '0'
end
- local levels = folds_levels(bufnr) or {}
+ if not foldinfos[bufnr] then
+ foldinfos[bufnr] = FoldInfo.new()
+ get_folds_levels(bufnr, foldinfos[bufnr])
+
+ local parser = vim.treesitter.get_parser(bufnr)
+ parser:register_cbs({
+ on_changedtree = function(tree_changes)
+ on_changedtree(bufnr, foldinfos[bufnr], tree_changes)
+ end,
+
+ on_bytes = function(_, _, start_row, _, _, old_row, _, _, new_row, _, _)
+ on_bytes(bufnr, foldinfos[bufnr], start_row, old_row, new_row)
+ end,
+
+ on_detach = function()
+ foldinfos[bufnr] = nil
+ end,
+ })
+ end
- return levels[lnum] or '0'
+ return foldinfos[bufnr].levels[lnum] or '0'
end
return M