diff options
-rw-r--r-- | runtime/lua/vim/treesitter.lua | 25 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/_fold.lua | 290 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/language.lua | 14 | ||||
-rw-r--r-- | src/nvim/lua/stdlib.c | 36 | ||||
-rw-r--r-- | src/nvim/lua/treesitter.c | 4 | ||||
-rw-r--r-- | test/functional/treesitter/language_spec.lua | 5 | ||||
-rw-r--r-- | test/functional/treesitter/parser_spec.lua | 55 |
7 files changed, 284 insertions, 145 deletions
diff --git a/runtime/lua/vim/treesitter.lua b/runtime/lua/vim/treesitter.lua index ab9f8968c8..5723cce563 100644 --- a/runtime/lua/vim/treesitter.lua +++ b/runtime/lua/vim/treesitter.lua @@ -99,13 +99,28 @@ function M.get_parser(bufnr, lang, opts) if bufnr == nil or bufnr == 0 then bufnr = a.nvim_get_current_buf() end + if lang == nil then local ft = vim.bo[bufnr].filetype - lang = language.get_lang(ft) or ft - -- TODO(lewis6991): we should error here and not default to ft - -- if not lang then - -- error(string.format('filetype %s of buffer %d is not associated with any lang', ft, bufnr)) - -- end + if ft ~= '' then + lang = language.get_lang(ft) or ft + -- TODO(lewis6991): we should error here and not default to ft + -- if not lang then + -- error(string.format('filetype %s of buffer %d is not associated with any lang', ft, bufnr)) + -- end + else + if parsers[bufnr] then + return parsers[bufnr] + end + error( + string.format( + 'There is no parser available for buffer %d and one could not be' + .. ' created because lang could not be determined. Either pass lang' + .. ' or set the buffer filetype', + bufnr + ) + ) + end end if parsers[bufnr] == nil or parsers[bufnr]:lang() ~= lang then diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua index a66cc6d543..435cb9fdb6 100644 --- a/runtime/lua/vim/treesitter/_fold.lua +++ b/runtime/lua/vim/treesitter/_fold.lua @@ -1,139 +1,157 @@ +local Range = require('vim.treesitter._range') + local api = vim.api -local M = {} +---@class FoldInfo +---@field levels table<integer,string> +---@field levels0 table<integer,integer> +---@field private start_counts table<integer,integer> +---@field private stop_counts table<integer,integer> +local FoldInfo = {} +FoldInfo.__index = FoldInfo ---- Memoizes a function based on the buffer tick of the provided bufnr. ---- The cache entry is cleared when the buffer is detached to avoid memory leaks. ----@generic F: function ----@param fn F fn to memoize, taking the bufnr as first argument ----@return F -local function memoize_by_changedtick(fn) - ---@type table<integer,{result:any,last_tick:integer}> - local cache = {} - - ---@param bufnr integer - return function(bufnr, ...) - local tick = api.nvim_buf_get_changedtick(bufnr) - - if cache[bufnr] then - if cache[bufnr].last_tick == tick then - return cache[bufnr].result - end - else - local function detach_handler() - cache[bufnr] = nil - end +function FoldInfo.new() + return setmetatable({ + start_counts = {}, + stop_counts = {}, + levels0 = {}, + levels = {}, + }, FoldInfo) +end - -- Clean up logic only! - api.nvim_buf_attach(bufnr, false, { - on_detach = detach_handler, - on_reload = detach_handler, - }) - end +---@param srow integer +---@param erow integer +function FoldInfo:invalidate_range(srow, erow) + for i = srow, erow do + self.start_counts[i + 1] = nil + self.stop_counts[i + 1] = nil + self.levels0[i + 1] = nil + self.levels[i + 1] = nil + end +end - cache[bufnr] = { - result = fn(bufnr, ...), - last_tick = tick, - } +---@param srow integer +---@param erow integer +function FoldInfo:remove_range(srow, erow) + for i = erow - 1, srow, -1 do + table.remove(self.levels, i + 1) + table.remove(self.levels0, i + 1) + table.remove(self.start_counts, i + 1) + table.remove(self.stop_counts, i + 1) + end +end - return cache[bufnr].result +---@param srow integer +---@param erow integer +function FoldInfo:add_range(srow, erow) + for i = srow, erow - 1 do + table.insert(self.levels, i + 1, '-1') + table.insert(self.levels0, i + 1, -1) + table.insert(self.start_counts, i + 1, nil) + table.insert(self.stop_counts, i + 1, nil) end end ----@param bufnr integer ----@param capture string ----@param query_name string ----@param callback fun(id: integer, node:TSNode, metadata: TSMetadata) -local function iter_matches_with_capture(bufnr, capture, query_name, callback) - local parser = vim.treesitter.get_parser(bufnr) +---@param lnum integer +function FoldInfo:add_start(lnum) + self.start_counts[lnum] = (self.start_counts[lnum] or 0) + 1 +end - if not parser then - return - end +---@param lnum integer +function FoldInfo:add_stop(lnum) + self.stop_counts[lnum] = (self.stop_counts[lnum] or 0) + 1 +end - parser:for_each_tree(function(tree, lang_tree) - local lang = lang_tree:lang() - local query = vim.treesitter.query.get_query(lang, query_name) - if query then - local root = tree:root() - local start, _, stop = root:range() - for _, match, metadata in query:iter_matches(root, bufnr, start, stop) do - for id, node in pairs(match) do - if query.captures[id] == capture then - callback(id, node, metadata) - end - end - end - end - end) +---@param lnum integer +---@return integer +function FoldInfo:get_start(lnum) + return self.start_counts[lnum] or 0 +end + +---@param lnum integer +---@return integer +function FoldInfo:get_stop(lnum) + return self.stop_counts[lnum] or 0 end ---@private --- TODO(lewis6991): copied from languagetree.lua. Consolidate ---@param node TSNode ----@param id integer ---@param metadata TSMetadata ----@return Range -local function get_range_from_metadata(node, id, metadata) - if metadata[id] and metadata[id].range then - return metadata[id].range --[[@as Range]] +---@return Range4 +local function get_range_from_metadata(node, metadata) + if metadata and metadata.range then + return metadata.range --[[@as Range4]] end return { node:range() } end --- This is cached on buf tick to avoid computing that multiple times --- Especially not for every line in the file when `zx` is hit ----@param bufnr integer ----@return table<integer,string> -local folds_levels = memoize_by_changedtick(function(bufnr) +local function trim_level(level) local max_fold_level = vim.wo.foldnestmax - local function trim_level(level) - if level > max_fold_level then - return max_fold_level - end - return level + if level > max_fold_level then + return max_fold_level end + return level +end - -- start..stop is an inclusive range - local start_counts = {} ---@type table<integer,integer> - local stop_counts = {} ---@type table<integer,integer> +---@param bufnr integer +---@param info FoldInfo +---@param srow integer? +---@param erow integer? +local function get_folds_levels(bufnr, info, srow, erow) + srow = srow or 0 + erow = erow or api.nvim_buf_line_count(bufnr) + + info:invalidate_range(srow, erow) local prev_start = -1 local prev_stop = -1 - local min_fold_lines = vim.wo.foldminlines + vim.treesitter.get_parser(bufnr):for_each_tree(function(tree, ltree) + local query = vim.treesitter.query.get_query(ltree:lang(), 'folds') + if not query then + return + end + + -- erow in query is end-exclusive + local q_erow = erow and erow + 1 or -1 - iter_matches_with_capture(bufnr, 'fold', 'folds', function(id, node, metadata) - local range = get_range_from_metadata(node, id, metadata) - local start, stop, stop_col = range[1], range[3], range[4] + for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow or 0, q_erow) do + if query.captures[id] == 'fold' then + local range = get_range_from_metadata(node, metadata[id]) + local start, _, stop, stop_col = Range.unpack4(range) - if stop_col == 0 then - stop = stop - 1 - end + if stop_col == 0 then + stop = stop - 1 + end - local fold_length = stop - start + 1 + local fold_length = stop - start + 1 - -- Fold only multiline nodes that are not exactly the same as previously met folds - -- Checking against just the previously found fold is sufficient if nodes - -- are returned in preorder or postorder when traversing tree - if fold_length > min_fold_lines and not (start == prev_start and stop == prev_stop) then - start_counts[start] = (start_counts[start] or 0) + 1 - stop_counts[stop] = (stop_counts[stop] or 0) + 1 - prev_start = start - prev_stop = stop + -- Fold only multiline nodes that are not exactly the same as previously met folds + -- Checking against just the previously found fold is sufficient if nodes + -- are returned in preorder or postorder when traversing tree + if + fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop) + then + info:add_start(start + 1) + info:add_stop(stop + 1) + prev_start = start + prev_stop = stop + end + end end end) - ---@type table<integer,string> - local levels = {} - local current_level = 0 + local current_level = info.levels0[srow] or 0 -- We now have the list of fold opening and closing, fill the gaps and mark where fold start - for lnum = 0, api.nvim_buf_line_count(bufnr) do + for lnum = srow + 1, erow + 1 do local last_trimmed_level = trim_level(current_level) - current_level = current_level + (start_counts[lnum] or 0) + current_level = current_level + info:get_start(lnum) + info.levels0[lnum] = current_level + local trimmed_level = trim_level(current_level) - current_level = current_level - (stop_counts[lnum] or 0) + current_level = current_level - info:get_stop(lnum) -- Determine if it's the start/end of a fold -- NB: vim's fold-expr interface does not have a mechanism to indicate that @@ -148,11 +166,61 @@ local folds_levels = memoize_by_changedtick(function(bufnr) prefix = '>' end - levels[lnum + 1] = prefix .. tostring(trimmed_level) + info.levels[lnum] = prefix .. tostring(trimmed_level) + end +end + +local M = {} + +---@type table<integer,FoldInfo> +local foldinfos = {} + +local function recompute_folds() + if api.nvim_get_mode().mode == 'i' then + -- foldUpdate() is guarded in insert mode. So update folds on InsertLeave + api.nvim_create_autocmd('InsertLeave', { + once = true, + callback = vim._foldupdate, + }) + return end - return levels -end) + vim._foldupdate() +end + +---@param bufnr integer +---@param foldinfo FoldInfo +---@param tree_changes Range4[] +local function on_changedtree(bufnr, foldinfo, tree_changes) + -- For some reason, queries seem to use the old buffer state in on_bytes. + -- Get around this by scheduling and manually updating folds. + vim.schedule(function() + for _, change in ipairs(tree_changes) do + local srow, _, erow = Range.unpack4(change) + get_folds_levels(bufnr, foldinfo, srow, erow) + end + recompute_folds() + end) +end + +---@param bufnr integer +---@param foldinfo FoldInfo +---@param start_row integer +---@param old_row integer +---@param new_row integer +local function on_bytes(bufnr, foldinfo, start_row, old_row, new_row) + local end_row_old = start_row + old_row + local end_row_new = start_row + new_row + if new_row < old_row then + foldinfo:remove_range(end_row_old, end_row_new) + elseif new_row > old_row then + foldinfo:add_range(start_row, end_row_new) + vim.schedule(function() + get_folds_levels(bufnr, foldinfo, start_row, end_row_new) + recompute_folds() + end) + end +end ---@param lnum integer|nil ---@return string @@ -165,9 +233,27 @@ function M.foldexpr(lnum) return '0' end - local levels = folds_levels(bufnr) or {} + if not foldinfos[bufnr] then + foldinfos[bufnr] = FoldInfo.new() + get_folds_levels(bufnr, foldinfos[bufnr]) + + local parser = vim.treesitter.get_parser(bufnr) + parser:register_cbs({ + on_changedtree = function(tree_changes) + on_changedtree(bufnr, foldinfos[bufnr], tree_changes) + end, + + on_bytes = function(_, _, start_row, _, _, old_row, _, _, new_row, _, _) + on_bytes(bufnr, foldinfos[bufnr], start_row, old_row, new_row) + end, + + on_detach = function() + foldinfos[bufnr] = nil + end, + }) + end - return levels[lnum] or '0' + return foldinfos[bufnr].levels[lnum] or '0' end return M diff --git a/runtime/lua/vim/treesitter/language.lua b/runtime/lua/vim/treesitter/language.lua index 5f34d9cd56..47375fd5e6 100644 --- a/runtime/lua/vim/treesitter/language.lua +++ b/runtime/lua/vim/treesitter/language.lua @@ -60,16 +60,6 @@ function M.add(lang, opts) filetype = { filetype, { 'string', 'table' }, true }, }) - if filetype == '' then - error(string.format("'%s' is not a valid filetype", filetype)) - elseif type(filetype) == 'table' then - for _, f in ipairs(filetype) do - if f == '' then - error(string.format("'%s' is not a valid filetype", filetype)) - end - end - end - M.register(lang, filetype or lang) if vim._ts_has_language(lang) then @@ -109,7 +99,9 @@ function M.register(lang, filetype) end for _, f in ipairs(filetypes) do - ft_to_lang[f] = lang + if f ~= '' then + ft_to_lang[f] = lang + end end end diff --git a/src/nvim/lua/stdlib.c b/src/nvim/lua/stdlib.c index d9682ff63d..b6e56c35d6 100644 --- a/src/nvim/lua/stdlib.c +++ b/src/nvim/lua/stdlib.c @@ -26,6 +26,7 @@ #include "nvim/eval/typval.h" #include "nvim/eval/typval_defs.h" #include "nvim/ex_eval.h" +#include "nvim/fold.h" #include "nvim/globals.h" #include "nvim/lua/converter.h" #include "nvim/lua/spell.h" @@ -528,6 +529,31 @@ static int nlua_iconv(lua_State *lstate) return 1; } +// Like 'zx' but don't call newFoldLevel() +static int nlua_foldupdate(lua_State *lstate) +{ + curwin->w_foldinvalid = true; // recompute folds + foldOpenCursor(); + + return 0; +} + +// Access to internal functions. For use in runtime/ +static void nlua_state_add_internal(lua_State *const lstate) +{ + // _getvar + lua_pushcfunction(lstate, &nlua_getvar); + lua_setfield(lstate, -2, "_getvar"); + + // _setvar + lua_pushcfunction(lstate, &nlua_setvar); + lua_setfield(lstate, -2, "_setvar"); + + // _updatefolds + lua_pushcfunction(lstate, &nlua_foldupdate); + lua_setfield(lstate, -2, "_foldupdate"); +} + void nlua_state_add_stdlib(lua_State *const lstate, bool is_thread) { if (!is_thread) { @@ -562,14 +588,6 @@ void nlua_state_add_stdlib(lua_State *const lstate, bool is_thread) lua_setfield(lstate, -2, "__index"); // [meta] lua_pop(lstate, 1); // don't use metatable now - // _getvar - lua_pushcfunction(lstate, &nlua_getvar); - lua_setfield(lstate, -2, "_getvar"); - - // _setvar - lua_pushcfunction(lstate, &nlua_setvar); - lua_setfield(lstate, -2, "_setvar"); - // vim.spell luaopen_spell(lstate); lua_setfield(lstate, -2, "spell"); @@ -578,6 +596,8 @@ void nlua_state_add_stdlib(lua_State *const lstate, bool is_thread) // depends on p_ambw, p_emoji lua_pushcfunction(lstate, &nlua_iconv); lua_setfield(lstate, -2, "iconv"); + + nlua_state_add_internal(lstate); } // vim.mpack diff --git a/src/nvim/lua/treesitter.c b/src/nvim/lua/treesitter.c index ae69f3f120..289a0cb9b4 100644 --- a/src/nvim/lua/treesitter.c +++ b/src/nvim/lua/treesitter.c @@ -405,8 +405,6 @@ static int parser_parse(lua_State *L) old_tree = tmp ? *tmp : NULL; } - bool include_bytes = (lua_gettop(L) >= 3) && lua_toboolean(L, 3); - TSTree *new_tree = NULL; size_t len; const char *str; @@ -443,6 +441,8 @@ static int parser_parse(lua_State *L) return luaL_argerror(L, 3, "expected either string or buffer handle"); } + bool include_bytes = (lua_gettop(L) >= 4) && lua_toboolean(L, 4); + // Sometimes parsing fails (timeout, or wrong parser ABI) // In those case, just return an error. if (!new_tree) { diff --git a/test/functional/treesitter/language_spec.lua b/test/functional/treesitter/language_spec.lua index 747aea54b7..48e7b4b018 100644 --- a/test/functional/treesitter/language_spec.lua +++ b/test/functional/treesitter/language_spec.lua @@ -36,11 +36,6 @@ describe('treesitter language API', function() pcall_err(exec_lua, 'vim.treesitter.add("/foo/")')) end) - it('shows error for invalid filetype', function() - eq('.../language.lua:0: \'\' is not a valid filetype', - pcall_err(exec_lua, [[vim.treesitter.add('foo', { filetype = '' })]])) - end) - it('inspects language', function() local keys, fields, symbols = unpack(exec_lua([[ local lang = vim.treesitter.inspect_language('c') diff --git a/test/functional/treesitter/parser_spec.lua b/test/functional/treesitter/parser_spec.lua index 27f2e81ab2..0f00fcfe0d 100644 --- a/test/functional/treesitter/parser_spec.lua +++ b/test/functional/treesitter/parser_spec.lua @@ -128,7 +128,9 @@ void ui_refresh(void) it('does not get parser for empty filetype', function() insert(test_text); - eq(".../language.lua:0: '' is not a valid filetype", + eq('.../treesitter.lua:0: There is no parser available for buffer 1 and one' + .. ' could not be created because lang could not be determined. Either' + .. ' pass lang or set the buffer filetype', pcall_err(exec_lua, 'vim.treesitter.get_parser(0)')) -- Must provide language for buffers with an empty filetype @@ -886,18 +888,20 @@ int x = INT_MAX; it("can fold via foldexpr", function() insert(test_text) - local levels = exec_lua([[ - vim.opt.filetype = 'c' - vim.treesitter.get_parser(0, "c") - local res = {} - for i = 1, vim.api.nvim_buf_line_count(0) do - res[i] = vim.treesitter.foldexpr(i) - end - return res - ]]) + local function get_fold_levels() + return exec_lua([[ + local res = {} + for i = 1, vim.api.nvim_buf_line_count(0) do + res[i] = vim.treesitter.foldexpr(i) + end + return res + ]]) + end + + exec_lua([[vim.treesitter.get_parser(0, "c")]]) eq({ - [1] = '>1', + [1] = '>1', [2] = '1', [3] = '1', [4] = '1', @@ -915,6 +919,33 @@ int x = INT_MAX; [16] = '3', [17] = '3', [18] = '2', - [19] = '1' }, levels) + [19] = '1' }, get_fold_levels()) + + helpers.command('1,2d') + helpers.poke_eventloop() + + exec_lua([[vim.treesitter.get_parser():parse()]]) + + helpers.poke_eventloop() + helpers.sleep(100) + + eq({ + [1] = '0', + [2] = '0', + [3] = '>1', + [4] = '1', + [5] = '1', + [6] = '0', + [7] = '0', + [8] = '>1', + [9] = '1', + [10] = '1', + [11] = '1', + [12] = '1', + [13] = '>2', + [14] = '2', + [15] = '2', + [16] = '1', + [17] = '0' }, get_fold_levels()) end) end) |