diff options
Diffstat (limited to 'runtime/lua/vim/treesitter/_fold.lua')
-rw-r--r-- | runtime/lua/vim/treesitter/_fold.lua | 43 |
1 files changed, 22 insertions, 21 deletions
diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua index 5c1cc06908..d5626d0391 100644 --- a/runtime/lua/vim/treesitter/_fold.lua +++ b/runtime/lua/vim/treesitter/_fold.lua @@ -5,8 +5,8 @@ local Range = require('vim.treesitter._range') local api = vim.api ---@class TS.FoldInfo ----@field levels table<integer,string> ----@field levels0 table<integer,integer> +---@field levels string[] the foldexpr value for each line +---@field levels0 integer[] the raw fold levels ---@field private start_counts table<integer,integer> ---@field private stop_counts table<integer,integer> local FoldInfo = {} @@ -24,13 +24,13 @@ end ---@package ---@param srow integer ----@param erow integer +---@param erow integer 0-indexed, exclusive function FoldInfo:invalidate_range(srow, erow) - for i = srow, erow do - self.start_counts[i + 1] = nil - self.stop_counts[i + 1] = nil - self.levels0[i + 1] = nil - self.levels[i + 1] = nil + for i = srow + 1, erow do + self.start_counts[i] = nil + self.stop_counts[i] = nil + self.levels0[i] = nil + self.levels[i] = nil end end @@ -55,7 +55,7 @@ end ---@package ---@param srow integer ----@param erow integer +---@param erow integer 0-indexed, exclusive function FoldInfo:remove_range(srow, erow) list_remove(self.levels, srow + 1, erow) list_remove(self.levels0, srow + 1, erow) @@ -91,7 +91,7 @@ end ---@package ---@param srow integer ----@param erow integer +---@param erow integer 0-indexed, exclusive function FoldInfo:add_range(srow, erow) list_insert(self.levels, srow + 1, erow, '-1') list_insert(self.levels0, srow + 1, erow, -1) @@ -140,10 +140,10 @@ end --- TODO(lewis6991): Handle this generally --- --- @param bufnr integer ---- @param erow integer? +--- @param erow integer? 0-indexed, exclusive --- @return integer local function normalise_erow(bufnr, erow) - local max_erow = api.nvim_buf_line_count(bufnr) - 1 + local max_erow = api.nvim_buf_line_count(bufnr) return math.min(erow or max_erow, max_erow) end @@ -152,7 +152,7 @@ end ---@param bufnr integer ---@param info TS.FoldInfo ---@param srow integer? ----@param erow integer? +---@param erow integer? 0-indexed, exclusive ---@param parse_injections? boolean local function get_folds_levels(bufnr, info, srow, erow, parse_injections) srow = srow or 0 @@ -173,10 +173,7 @@ local function get_folds_levels(bufnr, info, srow, erow, parse_injections) return end - -- erow in query is end-exclusive - local q_erow = erow and erow + 1 or -1 - - for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow, q_erow) do + for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow, erow) do if query.captures[id] == 'fold' then local range = ts.get_range(node, bufnr, metadata[id]) local start, _, stop, stop_col = Range.unpack4(range) @@ -205,7 +202,7 @@ local function get_folds_levels(bufnr, info, srow, erow, parse_injections) local current_level = info.levels0[srow] or 0 -- We now have the list of fold opening and closing, fill the gaps and mark where fold start - for lnum = srow + 1, erow + 1 do + for lnum = srow + 1, erow do local last_trimmed_level = trim_level(current_level) current_level = current_level + info:get_start(lnum) info.levels0[lnum] = current_level @@ -296,7 +293,10 @@ end local function on_changedtree(bufnr, foldinfo, tree_changes) schedule_if_loaded(bufnr, function() for _, change in ipairs(tree_changes) do - local srow, _, erow = Range.unpack4(change) + local srow, _, erow, ecol = Range.unpack4(change) + if ecol > 0 then + erow = erow + 1 + end get_folds_levels(bufnr, foldinfo, srow, erow) end if #tree_changes > 0 then @@ -311,8 +311,9 @@ end ---@param old_row integer ---@param new_row integer local function on_bytes(bufnr, foldinfo, start_row, old_row, new_row) - local end_row_old = start_row + old_row - local end_row_new = start_row + new_row + -- extend the end to fully include the range + local end_row_old = start_row + old_row + 1 + local end_row_new = start_row + new_row + 1 if new_row ~= old_row then if new_row < old_row then |