aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter/_fold.lua
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/lua/vim/treesitter/_fold.lua')
-rw-r--r--runtime/lua/vim/treesitter/_fold.lua300
1 files changed, 122 insertions, 178 deletions
diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua
index 5c1cc06908..d96cc966de 100644
--- a/runtime/lua/vim/treesitter/_fold.lua
+++ b/runtime/lua/vim/treesitter/_fold.lua
@@ -5,35 +5,20 @@ local Range = require('vim.treesitter._range')
local api = vim.api
---@class TS.FoldInfo
----@field levels table<integer,string>
----@field levels0 table<integer,integer>
----@field private start_counts table<integer,integer>
----@field private stop_counts table<integer,integer>
+---@field levels string[] the foldexpr result for each line
+---@field levels0 integer[] the raw fold levels
+---@field edits? {[1]: integer, [2]: integer} line range edited since the last invocation of the callback scheduled in on_bytes. 0-indexed, end-exclusive.
local FoldInfo = {}
FoldInfo.__index = FoldInfo
---@private
function FoldInfo.new()
return setmetatable({
- start_counts = {},
- stop_counts = {},
levels0 = {},
levels = {},
}, FoldInfo)
end
----@package
----@param srow integer
----@param erow integer
-function FoldInfo:invalidate_range(srow, erow)
- for i = srow, erow do
- self.start_counts[i + 1] = nil
- self.stop_counts[i + 1] = nil
- self.levels0[i + 1] = nil
- self.levels[i + 1] = nil
- end
-end
-
--- Efficiently remove items from middle of a list a list.
---
--- Calling table.remove() in a loop will re-index the tail of the table on
@@ -55,12 +40,10 @@ end
---@package
---@param srow integer
----@param erow integer
+---@param erow integer 0-indexed, exclusive
function FoldInfo:remove_range(srow, erow)
list_remove(self.levels, srow + 1, erow)
list_remove(self.levels0, srow + 1, erow)
- list_remove(self.start_counts, srow + 1, erow)
- list_remove(self.stop_counts, srow + 1, erow)
end
--- Efficiently insert items into the middle of a list.
@@ -91,46 +74,37 @@ end
---@package
---@param srow integer
----@param erow integer
+---@param erow integer 0-indexed, exclusive
function FoldInfo:add_range(srow, erow)
- list_insert(self.levels, srow + 1, erow, '-1')
+ list_insert(self.levels, srow + 1, erow, '=')
list_insert(self.levels0, srow + 1, erow, -1)
- list_insert(self.start_counts, srow + 1, erow, nil)
- list_insert(self.stop_counts, srow + 1, erow, nil)
end
---@package
----@param lnum integer
-function FoldInfo:add_start(lnum)
- self.start_counts[lnum] = (self.start_counts[lnum] or 0) + 1
-end
-
----@package
----@param lnum integer
-function FoldInfo:add_stop(lnum)
- self.stop_counts[lnum] = (self.stop_counts[lnum] or 0) + 1
-end
-
----@package
----@param lnum integer
----@return integer
-function FoldInfo:get_start(lnum)
- return self.start_counts[lnum] or 0
+---@param srow integer
+---@param erow_old integer
+---@param erow_new integer 0-indexed, exclusive
+function FoldInfo:edit_range(srow, erow_old, erow_new)
+ if self.edits then
+ self.edits[1] = math.min(srow, self.edits[1])
+ if erow_old <= self.edits[2] then
+ self.edits[2] = self.edits[2] + (erow_new - erow_old)
+ end
+ self.edits[2] = math.max(self.edits[2], erow_new)
+ else
+ self.edits = { srow, erow_new }
+ end
end
---@package
----@param lnum integer
----@return integer
-function FoldInfo:get_stop(lnum)
- return self.stop_counts[lnum] or 0
-end
-
-local function trim_level(level)
- local max_fold_level = vim.wo.foldnestmax
- if level > max_fold_level then
- return max_fold_level
+---@return integer? srow
+---@return integer? erow 0-indexed, exclusive
+function FoldInfo:flush_edit()
+ if self.edits then
+ local srow, erow = self.edits[1], self.edits[2]
+ self.edits = nil
+ return srow, erow
end
- return level
end
--- If a parser doesn't have any ranges explicitly set, treesitter will
@@ -140,10 +114,10 @@ end
--- TODO(lewis6991): Handle this generally
---
--- @param bufnr integer
---- @param erow integer?
+--- @param erow integer? 0-indexed, exclusive
--- @return integer
local function normalise_erow(bufnr, erow)
- local max_erow = api.nvim_buf_line_count(bufnr) - 1
+ local max_erow = api.nvim_buf_line_count(bufnr)
return math.min(erow or max_erow, max_erow)
end
@@ -152,31 +126,30 @@ end
---@param bufnr integer
---@param info TS.FoldInfo
---@param srow integer?
----@param erow integer?
+---@param erow integer? 0-indexed, exclusive
---@param parse_injections? boolean
local function get_folds_levels(bufnr, info, srow, erow, parse_injections)
srow = srow or 0
erow = normalise_erow(bufnr, erow)
- info:invalidate_range(srow, erow)
-
- local prev_start = -1
- local prev_stop = -1
-
local parser = ts.get_parser(bufnr)
parser:parse(parse_injections and { srow, erow } or nil)
+ local enter_counts = {} ---@type table<integer, integer>
+ local leave_counts = {} ---@type table<integer, integer>
+ local prev_start = -1
+ local prev_stop = -1
+
parser:for_each_tree(function(tree, ltree)
local query = ts.query.get(ltree:lang(), 'folds')
if not query then
return
end
- -- erow in query is end-exclusive
- local q_erow = erow and erow + 1 or -1
-
- for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow, q_erow) do
+ -- Collect folds starting from srow - 1, because we should first subtract the folds that end at
+ -- srow - 1 from the level of srow - 1 to get accurate level of srow.
+ for id, node, metadata in query:iter_captures(tree:root(), bufnr, math.max(srow - 1, 0), erow) do
if query.captures[id] == 'fold' then
local range = ts.get_range(node, bufnr, metadata[id])
local start, _, stop, stop_col = Range.unpack4(range)
@@ -193,8 +166,8 @@ local function get_folds_levels(bufnr, info, srow, erow, parse_injections)
if
fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop)
then
- info:add_start(start + 1)
- info:add_stop(stop + 1)
+ enter_counts[start + 1] = (enter_counts[start + 1] or 0) + 1
+ leave_counts[stop + 1] = (leave_counts[stop + 1] or 0) + 1
prev_start = start
prev_stop = stop
end
@@ -202,16 +175,15 @@ local function get_folds_levels(bufnr, info, srow, erow, parse_injections)
end
end)
- local current_level = info.levels0[srow] or 0
+ local nestmax = vim.wo.foldnestmax
+ local level0_prev = info.levels0[srow] or 0
+ local leave_prev = leave_counts[srow] or 0
-- We now have the list of fold opening and closing, fill the gaps and mark where fold start
- for lnum = srow + 1, erow + 1 do
- local last_trimmed_level = trim_level(current_level)
- current_level = current_level + info:get_start(lnum)
- info.levels0[lnum] = current_level
-
- local trimmed_level = trim_level(current_level)
- current_level = current_level - info:get_stop(lnum)
+ for lnum = srow + 1, erow do
+ local enter_line = enter_counts[lnum] or 0
+ local leave_line = leave_counts[lnum] or 0
+ local level0 = level0_prev - leave_prev + enter_line
-- Determine if it's the start/end of a fold
-- NB: vim's fold-expr interface does not have a mechanism to indicate that
@@ -219,14 +191,36 @@ local function get_folds_levels(bufnr, info, srow, erow, parse_injections)
-- ( \n ( \n )) \n (( \n ) \n )
-- versus
-- ( \n ( \n ) \n ( \n ) \n )
- -- If it did have such a mechanism, (trimmed_level - last_trimmed_level)
+ -- Both are represented by ['>1', '>2', '2', '>2', '2', '1'], and
+ -- vim interprets as the second case.
+ -- If it did have such a mechanism, (clamped - clamped_prev)
-- would be the correct number of starts to pass on.
+ local adjusted = level0 ---@type integer
local prefix = ''
- if trimmed_level - last_trimmed_level > 0 then
+ if enter_line > 0 then
prefix = '>'
+ if leave_line > 0 then
+ -- If this line ends a fold f1 and starts a fold f2, then move f1's end to the previous line
+ -- so that f2 gets the correct level on this line. This may reduce the size of f1 below
+ -- foldminlines, but we don't handle it for simplicity.
+ adjusted = level0 - leave_line
+ leave_line = 0
+ end
+ end
+
+ -- Clamp at foldnestmax.
+ local clamped = adjusted
+ if adjusted > nestmax then
+ prefix = ''
+ clamped = nestmax
end
- info.levels[lnum] = prefix .. tostring(trimmed_level)
+ -- Record the "real" level, so that it can be used as "base" of later get_folds_levels().
+ info.levels0[lnum] = adjusted
+ info.levels[lnum] = prefix .. tostring(clamped)
+
+ leave_prev = leave_line
+ level0_prev = adjusted
end
end
@@ -296,8 +290,12 @@ end
local function on_changedtree(bufnr, foldinfo, tree_changes)
schedule_if_loaded(bufnr, function()
for _, change in ipairs(tree_changes) do
- local srow, _, erow = Range.unpack4(change)
- get_folds_levels(bufnr, foldinfo, srow, erow)
+ local srow, _, erow, ecol = Range.unpack4(change)
+ if ecol > 0 then
+ erow = erow + 1
+ end
+ -- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit.
+ get_folds_levels(bufnr, foldinfo, math.max(srow - vim.wo.foldminlines, 0), erow)
end
if #tree_changes > 0 then
foldupdate(bufnr)
@@ -309,19 +307,46 @@ end
---@param foldinfo TS.FoldInfo
---@param start_row integer
---@param old_row integer
+---@param old_col integer
---@param new_row integer
-local function on_bytes(bufnr, foldinfo, start_row, old_row, new_row)
- local end_row_old = start_row + old_row
- local end_row_new = start_row + new_row
+---@param new_col integer
+local function on_bytes(bufnr, foldinfo, start_row, start_col, old_row, old_col, new_row, new_col)
+ -- extend the end to fully include the range
+ local end_row_old = start_row + old_row + 1
+ local end_row_new = start_row + new_row + 1
if new_row ~= old_row then
+ -- foldexpr can be evaluated before the scheduled callback is invoked. So it may observe the
+ -- outdated levels, which may spuriously open the folds that didn't change. So we should shift
+ -- folds as accurately as possible. For this to be perfectly accurate, we should track the
+ -- actual TSNodes that account for each fold, and compare the node's range with the edited
+ -- range. But for simplicity, we just check whether the start row is completely removed (e.g.,
+ -- `dd`) or shifted (e.g., `o`).
if new_row < old_row then
- foldinfo:remove_range(end_row_new, end_row_old)
+ if start_col == 0 and new_row == 0 and new_col == 0 then
+ foldinfo:remove_range(start_row, start_row + (end_row_old - end_row_new))
+ else
+ foldinfo:remove_range(end_row_new, end_row_old)
+ end
else
- foldinfo:add_range(start_row, end_row_new)
+ if start_col == 0 and old_row == 0 and old_col == 0 then
+ foldinfo:add_range(start_row, start_row + (end_row_new - end_row_old))
+ else
+ foldinfo:add_range(end_row_old, end_row_new)
+ end
end
+ foldinfo:edit_range(start_row, end_row_old, end_row_new)
+
+ -- This callback must not use on_bytes arguments, because they can be outdated when the callback
+ -- is invoked. For example, `J` with non-zero count triggers multiple on_bytes before executing
+ -- the scheduled callback. So we should collect the edits.
schedule_if_loaded(bufnr, function()
- get_folds_levels(bufnr, foldinfo, start_row, end_row_new)
+ local srow, erow = foldinfo:flush_edit()
+ if not srow then
+ return
+ end
+ -- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit.
+ get_folds_levels(bufnr, foldinfo, math.max(srow - vim.wo.foldminlines, 0), erow)
foldupdate(bufnr)
end)
end
@@ -348,8 +373,8 @@ function M.foldexpr(lnum)
on_changedtree(bufnr, foldinfos[bufnr], tree_changes)
end,
- on_bytes = function(_, _, start_row, _, _, old_row, _, _, new_row, _, _)
- on_bytes(bufnr, foldinfos[bufnr], start_row, old_row, new_row)
+ on_bytes = function(_, _, start_row, start_col, _, old_row, old_col, _, new_row, new_col, _)
+ on_bytes(bufnr, foldinfos[bufnr], start_row, start_col, old_row, old_col, new_row, new_col)
end,
on_detach = function()
@@ -361,96 +386,15 @@ function M.foldexpr(lnum)
return foldinfos[bufnr].levels[lnum] or '0'
end
----@package
----@return { [1]: string, [2]: string[] }[]|string
-function M.foldtext()
- local foldstart = vim.v.foldstart
- local bufnr = api.nvim_get_current_buf()
-
- ---@type boolean, LanguageTree
- local ok, parser = pcall(ts.get_parser, bufnr)
- if not ok then
- return vim.fn.foldtext()
- end
-
- local query = ts.query.get(parser:lang(), 'highlights')
- if not query then
- return vim.fn.foldtext()
- end
-
- local tree = parser:parse({ foldstart - 1, foldstart })[1]
-
- local line = api.nvim_buf_get_lines(bufnr, foldstart - 1, foldstart, false)[1]
- if not line then
- return vim.fn.foldtext()
- end
-
- ---@type { [1]: string, [2]: string[], range: { [1]: integer, [2]: integer } }[] | { [1]: string, [2]: string[] }[]
- local result = {}
-
- local line_pos = 0
-
- for id, node, metadata in query:iter_captures(tree:root(), 0, foldstart - 1, foldstart) do
- local name = query.captures[id]
- local start_row, start_col, end_row, end_col = node:range()
-
- local priority = tonumber(metadata.priority or vim.highlight.priorities.treesitter)
-
- if start_row == foldstart - 1 and end_row == foldstart - 1 then
- -- check for characters ignored by treesitter
- if start_col > line_pos then
- table.insert(result, {
- line:sub(line_pos + 1, start_col),
- {},
- range = { line_pos, start_col },
- })
- end
- line_pos = end_col
-
- local text = line:sub(start_col + 1, end_col)
- table.insert(result, { text, { { '@' .. name, priority } }, range = { start_col, end_col } })
- end
- end
-
- local i = 1
- while i <= #result do
- -- find first capture that is not in current range and apply highlights on the way
- local j = i + 1
- while
- j <= #result
- and result[j].range[1] >= result[i].range[1]
- and result[j].range[2] <= result[i].range[2]
- do
- for k, v in ipairs(result[i][2]) do
- if not vim.tbl_contains(result[j][2], v) then
- table.insert(result[j][2], k, v)
- end
- end
- j = j + 1
- end
-
- -- remove the parent capture if it is split into children
- if j > i + 1 then
- table.remove(result, i)
- else
- -- highlights need to be sorted by priority, on equal prio, the deeper nested capture (earlier
- -- in list) should be considered higher prio
- if #result[i][2] > 1 then
- table.sort(result[i][2], function(a, b)
- return a[2] < b[2]
- end)
- end
-
- result[i][2] = vim.tbl_map(function(tbl)
- return tbl[1]
- end, result[i][2])
- result[i] = { result[i][1], result[i][2] }
-
- i = i + 1
+api.nvim_create_autocmd('OptionSet', {
+ pattern = { 'foldminlines', 'foldnestmax' },
+ desc = 'Refresh treesitter folds',
+ callback = function()
+ for _, bufnr in ipairs(vim.tbl_keys(foldinfos)) do
+ foldinfos[bufnr] = FoldInfo.new()
+ get_folds_levels(bufnr, foldinfos[bufnr])
+ foldupdate(bufnr)
end
- end
-
- return result
-end
-
+ end,
+})
return M