aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter/_fold.lua
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/lua/vim/treesitter/_fold.lua')
-rw-r--r--runtime/lua/vim/treesitter/_fold.lua115
1 files changed, 95 insertions, 20 deletions
diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua
index 7df93d1b2e..f6425d7cb9 100644
--- a/runtime/lua/vim/treesitter/_fold.lua
+++ b/runtime/lua/vim/treesitter/_fold.lua
@@ -1,3 +1,5 @@
+local ts = vim.treesitter
+
local Range = require('vim.treesitter._range')
local api = vim.api
@@ -32,15 +34,58 @@ function FoldInfo:invalidate_range(srow, erow)
end
end
+--- Efficiently remove items from middle of a list a list.
+---
+--- Calling table.remove() in a loop will re-index the tail of the table on
+--- every iteration, instead this function will re-index the table exactly
+--- once.
+---
+--- Based on https://stackoverflow.com/questions/12394841/safely-remove-items-from-an-array-table-while-iterating/53038524#53038524
+---
+---@param t any[]
+---@param first integer
+---@param last integer
+local function list_remove(t, first, last)
+ local n = #t
+ for i = 0, n - first do
+ t[first + i] = t[last + 1 + i]
+ t[last + 1 + i] = nil
+ end
+end
+
---@package
---@param srow integer
---@param erow integer
function FoldInfo:remove_range(srow, erow)
- for i = erow - 1, srow, -1 do
- table.remove(self.levels, i + 1)
- table.remove(self.levels0, i + 1)
- table.remove(self.start_counts, i + 1)
- table.remove(self.stop_counts, i + 1)
+ list_remove(self.levels, srow + 1, erow)
+ list_remove(self.levels0, srow + 1, erow)
+ list_remove(self.start_counts, srow + 1, erow)
+ list_remove(self.stop_counts, srow + 1, erow)
+end
+
+--- Efficiently insert items into the middle of a list.
+---
+--- Calling table.insert() in a loop will re-index the tail of the table on
+--- every iteration, instead this function will re-index the table exactly
+--- once.
+---
+--- Based on https://stackoverflow.com/questions/12394841/safely-remove-items-from-an-array-table-while-iterating/53038524#53038524
+---
+---@param t any[]
+---@param first integer
+---@param last integer
+---@param v any
+local function list_insert(t, first, last, v)
+ local n = #t
+
+ -- Shift table forward
+ for i = n - first, 0, -1 do
+ t[last + 1 + i] = t[first + i]
+ end
+
+ -- Fill in new values
+ for i = first, last do
+ t[i] = v
end
end
@@ -48,12 +93,10 @@ end
---@param srow integer
---@param erow integer
function FoldInfo:add_range(srow, erow)
- for i = srow, erow - 1 do
- table.insert(self.levels, i + 1, '-1')
- table.insert(self.levels0, i + 1, -1)
- table.insert(self.start_counts, i + 1, nil)
- table.insert(self.stop_counts, i + 1, nil)
- end
+ list_insert(self.levels, srow + 1, erow, '-1')
+ list_insert(self.levels0, srow + 1, erow, -1)
+ list_insert(self.start_counts, srow + 1, erow, nil)
+ list_insert(self.stop_counts, srow + 1, erow, nil)
end
---@package
@@ -90,21 +133,41 @@ local function trim_level(level)
return level
end
+--- If a parser doesn't have any ranges explicitly set, treesitter will
+--- return a range with end_row and end_bytes with a value of UINT32_MAX,
+--- so clip end_row to the max buffer line.
+---
+--- TODO(lewis6991): Handle this generally
+---
+--- @param bufnr integer
+--- @param erow integer?
+--- @return integer
+local function normalise_erow(bufnr, erow)
+ local max_erow = api.nvim_buf_line_count(bufnr) - 1
+ return math.min(erow or max_erow, max_erow)
+end
+
---@param bufnr integer
---@param info TS.FoldInfo
---@param srow integer?
---@param erow integer?
local function get_folds_levels(bufnr, info, srow, erow)
srow = srow or 0
- erow = erow or api.nvim_buf_line_count(bufnr)
+ erow = normalise_erow(bufnr, erow)
info:invalidate_range(srow, erow)
local prev_start = -1
local prev_stop = -1
- vim.treesitter.get_parser(bufnr):for_each_tree(function(tree, ltree)
- local query = vim.treesitter.query.get(ltree:lang(), 'folds')
+ local parser = ts.get_parser(bufnr)
+
+ if not parser:is_valid() then
+ return
+ end
+
+ parser:for_each_tree(function(tree, ltree)
+ local query = ts.query.get(ltree:lang(), 'folds')
if not query then
return
end
@@ -112,9 +175,9 @@ local function get_folds_levels(bufnr, info, srow, erow)
-- erow in query is end-exclusive
local q_erow = erow and erow + 1 or -1
- for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow or 0, q_erow) do
+ for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow, q_erow) do
if query.captures[id] == 'fold' then
- local range = vim.treesitter.get_range(node, bufnr, metadata[id])
+ local range = ts.get_range(node, bufnr, metadata[id])
local start, _, stop, stop_col = Range.unpack4(range)
if stop_col == 0 then
@@ -184,13 +247,25 @@ local function recompute_folds()
vim._foldupdate()
end
+--- Schedule a function only if bufnr is loaded
+---@param bufnr integer
+---@param fn function
+local function schedule_if_loaded(bufnr, fn)
+ vim.schedule(function()
+ if not api.nvim_buf_is_loaded(bufnr) then
+ return
+ end
+ fn()
+ end)
+end
+
---@param bufnr integer
---@param foldinfo TS.FoldInfo
---@param tree_changes Range4[]
local function on_changedtree(bufnr, foldinfo, tree_changes)
-- For some reason, queries seem to use the old buffer state in on_bytes.
-- Get around this by scheduling and manually updating folds.
- vim.schedule(function()
+ schedule_if_loaded(bufnr, function()
for _, change in ipairs(tree_changes) do
local srow, _, erow = Range.unpack4(change)
get_folds_levels(bufnr, foldinfo, srow, erow)
@@ -212,7 +287,7 @@ local function on_bytes(bufnr, foldinfo, start_row, old_row, new_row)
foldinfo:remove_range(end_row_new, end_row_old)
elseif new_row > old_row then
foldinfo:add_range(start_row, end_row_new)
- vim.schedule(function()
+ schedule_if_loaded(bufnr, function()
get_folds_levels(bufnr, foldinfo, start_row, end_row_new)
recompute_folds()
end)
@@ -226,7 +301,7 @@ function M.foldexpr(lnum)
lnum = lnum or vim.v.lnum
local bufnr = api.nvim_get_current_buf()
- if not vim.treesitter._has_parser(bufnr) or not lnum then
+ if not ts._has_parser(bufnr) or not lnum then
return '0'
end
@@ -234,7 +309,7 @@ function M.foldexpr(lnum)
foldinfos[bufnr] = FoldInfo.new()
get_folds_levels(bufnr, foldinfos[bufnr])
- local parser = vim.treesitter.get_parser(bufnr)
+ local parser = ts.get_parser(bufnr)
parser:register_cbs({
on_changedtree = function(tree_changes)
on_changedtree(bufnr, foldinfos[bufnr], tree_changes)