local ts = vim.treesitter local Range = require('vim.treesitter._range') local api = vim.api ---@class TS.FoldInfo ---@field levels table ---@field levels0 table ---@field private start_counts table ---@field private stop_counts table local FoldInfo = {} FoldInfo.__index = FoldInfo ---@private function FoldInfo.new() return setmetatable({ start_counts = {}, stop_counts = {}, levels0 = {}, levels = {}, }, FoldInfo) end ---@package ---@param srow integer ---@param erow integer function FoldInfo:invalidate_range(srow, erow) for i = srow, erow do self.start_counts[i + 1] = nil self.stop_counts[i + 1] = nil self.levels0[i + 1] = nil self.levels[i + 1] = nil end end --- Efficiently remove items from middle of a list a list. --- --- Calling table.remove() in a loop will re-index the tail of the table on --- every iteration, instead this function will re-index the table exactly --- once. --- --- Based on https://stackoverflow.com/questions/12394841/safely-remove-items-from-an-array-table-while-iterating/53038524#53038524 --- ---@param t any[] ---@param first integer ---@param last integer local function list_remove(t, first, last) local n = #t for i = 0, n - first do t[first + i] = t[last + 1 + i] t[last + 1 + i] = nil end end ---@package ---@param srow integer ---@param erow integer function FoldInfo:remove_range(srow, erow) list_remove(self.levels, srow + 1, erow) list_remove(self.levels0, srow + 1, erow) list_remove(self.start_counts, srow + 1, erow) list_remove(self.stop_counts, srow + 1, erow) end --- Efficiently insert items into the middle of a list. --- --- Calling table.insert() in a loop will re-index the tail of the table on --- every iteration, instead this function will re-index the table exactly --- once. --- --- Based on https://stackoverflow.com/questions/12394841/safely-remove-items-from-an-array-table-while-iterating/53038524#53038524 --- ---@param t any[] ---@param first integer ---@param last integer ---@param v any local function list_insert(t, first, last, v) local n = #t -- Shift table forward for i = n - first, 0, -1 do t[last + 1 + i] = t[first + i] end -- Fill in new values for i = first, last do t[i] = v end end ---@package ---@param srow integer ---@param erow integer function FoldInfo:add_range(srow, erow) list_insert(self.levels, srow + 1, erow, '-1') list_insert(self.levels0, srow + 1, erow, -1) list_insert(self.start_counts, srow + 1, erow, nil) list_insert(self.stop_counts, srow + 1, erow, nil) end ---@package ---@param lnum integer function FoldInfo:add_start(lnum) self.start_counts[lnum] = (self.start_counts[lnum] or 0) + 1 end ---@package ---@param lnum integer function FoldInfo:add_stop(lnum) self.stop_counts[lnum] = (self.stop_counts[lnum] or 0) + 1 end ---@package ---@param lnum integer ---@return integer function FoldInfo:get_start(lnum) return self.start_counts[lnum] or 0 end ---@package ---@param lnum integer ---@return integer function FoldInfo:get_stop(lnum) return self.stop_counts[lnum] or 0 end local function trim_level(level) local max_fold_level = vim.wo.foldnestmax if level > max_fold_level then return max_fold_level end return level end --- If a parser doesn't have any ranges explicitly set, treesitter will --- return a range with end_row and end_bytes with a value of UINT32_MAX, --- so clip end_row to the max buffer line. --- --- TODO(lewis6991): Handle this generally --- --- @param bufnr integer --- @param erow integer? --- @return integer local function normalise_erow(bufnr, erow) local max_erow = api.nvim_buf_line_count(bufnr) - 1 return math.min(erow or max_erow, max_erow) end -- TODO(lewis6991): Setup a decor provider so injections folds can be parsed -- as the window is redrawn ---@param bufnr integer ---@param info TS.FoldInfo ---@param srow integer? ---@param erow integer? ---@param parse_injections? boolean local function get_folds_levels(bufnr, info, srow, erow, parse_injections) srow = srow or 0 erow = normalise_erow(bufnr, erow) info:invalidate_range(srow, erow) local prev_start = -1 local prev_stop = -1 local parser = ts.get_parser(bufnr) parser:parse(parse_injections and { srow, erow } or nil) parser:for_each_tree(function(tree, ltree) local query = ts.query.get(ltree:lang(), 'folds') if not query then return end -- erow in query is end-exclusive local q_erow = erow and erow + 1 or -1 for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow, q_erow) do if query.captures[id] == 'fold' then local range = ts.get_range(node, bufnr, metadata[id]) local start, _, stop, stop_col = Range.unpack4(range) if stop_col == 0 then stop = stop - 1 end local fold_length = stop - start + 1 -- Fold only multiline nodes that are not exactly the same as previously met folds -- Checking against just the previously found fold is sufficient if nodes -- are returned in preorder or postorder when traversing tree if fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop) then info:add_start(start + 1) info:add_stop(stop + 1) prev_start = start prev_stop = stop end end end end) local current_level = info.levels0[srow] or 0 -- We now have the list of fold opening and closing, fill the gaps and mark where fold start for lnum = srow + 1, erow + 1 do local last_trimmed_level = trim_level(current_level) current_level = current_level + info:get_start(lnum) info.levels0[lnum] = current_level local trimmed_level = trim_level(current_level) current_level = current_level - info:get_stop(lnum) -- Determine if it's the start/end of a fold -- NB: vim's fold-expr interface does not have a mechanism to indicate that -- two (or more) folds start at this line, so it cannot distinguish between -- ( \n ( \n )) \n (( \n ) \n ) -- versus -- ( \n ( \n ) \n ( \n ) \n ) -- If it did have such a mechanism, (trimmed_level - last_trimmed_level) -- would be the correct number of starts to pass on. local prefix = '' if trimmed_level - last_trimmed_level > 0 then prefix = '>' end info.levels[lnum] = prefix .. tostring(trimmed_level) end end local M = {} ---@type table local foldinfos = {} local group = api.nvim_create_augroup('treesitter/fold', {}) --- Update the folds in the windows that contain the buffer and use expr foldmethod (assuming that --- the user doesn't use different foldexpr for the same buffer). --- --- Nvim usually automatically updates folds when text changes, but it doesn't work here because --- FoldInfo update is scheduled. So we do it manually. local function foldupdate(bufnr) local function do_update() for _, win in ipairs(vim.fn.win_findbuf(bufnr)) do api.nvim_win_call(win, function() if vim.wo.foldmethod == 'expr' then vim._foldupdate() end end) end end if api.nvim_get_mode().mode == 'i' then -- foldUpdate() is guarded in insert mode. So update folds on InsertLeave if #(api.nvim_get_autocmds({ group = group, buffer = bufnr, })) > 0 then return end api.nvim_create_autocmd('InsertLeave', { group = group, buffer = bufnr, once = true, callback = do_update, }) return end do_update() end --- Schedule a function only if bufnr is loaded. --- We schedule fold level computation for the following reasons: --- * queries seem to use the old buffer state in on_bytes for some unknown reason; --- * to avoid textlock; --- * to avoid infinite recursion: --- get_folds_levels → parse → _do_callback → on_changedtree → get_folds_levels. ---@param bufnr integer ---@param fn function local function schedule_if_loaded(bufnr, fn) vim.schedule(function() if not api.nvim_buf_is_loaded(bufnr) then return end fn() end) end ---@param bufnr integer ---@param foldinfo TS.FoldInfo ---@param tree_changes Range4[] local function on_changedtree(bufnr, foldinfo, tree_changes) schedule_if_loaded(bufnr, function() for _, change in ipairs(tree_changes) do local srow, _, erow = Range.unpack4(change) get_folds_levels(bufnr, foldinfo, srow, erow) end if #tree_changes > 0 then foldupdate(bufnr) end end) end ---@param bufnr integer ---@param foldinfo TS.FoldInfo ---@param start_row integer ---@param old_row integer ---@param new_row integer local function on_bytes(bufnr, foldinfo, start_row, old_row, new_row) local end_row_old = start_row + old_row local end_row_new = start_row + new_row if new_row ~= old_row then if new_row < old_row then foldinfo:remove_range(end_row_new, end_row_old) else foldinfo:add_range(start_row, end_row_new) end schedule_if_loaded(bufnr, function() get_folds_levels(bufnr, foldinfo, start_row, end_row_new) foldupdate(bufnr) end) end end ---@package ---@param lnum integer|nil ---@return string function M.foldexpr(lnum) lnum = lnum or vim.v.lnum local bufnr = api.nvim_get_current_buf() local parser = vim.F.npcall(ts.get_parser, bufnr) if not parser then return '0' end if not foldinfos[bufnr] then foldinfos[bufnr] = FoldInfo.new() get_folds_levels(bufnr, foldinfos[bufnr]) parser:register_cbs({ on_changedtree = function(tree_changes) on_changedtree(bufnr, foldinfos[bufnr], tree_changes) end, on_bytes = function(_, _, start_row, _, _, old_row, _, _, new_row, _, _) on_bytes(bufnr, foldinfos[bufnr], start_row, old_row, new_row) end, on_detach = function() foldinfos[bufnr] = nil end, }) end return foldinfos[bufnr].levels[lnum] or '0' end ---@package ---@return { [1]: string, [2]: string[] }[]|string function M.foldtext() local foldstart = vim.v.foldstart local bufnr = api.nvim_get_current_buf() ---@type boolean, LanguageTree local ok, parser = pcall(ts.get_parser, bufnr) if not ok then return vim.fn.foldtext() end local query = ts.query.get(parser:lang(), 'highlights') if not query then return vim.fn.foldtext() end local tree = parser:parse({ foldstart - 1, foldstart })[1] local line = api.nvim_buf_get_lines(bufnr, foldstart - 1, foldstart, false)[1] if not line then return vim.fn.foldtext() end ---@type { [1]: string, [2]: string[], range: { [1]: integer, [2]: integer } }[] | { [1]: string, [2]: string[] }[] local result = {} local line_pos = 0 for id, node, metadata in query:iter_captures(tree:root(), 0, foldstart - 1, foldstart) do local name = query.captures[id] local start_row, start_col, end_row, end_col = node:range() local priority = tonumber(metadata.priority or vim.highlight.priorities.treesitter) if start_row == foldstart - 1 and end_row == foldstart - 1 then -- check for characters ignored by treesitter if start_col > line_pos then table.insert(result, { line:sub(line_pos + 1, start_col), {}, range = { line_pos, start_col }, }) end line_pos = end_col local text = line:sub(start_col + 1, end_col) table.insert(result, { text, { { '@' .. name, priority } }, range = { start_col, end_col } }) end end local i = 1 while i <= #result do -- find first capture that is not in current range and apply highlights on the way local j = i + 1 while j <= #result and result[j].range[1] >= result[i].range[1] and result[j].range[2] <= result[i].range[2] do for k, v in ipairs(result[i][2]) do if not vim.tbl_contains(result[j][2], v) then table.insert(result[j][2], k, v) end end j = j + 1 end -- remove the parent capture if it is split into children if j > i + 1 then table.remove(result, i) else -- highlights need to be sorted by priority, on equal prio, the deeper nested capture (earlier -- in list) should be considered higher prio if #result[i][2] > 1 then table.sort(result[i][2], function(a, b) return a[2] < b[2] end) end result[i][2] = vim.tbl_map(function(tbl) return tbl[1] end, result[i][2]) result[i] = { result[i][1], result[i][2] } i = i + 1 end end return result end return M