aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter/_fold.lua
diff options
context:
space:
mode:
authorJosh Rahm <joshuarahm@gmail.com>2023-11-29 22:40:31 +0000
committerJosh Rahm <joshuarahm@gmail.com>2023-11-29 22:40:31 +0000
commit339e2d15cc26fe86988ea06468d912a46c8d6f29 (patch)
treea6167fc8fcfc6ae2dc102f57b2473858eac34063 /runtime/lua/vim/treesitter/_fold.lua
parent067dc73729267c0262438a6fdd66e586f8496946 (diff)
parent4a8bf24ac690004aedf5540fa440e788459e5e34 (diff)
downloadrneovim-339e2d15cc26fe86988ea06468d912a46c8d6f29.tar.gz
rneovim-339e2d15cc26fe86988ea06468d912a46c8d6f29.tar.bz2
rneovim-339e2d15cc26fe86988ea06468d912a46c8d6f29.zip
Merge remote-tracking branch 'upstream/master' into fix_repeatcmdline
Diffstat (limited to 'runtime/lua/vim/treesitter/_fold.lua')
-rw-r--r--runtime/lua/vim/treesitter/_fold.lua456
1 files changed, 456 insertions, 0 deletions
diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua
new file mode 100644
index 0000000000..5c1cc06908
--- /dev/null
+++ b/runtime/lua/vim/treesitter/_fold.lua
@@ -0,0 +1,456 @@
+local ts = vim.treesitter
+
+local Range = require('vim.treesitter._range')
+
+local api = vim.api
+
+---@class TS.FoldInfo
+---@field levels table<integer,string>
+---@field levels0 table<integer,integer>
+---@field private start_counts table<integer,integer>
+---@field private stop_counts table<integer,integer>
+local FoldInfo = {}
+FoldInfo.__index = FoldInfo
+
+---@private
+function FoldInfo.new()
+ return setmetatable({
+ start_counts = {},
+ stop_counts = {},
+ levels0 = {},
+ levels = {},
+ }, FoldInfo)
+end
+
+---@package
+---@param srow integer
+---@param erow integer
+function FoldInfo:invalidate_range(srow, erow)
+ for i = srow, erow do
+ self.start_counts[i + 1] = nil
+ self.stop_counts[i + 1] = nil
+ self.levels0[i + 1] = nil
+ self.levels[i + 1] = nil
+ end
+end
+
+--- Efficiently remove items from middle of a list a list.
+---
+--- Calling table.remove() in a loop will re-index the tail of the table on
+--- every iteration, instead this function will re-index the table exactly
+--- once.
+---
+--- Based on https://stackoverflow.com/questions/12394841/safely-remove-items-from-an-array-table-while-iterating/53038524#53038524
+---
+---@param t any[]
+---@param first integer
+---@param last integer
+local function list_remove(t, first, last)
+ local n = #t
+ for i = 0, n - first do
+ t[first + i] = t[last + 1 + i]
+ t[last + 1 + i] = nil
+ end
+end
+
+---@package
+---@param srow integer
+---@param erow integer
+function FoldInfo:remove_range(srow, erow)
+ list_remove(self.levels, srow + 1, erow)
+ list_remove(self.levels0, srow + 1, erow)
+ list_remove(self.start_counts, srow + 1, erow)
+ list_remove(self.stop_counts, srow + 1, erow)
+end
+
+--- Efficiently insert items into the middle of a list.
+---
+--- Calling table.insert() in a loop will re-index the tail of the table on
+--- every iteration, instead this function will re-index the table exactly
+--- once.
+---
+--- Based on https://stackoverflow.com/questions/12394841/safely-remove-items-from-an-array-table-while-iterating/53038524#53038524
+---
+---@param t any[]
+---@param first integer
+---@param last integer
+---@param v any
+local function list_insert(t, first, last, v)
+ local n = #t
+
+ -- Shift table forward
+ for i = n - first, 0, -1 do
+ t[last + 1 + i] = t[first + i]
+ end
+
+ -- Fill in new values
+ for i = first, last do
+ t[i] = v
+ end
+end
+
+---@package
+---@param srow integer
+---@param erow integer
+function FoldInfo:add_range(srow, erow)
+ list_insert(self.levels, srow + 1, erow, '-1')
+ list_insert(self.levels0, srow + 1, erow, -1)
+ list_insert(self.start_counts, srow + 1, erow, nil)
+ list_insert(self.stop_counts, srow + 1, erow, nil)
+end
+
+---@package
+---@param lnum integer
+function FoldInfo:add_start(lnum)
+ self.start_counts[lnum] = (self.start_counts[lnum] or 0) + 1
+end
+
+---@package
+---@param lnum integer
+function FoldInfo:add_stop(lnum)
+ self.stop_counts[lnum] = (self.stop_counts[lnum] or 0) + 1
+end
+
+---@package
+---@param lnum integer
+---@return integer
+function FoldInfo:get_start(lnum)
+ return self.start_counts[lnum] or 0
+end
+
+---@package
+---@param lnum integer
+---@return integer
+function FoldInfo:get_stop(lnum)
+ return self.stop_counts[lnum] or 0
+end
+
+local function trim_level(level)
+ local max_fold_level = vim.wo.foldnestmax
+ if level > max_fold_level then
+ return max_fold_level
+ end
+ return level
+end
+
+--- If a parser doesn't have any ranges explicitly set, treesitter will
+--- return a range with end_row and end_bytes with a value of UINT32_MAX,
+--- so clip end_row to the max buffer line.
+---
+--- TODO(lewis6991): Handle this generally
+---
+--- @param bufnr integer
+--- @param erow integer?
+--- @return integer
+local function normalise_erow(bufnr, erow)
+ local max_erow = api.nvim_buf_line_count(bufnr) - 1
+ return math.min(erow or max_erow, max_erow)
+end
+
+-- TODO(lewis6991): Setup a decor provider so injections folds can be parsed
+-- as the window is redrawn
+---@param bufnr integer
+---@param info TS.FoldInfo
+---@param srow integer?
+---@param erow integer?
+---@param parse_injections? boolean
+local function get_folds_levels(bufnr, info, srow, erow, parse_injections)
+ srow = srow or 0
+ erow = normalise_erow(bufnr, erow)
+
+ info:invalidate_range(srow, erow)
+
+ local prev_start = -1
+ local prev_stop = -1
+
+ local parser = ts.get_parser(bufnr)
+
+ parser:parse(parse_injections and { srow, erow } or nil)
+
+ parser:for_each_tree(function(tree, ltree)
+ local query = ts.query.get(ltree:lang(), 'folds')
+ if not query then
+ return
+ end
+
+ -- erow in query is end-exclusive
+ local q_erow = erow and erow + 1 or -1
+
+ for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow, q_erow) do
+ if query.captures[id] == 'fold' then
+ local range = ts.get_range(node, bufnr, metadata[id])
+ local start, _, stop, stop_col = Range.unpack4(range)
+
+ if stop_col == 0 then
+ stop = stop - 1
+ end
+
+ local fold_length = stop - start + 1
+
+ -- Fold only multiline nodes that are not exactly the same as previously met folds
+ -- Checking against just the previously found fold is sufficient if nodes
+ -- are returned in preorder or postorder when traversing tree
+ if
+ fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop)
+ then
+ info:add_start(start + 1)
+ info:add_stop(stop + 1)
+ prev_start = start
+ prev_stop = stop
+ end
+ end
+ end
+ end)
+
+ local current_level = info.levels0[srow] or 0
+
+ -- We now have the list of fold opening and closing, fill the gaps and mark where fold start
+ for lnum = srow + 1, erow + 1 do
+ local last_trimmed_level = trim_level(current_level)
+ current_level = current_level + info:get_start(lnum)
+ info.levels0[lnum] = current_level
+
+ local trimmed_level = trim_level(current_level)
+ current_level = current_level - info:get_stop(lnum)
+
+ -- Determine if it's the start/end of a fold
+ -- NB: vim's fold-expr interface does not have a mechanism to indicate that
+ -- two (or more) folds start at this line, so it cannot distinguish between
+ -- ( \n ( \n )) \n (( \n ) \n )
+ -- versus
+ -- ( \n ( \n ) \n ( \n ) \n )
+ -- If it did have such a mechanism, (trimmed_level - last_trimmed_level)
+ -- would be the correct number of starts to pass on.
+ local prefix = ''
+ if trimmed_level - last_trimmed_level > 0 then
+ prefix = '>'
+ end
+
+ info.levels[lnum] = prefix .. tostring(trimmed_level)
+ end
+end
+
+local M = {}
+
+---@type table<integer,TS.FoldInfo>
+local foldinfos = {}
+
+local group = api.nvim_create_augroup('treesitter/fold', {})
+
+--- Update the folds in the windows that contain the buffer and use expr foldmethod (assuming that
+--- the user doesn't use different foldexpr for the same buffer).
+---
+--- Nvim usually automatically updates folds when text changes, but it doesn't work here because
+--- FoldInfo update is scheduled. So we do it manually.
+local function foldupdate(bufnr)
+ local function do_update()
+ for _, win in ipairs(vim.fn.win_findbuf(bufnr)) do
+ api.nvim_win_call(win, function()
+ if vim.wo.foldmethod == 'expr' then
+ vim._foldupdate()
+ end
+ end)
+ end
+ end
+
+ if api.nvim_get_mode().mode == 'i' then
+ -- foldUpdate() is guarded in insert mode. So update folds on InsertLeave
+ if #(api.nvim_get_autocmds({
+ group = group,
+ buffer = bufnr,
+ })) > 0 then
+ return
+ end
+ api.nvim_create_autocmd('InsertLeave', {
+ group = group,
+ buffer = bufnr,
+ once = true,
+ callback = do_update,
+ })
+ return
+ end
+
+ do_update()
+end
+
+--- Schedule a function only if bufnr is loaded.
+--- We schedule fold level computation for the following reasons:
+--- * queries seem to use the old buffer state in on_bytes for some unknown reason;
+--- * to avoid textlock;
+--- * to avoid infinite recursion:
+--- get_folds_levels → parse → _do_callback → on_changedtree → get_folds_levels.
+---@param bufnr integer
+---@param fn function
+local function schedule_if_loaded(bufnr, fn)
+ vim.schedule(function()
+ if not api.nvim_buf_is_loaded(bufnr) then
+ return
+ end
+ fn()
+ end)
+end
+
+---@param bufnr integer
+---@param foldinfo TS.FoldInfo
+---@param tree_changes Range4[]
+local function on_changedtree(bufnr, foldinfo, tree_changes)
+ schedule_if_loaded(bufnr, function()
+ for _, change in ipairs(tree_changes) do
+ local srow, _, erow = Range.unpack4(change)
+ get_folds_levels(bufnr, foldinfo, srow, erow)
+ end
+ if #tree_changes > 0 then
+ foldupdate(bufnr)
+ end
+ end)
+end
+
+---@param bufnr integer
+---@param foldinfo TS.FoldInfo
+---@param start_row integer
+---@param old_row integer
+---@param new_row integer
+local function on_bytes(bufnr, foldinfo, start_row, old_row, new_row)
+ local end_row_old = start_row + old_row
+ local end_row_new = start_row + new_row
+
+ if new_row ~= old_row then
+ if new_row < old_row then
+ foldinfo:remove_range(end_row_new, end_row_old)
+ else
+ foldinfo:add_range(start_row, end_row_new)
+ end
+ schedule_if_loaded(bufnr, function()
+ get_folds_levels(bufnr, foldinfo, start_row, end_row_new)
+ foldupdate(bufnr)
+ end)
+ end
+end
+
+---@package
+---@param lnum integer|nil
+---@return string
+function M.foldexpr(lnum)
+ lnum = lnum or vim.v.lnum
+ local bufnr = api.nvim_get_current_buf()
+
+ local parser = vim.F.npcall(ts.get_parser, bufnr)
+ if not parser then
+ return '0'
+ end
+
+ if not foldinfos[bufnr] then
+ foldinfos[bufnr] = FoldInfo.new()
+ get_folds_levels(bufnr, foldinfos[bufnr])
+
+ parser:register_cbs({
+ on_changedtree = function(tree_changes)
+ on_changedtree(bufnr, foldinfos[bufnr], tree_changes)
+ end,
+
+ on_bytes = function(_, _, start_row, _, _, old_row, _, _, new_row, _, _)
+ on_bytes(bufnr, foldinfos[bufnr], start_row, old_row, new_row)
+ end,
+
+ on_detach = function()
+ foldinfos[bufnr] = nil
+ end,
+ })
+ end
+
+ return foldinfos[bufnr].levels[lnum] or '0'
+end
+
+---@package
+---@return { [1]: string, [2]: string[] }[]|string
+function M.foldtext()
+ local foldstart = vim.v.foldstart
+ local bufnr = api.nvim_get_current_buf()
+
+ ---@type boolean, LanguageTree
+ local ok, parser = pcall(ts.get_parser, bufnr)
+ if not ok then
+ return vim.fn.foldtext()
+ end
+
+ local query = ts.query.get(parser:lang(), 'highlights')
+ if not query then
+ return vim.fn.foldtext()
+ end
+
+ local tree = parser:parse({ foldstart - 1, foldstart })[1]
+
+ local line = api.nvim_buf_get_lines(bufnr, foldstart - 1, foldstart, false)[1]
+ if not line then
+ return vim.fn.foldtext()
+ end
+
+ ---@type { [1]: string, [2]: string[], range: { [1]: integer, [2]: integer } }[] | { [1]: string, [2]: string[] }[]
+ local result = {}
+
+ local line_pos = 0
+
+ for id, node, metadata in query:iter_captures(tree:root(), 0, foldstart - 1, foldstart) do
+ local name = query.captures[id]
+ local start_row, start_col, end_row, end_col = node:range()
+
+ local priority = tonumber(metadata.priority or vim.highlight.priorities.treesitter)
+
+ if start_row == foldstart - 1 and end_row == foldstart - 1 then
+ -- check for characters ignored by treesitter
+ if start_col > line_pos then
+ table.insert(result, {
+ line:sub(line_pos + 1, start_col),
+ {},
+ range = { line_pos, start_col },
+ })
+ end
+ line_pos = end_col
+
+ local text = line:sub(start_col + 1, end_col)
+ table.insert(result, { text, { { '@' .. name, priority } }, range = { start_col, end_col } })
+ end
+ end
+
+ local i = 1
+ while i <= #result do
+ -- find first capture that is not in current range and apply highlights on the way
+ local j = i + 1
+ while
+ j <= #result
+ and result[j].range[1] >= result[i].range[1]
+ and result[j].range[2] <= result[i].range[2]
+ do
+ for k, v in ipairs(result[i][2]) do
+ if not vim.tbl_contains(result[j][2], v) then
+ table.insert(result[j][2], k, v)
+ end
+ end
+ j = j + 1
+ end
+
+ -- remove the parent capture if it is split into children
+ if j > i + 1 then
+ table.remove(result, i)
+ else
+ -- highlights need to be sorted by priority, on equal prio, the deeper nested capture (earlier
+ -- in list) should be considered higher prio
+ if #result[i][2] > 1 then
+ table.sort(result[i][2], function(a, b)
+ return a[2] < b[2]
+ end)
+ end
+
+ result[i][2] = vim.tbl_map(function(tbl)
+ return tbl[1]
+ end, result[i][2])
+ result[i] = { result[i][1], result[i][2] }
+
+ i = i + 1
+ end
+ end
+
+ return result
+end
+
+return M