diff options
author | Lewis Russell <lewis6991@gmail.com> | 2023-02-23 17:05:20 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-02-23 17:05:20 +0000 |
commit | 1df3f5ec6aca24cbe7b78ead5c37ad06a65c84e8 (patch) | |
tree | d1523f1b7ef2a813859572ba0bdce2874de82699 /runtime/lua/vim | |
parent | 8c339aa04b2a1ca99a297b2eada8ebc6218f5f1c (diff) | |
download | rneovim-1df3f5ec6aca24cbe7b78ead5c37ad06a65c84e8.tar.gz rneovim-1df3f5ec6aca24cbe7b78ead5c37ad06a65c84e8.tar.bz2 rneovim-1df3f5ec6aca24cbe7b78ead5c37ad06a65c84e8.zip |
feat(treesitter): upstream foldexpr from nvim-treesitter
Diffstat (limited to 'runtime/lua/vim')
-rw-r--r-- | runtime/lua/vim/treesitter.lua | 20 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/_fold.lua | 173 |
2 files changed, 193 insertions, 0 deletions
diff --git a/runtime/lua/vim/treesitter.lua b/runtime/lua/vim/treesitter.lua index 44922bbc4d..fead7b7b1b 100644 --- a/runtime/lua/vim/treesitter.lua +++ b/runtime/lua/vim/treesitter.lua @@ -115,6 +115,16 @@ function M.get_parser(bufnr, lang, opts) return parsers[bufnr] end +---@private +---@param bufnr (integer|nil) Buffer number +---@return boolean +function M._has_parser(bufnr) + if bufnr == nil or bufnr == 0 then + bufnr = a.nvim_get_current_buf() + end + return parsers[bufnr] ~= nil +end + --- Returns a string parser --- ---@param str string Text to parse @@ -612,4 +622,14 @@ function M.show_tree(opts) }) end +--- Returns the fold level for {lnum} in the current buffer. Can be set directly to 'foldexpr': +--- <pre>lua +--- vim.wo.foldexpr = 'v:lua.vim.treesitter.foldexpr()' +--- </pre> +---@param lnum integer|nil Line number to calculate fold level for +---@return string +function M.foldexpr(lnum) + return require('vim.treesitter._fold').foldexpr(lnum) +end + return M diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua new file mode 100644 index 0000000000..a66cc6d543 --- /dev/null +++ b/runtime/lua/vim/treesitter/_fold.lua @@ -0,0 +1,173 @@ +local api = vim.api + +local M = {} + +--- Memoizes a function based on the buffer tick of the provided bufnr. +--- The cache entry is cleared when the buffer is detached to avoid memory leaks. +---@generic F: function +---@param fn F fn to memoize, taking the bufnr as first argument +---@return F +local function memoize_by_changedtick(fn) + ---@type table<integer,{result:any,last_tick:integer}> + local cache = {} + + ---@param bufnr integer + return function(bufnr, ...) + local tick = api.nvim_buf_get_changedtick(bufnr) + + if cache[bufnr] then + if cache[bufnr].last_tick == tick then + return cache[bufnr].result + end + else + local function detach_handler() + cache[bufnr] = nil + end + + -- Clean up logic only! + api.nvim_buf_attach(bufnr, false, { + on_detach = detach_handler, + on_reload = detach_handler, + }) + end + + cache[bufnr] = { + result = fn(bufnr, ...), + last_tick = tick, + } + + return cache[bufnr].result + end +end + +---@param bufnr integer +---@param capture string +---@param query_name string +---@param callback fun(id: integer, node:TSNode, metadata: TSMetadata) +local function iter_matches_with_capture(bufnr, capture, query_name, callback) + local parser = vim.treesitter.get_parser(bufnr) + + if not parser then + return + end + + parser:for_each_tree(function(tree, lang_tree) + local lang = lang_tree:lang() + local query = vim.treesitter.query.get_query(lang, query_name) + if query then + local root = tree:root() + local start, _, stop = root:range() + for _, match, metadata in query:iter_matches(root, bufnr, start, stop) do + for id, node in pairs(match) do + if query.captures[id] == capture then + callback(id, node, metadata) + end + end + end + end + end) +end + +---@private +--- TODO(lewis6991): copied from languagetree.lua. Consolidate +---@param node TSNode +---@param id integer +---@param metadata TSMetadata +---@return Range +local function get_range_from_metadata(node, id, metadata) + if metadata[id] and metadata[id].range then + return metadata[id].range --[[@as Range]] + end + return { node:range() } +end + +-- This is cached on buf tick to avoid computing that multiple times +-- Especially not for every line in the file when `zx` is hit +---@param bufnr integer +---@return table<integer,string> +local folds_levels = memoize_by_changedtick(function(bufnr) + local max_fold_level = vim.wo.foldnestmax + local function trim_level(level) + if level > max_fold_level then + return max_fold_level + end + return level + end + + -- start..stop is an inclusive range + local start_counts = {} ---@type table<integer,integer> + local stop_counts = {} ---@type table<integer,integer> + + local prev_start = -1 + local prev_stop = -1 + + local min_fold_lines = vim.wo.foldminlines + + iter_matches_with_capture(bufnr, 'fold', 'folds', function(id, node, metadata) + local range = get_range_from_metadata(node, id, metadata) + local start, stop, stop_col = range[1], range[3], range[4] + + if stop_col == 0 then + stop = stop - 1 + end + + local fold_length = stop - start + 1 + + -- Fold only multiline nodes that are not exactly the same as previously met folds + -- Checking against just the previously found fold is sufficient if nodes + -- are returned in preorder or postorder when traversing tree + if fold_length > min_fold_lines and not (start == prev_start and stop == prev_stop) then + start_counts[start] = (start_counts[start] or 0) + 1 + stop_counts[stop] = (stop_counts[stop] or 0) + 1 + prev_start = start + prev_stop = stop + end + end) + + ---@type table<integer,string> + local levels = {} + local current_level = 0 + + -- We now have the list of fold opening and closing, fill the gaps and mark where fold start + for lnum = 0, api.nvim_buf_line_count(bufnr) do + local last_trimmed_level = trim_level(current_level) + current_level = current_level + (start_counts[lnum] or 0) + local trimmed_level = trim_level(current_level) + current_level = current_level - (stop_counts[lnum] or 0) + + -- Determine if it's the start/end of a fold + -- NB: vim's fold-expr interface does not have a mechanism to indicate that + -- two (or more) folds start at this line, so it cannot distinguish between + -- ( \n ( \n )) \n (( \n ) \n ) + -- versus + -- ( \n ( \n ) \n ( \n ) \n ) + -- If it did have such a mechanism, (trimmed_level - last_trimmed_level) + -- would be the correct number of starts to pass on. + local prefix = '' + if trimmed_level - last_trimmed_level > 0 then + prefix = '>' + end + + levels[lnum + 1] = prefix .. tostring(trimmed_level) + end + + return levels +end) + +---@param lnum integer|nil +---@return string +function M.foldexpr(lnum) + lnum = lnum or vim.v.lnum + local bufnr = api.nvim_get_current_buf() + + ---@diagnostic disable-next-line:invisible + if not vim.treesitter._has_parser(bufnr) or not lnum then + return '0' + end + + local levels = folds_levels(bufnr) or {} + + return levels[lnum] or '0' +end + +return M |