aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter/_fold.lua
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/lua/vim/treesitter/_fold.lua')
-rw-r--r--runtime/lua/vim/treesitter/_fold.lua173
1 files changed, 173 insertions, 0 deletions
diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua
new file mode 100644
index 0000000000..a66cc6d543
--- /dev/null
+++ b/runtime/lua/vim/treesitter/_fold.lua
@@ -0,0 +1,173 @@
+local api = vim.api
+
+local M = {}
+
+--- Memoizes a function based on the buffer tick of the provided bufnr.
+--- The cache entry is cleared when the buffer is detached to avoid memory leaks.
+---@generic F: function
+---@param fn F fn to memoize, taking the bufnr as first argument
+---@return F
+local function memoize_by_changedtick(fn)
+ ---@type table<integer,{result:any,last_tick:integer}>
+ local cache = {}
+
+ ---@param bufnr integer
+ return function(bufnr, ...)
+ local tick = api.nvim_buf_get_changedtick(bufnr)
+
+ if cache[bufnr] then
+ if cache[bufnr].last_tick == tick then
+ return cache[bufnr].result
+ end
+ else
+ local function detach_handler()
+ cache[bufnr] = nil
+ end
+
+ -- Clean up logic only!
+ api.nvim_buf_attach(bufnr, false, {
+ on_detach = detach_handler,
+ on_reload = detach_handler,
+ })
+ end
+
+ cache[bufnr] = {
+ result = fn(bufnr, ...),
+ last_tick = tick,
+ }
+
+ return cache[bufnr].result
+ end
+end
+
+---@param bufnr integer
+---@param capture string
+---@param query_name string
+---@param callback fun(id: integer, node:TSNode, metadata: TSMetadata)
+local function iter_matches_with_capture(bufnr, capture, query_name, callback)
+ local parser = vim.treesitter.get_parser(bufnr)
+
+ if not parser then
+ return
+ end
+
+ parser:for_each_tree(function(tree, lang_tree)
+ local lang = lang_tree:lang()
+ local query = vim.treesitter.query.get_query(lang, query_name)
+ if query then
+ local root = tree:root()
+ local start, _, stop = root:range()
+ for _, match, metadata in query:iter_matches(root, bufnr, start, stop) do
+ for id, node in pairs(match) do
+ if query.captures[id] == capture then
+ callback(id, node, metadata)
+ end
+ end
+ end
+ end
+ end)
+end
+
+---@private
+--- TODO(lewis6991): copied from languagetree.lua. Consolidate
+---@param node TSNode
+---@param id integer
+---@param metadata TSMetadata
+---@return Range
+local function get_range_from_metadata(node, id, metadata)
+ if metadata[id] and metadata[id].range then
+ return metadata[id].range --[[@as Range]]
+ end
+ return { node:range() }
+end
+
+-- This is cached on buf tick to avoid computing that multiple times
+-- Especially not for every line in the file when `zx` is hit
+---@param bufnr integer
+---@return table<integer,string>
+local folds_levels = memoize_by_changedtick(function(bufnr)
+ local max_fold_level = vim.wo.foldnestmax
+ local function trim_level(level)
+ if level > max_fold_level then
+ return max_fold_level
+ end
+ return level
+ end
+
+ -- start..stop is an inclusive range
+ local start_counts = {} ---@type table<integer,integer>
+ local stop_counts = {} ---@type table<integer,integer>
+
+ local prev_start = -1
+ local prev_stop = -1
+
+ local min_fold_lines = vim.wo.foldminlines
+
+ iter_matches_with_capture(bufnr, 'fold', 'folds', function(id, node, metadata)
+ local range = get_range_from_metadata(node, id, metadata)
+ local start, stop, stop_col = range[1], range[3], range[4]
+
+ if stop_col == 0 then
+ stop = stop - 1
+ end
+
+ local fold_length = stop - start + 1
+
+ -- Fold only multiline nodes that are not exactly the same as previously met folds
+ -- Checking against just the previously found fold is sufficient if nodes
+ -- are returned in preorder or postorder when traversing tree
+ if fold_length > min_fold_lines and not (start == prev_start and stop == prev_stop) then
+ start_counts[start] = (start_counts[start] or 0) + 1
+ stop_counts[stop] = (stop_counts[stop] or 0) + 1
+ prev_start = start
+ prev_stop = stop
+ end
+ end)
+
+ ---@type table<integer,string>
+ local levels = {}
+ local current_level = 0
+
+ -- We now have the list of fold opening and closing, fill the gaps and mark where fold start
+ for lnum = 0, api.nvim_buf_line_count(bufnr) do
+ local last_trimmed_level = trim_level(current_level)
+ current_level = current_level + (start_counts[lnum] or 0)
+ local trimmed_level = trim_level(current_level)
+ current_level = current_level - (stop_counts[lnum] or 0)
+
+ -- Determine if it's the start/end of a fold
+ -- NB: vim's fold-expr interface does not have a mechanism to indicate that
+ -- two (or more) folds start at this line, so it cannot distinguish between
+ -- ( \n ( \n )) \n (( \n ) \n )
+ -- versus
+ -- ( \n ( \n ) \n ( \n ) \n )
+ -- If it did have such a mechanism, (trimmed_level - last_trimmed_level)
+ -- would be the correct number of starts to pass on.
+ local prefix = ''
+ if trimmed_level - last_trimmed_level > 0 then
+ prefix = '>'
+ end
+
+ levels[lnum + 1] = prefix .. tostring(trimmed_level)
+ end
+
+ return levels
+end)
+
+---@param lnum integer|nil
+---@return string
+function M.foldexpr(lnum)
+ lnum = lnum or vim.v.lnum
+ local bufnr = api.nvim_get_current_buf()
+
+ ---@diagnostic disable-next-line:invisible
+ if not vim.treesitter._has_parser(bufnr) or not lnum then
+ return '0'
+ end
+
+ local levels = folds_levels(bufnr) or {}
+
+ return levels[lnum] or '0'
+end
+
+return M