diff options
author | Josh Rahm <joshuarahm@gmail.com> | 2023-11-30 20:35:25 +0000 |
---|---|---|
committer | Josh Rahm <joshuarahm@gmail.com> | 2023-11-30 20:35:25 +0000 |
commit | 1b7b916b7631ddf73c38e3a0070d64e4636cb2f3 (patch) | |
tree | cd08258054db80bb9a11b1061bb091c70b76926a /runtime/lua/vim/treesitter/_fold.lua | |
parent | eaa89c11d0f8aefbb512de769c6c82f61a8baca3 (diff) | |
parent | 4a8bf24ac690004aedf5540fa440e788459e5e34 (diff) | |
download | rneovim-aucmd_textputpost.tar.gz rneovim-aucmd_textputpost.tar.bz2 rneovim-aucmd_textputpost.zip |
Merge remote-tracking branch 'upstream/master' into aucmd_textputpostaucmd_textputpost
Diffstat (limited to 'runtime/lua/vim/treesitter/_fold.lua')
-rw-r--r-- | runtime/lua/vim/treesitter/_fold.lua | 456 |
1 files changed, 456 insertions, 0 deletions
diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua new file mode 100644 index 0000000000..5c1cc06908 --- /dev/null +++ b/runtime/lua/vim/treesitter/_fold.lua @@ -0,0 +1,456 @@ +local ts = vim.treesitter + +local Range = require('vim.treesitter._range') + +local api = vim.api + +---@class TS.FoldInfo +---@field levels table<integer,string> +---@field levels0 table<integer,integer> +---@field private start_counts table<integer,integer> +---@field private stop_counts table<integer,integer> +local FoldInfo = {} +FoldInfo.__index = FoldInfo + +---@private +function FoldInfo.new() + return setmetatable({ + start_counts = {}, + stop_counts = {}, + levels0 = {}, + levels = {}, + }, FoldInfo) +end + +---@package +---@param srow integer +---@param erow integer +function FoldInfo:invalidate_range(srow, erow) + for i = srow, erow do + self.start_counts[i + 1] = nil + self.stop_counts[i + 1] = nil + self.levels0[i + 1] = nil + self.levels[i + 1] = nil + end +end + +--- Efficiently remove items from middle of a list a list. +--- +--- Calling table.remove() in a loop will re-index the tail of the table on +--- every iteration, instead this function will re-index the table exactly +--- once. +--- +--- Based on https://stackoverflow.com/questions/12394841/safely-remove-items-from-an-array-table-while-iterating/53038524#53038524 +--- +---@param t any[] +---@param first integer +---@param last integer +local function list_remove(t, first, last) + local n = #t + for i = 0, n - first do + t[first + i] = t[last + 1 + i] + t[last + 1 + i] = nil + end +end + +---@package +---@param srow integer +---@param erow integer +function FoldInfo:remove_range(srow, erow) + list_remove(self.levels, srow + 1, erow) + list_remove(self.levels0, srow + 1, erow) + list_remove(self.start_counts, srow + 1, erow) + list_remove(self.stop_counts, srow + 1, erow) +end + +--- Efficiently insert items into the middle of a list. +--- +--- Calling table.insert() in a loop will re-index the tail of the table on +--- every iteration, instead this function will re-index the table exactly +--- once. +--- +--- Based on https://stackoverflow.com/questions/12394841/safely-remove-items-from-an-array-table-while-iterating/53038524#53038524 +--- +---@param t any[] +---@param first integer +---@param last integer +---@param v any +local function list_insert(t, first, last, v) + local n = #t + + -- Shift table forward + for i = n - first, 0, -1 do + t[last + 1 + i] = t[first + i] + end + + -- Fill in new values + for i = first, last do + t[i] = v + end +end + +---@package +---@param srow integer +---@param erow integer +function FoldInfo:add_range(srow, erow) + list_insert(self.levels, srow + 1, erow, '-1') + list_insert(self.levels0, srow + 1, erow, -1) + list_insert(self.start_counts, srow + 1, erow, nil) + list_insert(self.stop_counts, srow + 1, erow, nil) +end + +---@package +---@param lnum integer +function FoldInfo:add_start(lnum) + self.start_counts[lnum] = (self.start_counts[lnum] or 0) + 1 +end + +---@package +---@param lnum integer +function FoldInfo:add_stop(lnum) + self.stop_counts[lnum] = (self.stop_counts[lnum] or 0) + 1 +end + +---@package +---@param lnum integer +---@return integer +function FoldInfo:get_start(lnum) + return self.start_counts[lnum] or 0 +end + +---@package +---@param lnum integer +---@return integer +function FoldInfo:get_stop(lnum) + return self.stop_counts[lnum] or 0 +end + +local function trim_level(level) + local max_fold_level = vim.wo.foldnestmax + if level > max_fold_level then + return max_fold_level + end + return level +end + +--- If a parser doesn't have any ranges explicitly set, treesitter will +--- return a range with end_row and end_bytes with a value of UINT32_MAX, +--- so clip end_row to the max buffer line. +--- +--- TODO(lewis6991): Handle this generally +--- +--- @param bufnr integer +--- @param erow integer? +--- @return integer +local function normalise_erow(bufnr, erow) + local max_erow = api.nvim_buf_line_count(bufnr) - 1 + return math.min(erow or max_erow, max_erow) +end + +-- TODO(lewis6991): Setup a decor provider so injections folds can be parsed +-- as the window is redrawn +---@param bufnr integer +---@param info TS.FoldInfo +---@param srow integer? +---@param erow integer? +---@param parse_injections? boolean +local function get_folds_levels(bufnr, info, srow, erow, parse_injections) + srow = srow or 0 + erow = normalise_erow(bufnr, erow) + + info:invalidate_range(srow, erow) + + local prev_start = -1 + local prev_stop = -1 + + local parser = ts.get_parser(bufnr) + + parser:parse(parse_injections and { srow, erow } or nil) + + parser:for_each_tree(function(tree, ltree) + local query = ts.query.get(ltree:lang(), 'folds') + if not query then + return + end + + -- erow in query is end-exclusive + local q_erow = erow and erow + 1 or -1 + + for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow, q_erow) do + if query.captures[id] == 'fold' then + local range = ts.get_range(node, bufnr, metadata[id]) + local start, _, stop, stop_col = Range.unpack4(range) + + if stop_col == 0 then + stop = stop - 1 + end + + local fold_length = stop - start + 1 + + -- Fold only multiline nodes that are not exactly the same as previously met folds + -- Checking against just the previously found fold is sufficient if nodes + -- are returned in preorder or postorder when traversing tree + if + fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop) + then + info:add_start(start + 1) + info:add_stop(stop + 1) + prev_start = start + prev_stop = stop + end + end + end + end) + + local current_level = info.levels0[srow] or 0 + + -- We now have the list of fold opening and closing, fill the gaps and mark where fold start + for lnum = srow + 1, erow + 1 do + local last_trimmed_level = trim_level(current_level) + current_level = current_level + info:get_start(lnum) + info.levels0[lnum] = current_level + + local trimmed_level = trim_level(current_level) + current_level = current_level - info:get_stop(lnum) + + -- Determine if it's the start/end of a fold + -- NB: vim's fold-expr interface does not have a mechanism to indicate that + -- two (or more) folds start at this line, so it cannot distinguish between + -- ( \n ( \n )) \n (( \n ) \n ) + -- versus + -- ( \n ( \n ) \n ( \n ) \n ) + -- If it did have such a mechanism, (trimmed_level - last_trimmed_level) + -- would be the correct number of starts to pass on. + local prefix = '' + if trimmed_level - last_trimmed_level > 0 then + prefix = '>' + end + + info.levels[lnum] = prefix .. tostring(trimmed_level) + end +end + +local M = {} + +---@type table<integer,TS.FoldInfo> +local foldinfos = {} + +local group = api.nvim_create_augroup('treesitter/fold', {}) + +--- Update the folds in the windows that contain the buffer and use expr foldmethod (assuming that +--- the user doesn't use different foldexpr for the same buffer). +--- +--- Nvim usually automatically updates folds when text changes, but it doesn't work here because +--- FoldInfo update is scheduled. So we do it manually. +local function foldupdate(bufnr) + local function do_update() + for _, win in ipairs(vim.fn.win_findbuf(bufnr)) do + api.nvim_win_call(win, function() + if vim.wo.foldmethod == 'expr' then + vim._foldupdate() + end + end) + end + end + + if api.nvim_get_mode().mode == 'i' then + -- foldUpdate() is guarded in insert mode. So update folds on InsertLeave + if #(api.nvim_get_autocmds({ + group = group, + buffer = bufnr, + })) > 0 then + return + end + api.nvim_create_autocmd('InsertLeave', { + group = group, + buffer = bufnr, + once = true, + callback = do_update, + }) + return + end + + do_update() +end + +--- Schedule a function only if bufnr is loaded. +--- We schedule fold level computation for the following reasons: +--- * queries seem to use the old buffer state in on_bytes for some unknown reason; +--- * to avoid textlock; +--- * to avoid infinite recursion: +--- get_folds_levels → parse → _do_callback → on_changedtree → get_folds_levels. +---@param bufnr integer +---@param fn function +local function schedule_if_loaded(bufnr, fn) + vim.schedule(function() + if not api.nvim_buf_is_loaded(bufnr) then + return + end + fn() + end) +end + +---@param bufnr integer +---@param foldinfo TS.FoldInfo +---@param tree_changes Range4[] +local function on_changedtree(bufnr, foldinfo, tree_changes) + schedule_if_loaded(bufnr, function() + for _, change in ipairs(tree_changes) do + local srow, _, erow = Range.unpack4(change) + get_folds_levels(bufnr, foldinfo, srow, erow) + end + if #tree_changes > 0 then + foldupdate(bufnr) + end + end) +end + +---@param bufnr integer +---@param foldinfo TS.FoldInfo +---@param start_row integer +---@param old_row integer +---@param new_row integer +local function on_bytes(bufnr, foldinfo, start_row, old_row, new_row) + local end_row_old = start_row + old_row + local end_row_new = start_row + new_row + + if new_row ~= old_row then + if new_row < old_row then + foldinfo:remove_range(end_row_new, end_row_old) + else + foldinfo:add_range(start_row, end_row_new) + end + schedule_if_loaded(bufnr, function() + get_folds_levels(bufnr, foldinfo, start_row, end_row_new) + foldupdate(bufnr) + end) + end +end + +---@package +---@param lnum integer|nil +---@return string +function M.foldexpr(lnum) + lnum = lnum or vim.v.lnum + local bufnr = api.nvim_get_current_buf() + + local parser = vim.F.npcall(ts.get_parser, bufnr) + if not parser then + return '0' + end + + if not foldinfos[bufnr] then + foldinfos[bufnr] = FoldInfo.new() + get_folds_levels(bufnr, foldinfos[bufnr]) + + parser:register_cbs({ + on_changedtree = function(tree_changes) + on_changedtree(bufnr, foldinfos[bufnr], tree_changes) + end, + + on_bytes = function(_, _, start_row, _, _, old_row, _, _, new_row, _, _) + on_bytes(bufnr, foldinfos[bufnr], start_row, old_row, new_row) + end, + + on_detach = function() + foldinfos[bufnr] = nil + end, + }) + end + + return foldinfos[bufnr].levels[lnum] or '0' +end + +---@package +---@return { [1]: string, [2]: string[] }[]|string +function M.foldtext() + local foldstart = vim.v.foldstart + local bufnr = api.nvim_get_current_buf() + + ---@type boolean, LanguageTree + local ok, parser = pcall(ts.get_parser, bufnr) + if not ok then + return vim.fn.foldtext() + end + + local query = ts.query.get(parser:lang(), 'highlights') + if not query then + return vim.fn.foldtext() + end + + local tree = parser:parse({ foldstart - 1, foldstart })[1] + + local line = api.nvim_buf_get_lines(bufnr, foldstart - 1, foldstart, false)[1] + if not line then + return vim.fn.foldtext() + end + + ---@type { [1]: string, [2]: string[], range: { [1]: integer, [2]: integer } }[] | { [1]: string, [2]: string[] }[] + local result = {} + + local line_pos = 0 + + for id, node, metadata in query:iter_captures(tree:root(), 0, foldstart - 1, foldstart) do + local name = query.captures[id] + local start_row, start_col, end_row, end_col = node:range() + + local priority = tonumber(metadata.priority or vim.highlight.priorities.treesitter) + + if start_row == foldstart - 1 and end_row == foldstart - 1 then + -- check for characters ignored by treesitter + if start_col > line_pos then + table.insert(result, { + line:sub(line_pos + 1, start_col), + {}, + range = { line_pos, start_col }, + }) + end + line_pos = end_col + + local text = line:sub(start_col + 1, end_col) + table.insert(result, { text, { { '@' .. name, priority } }, range = { start_col, end_col } }) + end + end + + local i = 1 + while i <= #result do + -- find first capture that is not in current range and apply highlights on the way + local j = i + 1 + while + j <= #result + and result[j].range[1] >= result[i].range[1] + and result[j].range[2] <= result[i].range[2] + do + for k, v in ipairs(result[i][2]) do + if not vim.tbl_contains(result[j][2], v) then + table.insert(result[j][2], k, v) + end + end + j = j + 1 + end + + -- remove the parent capture if it is split into children + if j > i + 1 then + table.remove(result, i) + else + -- highlights need to be sorted by priority, on equal prio, the deeper nested capture (earlier + -- in list) should be considered higher prio + if #result[i][2] > 1 then + table.sort(result[i][2], function(a, b) + return a[2] < b[2] + end) + end + + result[i][2] = vim.tbl_map(function(tbl) + return tbl[1] + end, result[i][2]) + result[i] = { result[i][1], result[i][2] } + + i = i + 1 + end + end + + return result +end + +return M |