diff options
author | Lewis Russell <lewis6991@gmail.com> | 2023-02-23 15:19:52 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-02-23 15:19:52 +0000 |
commit | 75e53341f37eeeda7d9be7f934249f7e5e4397e9 (patch) | |
tree | b0b88f28c0ed701a9b2f13dbfe988061d48fa937 | |
parent | 86807157438240757199f925f538d7ad02322754 (diff) | |
download | rneovim-75e53341f37eeeda7d9be7f934249f7e5e4397e9.tar.gz rneovim-75e53341f37eeeda7d9be7f934249f7e5e4397e9.tar.bz2 rneovim-75e53341f37eeeda7d9be7f934249f7e5e4397e9.zip |
perf(treesitter): smarter languagetree invalidation
Problem:
Treesitter injections are slow because all injected trees are invalidated on every change.
Solution:
Implement smarter invalidation to avoid reparsing injected regions.
- In on_bytes, try and update self._regions as best we can. This PR just offsets any regions after the change.
- Add valid flags for each region in self._regions.
- Call on_bytes recursively for all children.
- We still need to run the query every time for the top level tree. I don't know how to avoid this. However, if the new injection ranges don't change, then we re-use the old trees and avoid reparsing children.
This should result in roughly a 2-3x reduction in tree parsing when the comment injections are enabled.
-rw-r--r-- | runtime/lua/vim/treesitter/_range.lua | 126 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/languagetree.lua | 249 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 2 | ||||
-rw-r--r-- | scripts/lua2dox.lua | 2 | ||||
-rw-r--r-- | test/functional/treesitter/parser_spec.lua | 35 |
5 files changed, 326 insertions, 88 deletions
diff --git a/runtime/lua/vim/treesitter/_range.lua b/runtime/lua/vim/treesitter/_range.lua new file mode 100644 index 0000000000..b87542c20f --- /dev/null +++ b/runtime/lua/vim/treesitter/_range.lua @@ -0,0 +1,126 @@ +local api = vim.api + +local M = {} + +---@alias Range4 {[1]: integer, [2]: integer, [3]: integer, [4]: integer} +---@alias Range6 {[1]: integer, [2]: integer, [3]: integer, [4]: integer, [5]: integer, [6]: integer} + +---@private +---@param a_row integer +---@param a_col integer +---@param b_row integer +---@param b_col integer +---@return integer +--- 1: a > b +--- 0: a == b +--- -1: a < b +local function cmp_pos(a_row, a_col, b_row, b_col) + if a_row == b_row then + if a_col > b_col then + return 1 + elseif a_col < b_col then + return -1 + else + return 0 + end + elseif a_row > b_row then + return 1 + end + + return -1 +end + +M.cmp_pos = { + lt = function(...) + return cmp_pos(...) == -1 + end, + le = function(...) + return cmp_pos(...) ~= 1 + end, + gt = function(...) + return cmp_pos(...) == 1 + end, + ge = function(...) + return cmp_pos(...) ~= -1 + end, + eq = function(...) + return cmp_pos(...) == 0 + end, + ne = function(...) + return cmp_pos(...) ~= 0 + end, +} + +setmetatable(M.cmp_pos, { __call = cmp_pos }) + +---@private +---@param r1 Range4|Range6 +---@param r2 Range4|Range6 +---@return boolean +function M.intercepts(r1, r2) + local off_1 = #r1 == 6 and 1 or 0 + local off_2 = #r1 == 6 and 1 or 0 + + local srow_1, scol_1, erow_1, ecol_1 = r1[1], r2[2], r1[3 + off_1], r1[4 + off_1] + local srow_2, scol_2, erow_2, ecol_2 = r2[1], r2[2], r2[3 + off_2], r2[4 + off_2] + + -- r1 is above r2 + if M.cmp_pos.le(erow_1, ecol_1, srow_2, scol_2) then + return false + end + + -- r1 is below r2 + if M.cmp_pos.ge(srow_1, scol_1, erow_2, ecol_2) then + return false + end + + return true +end + +---@private +---@param r1 Range4|Range6 +---@param r2 Range4|Range6 +---@return boolean whether r1 contains r2 +function M.contains(r1, r2) + local off_1 = #r1 == 6 and 1 or 0 + local off_2 = #r1 == 6 and 1 or 0 + + local srow_1, scol_1, erow_1, ecol_1 = r1[1], r2[2], r1[3 + off_1], r1[4 + off_1] + local srow_2, scol_2, erow_2, ecol_2 = r2[1], r2[2], r2[3 + off_2], r2[4 + off_2] + + -- start doesn't fit + if M.cmp_pos.gt(srow_1, scol_1, srow_2, scol_2) then + return false + end + + -- end doesn't fit + if M.cmp_pos.lt(erow_1, ecol_1, erow_2, ecol_2) then + return false + end + + return true +end + +---@private +---@param source integer|string +---@param range Range4 +---@return Range6 +function M.add_bytes(source, range) + local start_row, start_col, end_row, end_col = range[1], range[2], range[3], range[4] + local start_byte = 0 + local end_byte = 0 + -- TODO(vigoux): proper byte computation here, and account for EOL ? + if type(source) == 'number' then + -- Easy case, this is a buffer parser + start_byte = api.nvim_buf_get_offset(source, start_row) + start_col + end_byte = api.nvim_buf_get_offset(source, end_row) + end_col + elseif type(source) == 'string' then + -- string parser, single `\n` delimited string + start_byte = vim.fn.byteidx(source, start_col) + end_byte = vim.fn.byteidx(source, end_col) + end + + return { start_row, start_col, start_byte, end_row, end_col, end_byte } +end + +return M diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua index c9fd4bb2ea..2d4e2e595b 100644 --- a/runtime/lua/vim/treesitter/languagetree.lua +++ b/runtime/lua/vim/treesitter/languagetree.lua @@ -1,9 +1,8 @@ local a = vim.api local query = require('vim.treesitter.query') local language = require('vim.treesitter.language') +local Range = require('vim.treesitter._range') ----@alias Range {[1]: integer, [2]: integer, [3]: integer, [4]: integer} --- ---@alias TSCallbackName ---| 'changedtree' ---| 'bytes' @@ -24,11 +23,13 @@ local language = require('vim.treesitter.language') ---@field private _injection_query Query Queries defining injected languages ---@field private _opts table Options ---@field private _parser TSParser Parser for language ----@field private _regions Range[][] List of regions this tree should manage and parse +---@field private _regions Range6[][] List of regions this tree should manage and parse ---@field private _lang string Language name ---@field private _source (integer|string) Buffer or string to parse ---@field private _trees TSTree[] Reference to parsed tree (one for each language) ----@field private _valid boolean If the parsed tree is valid +---@field private _valid boolean|table<integer,true> If the parsed tree is valid +--- TODO(lewis6991): combine _regions, _valid and _trees +---@field private _is_child boolean local LanguageTree = {} ---@class LanguageTreeOpts @@ -114,6 +115,9 @@ end --- If the tree is invalid, call `parse()`. --- This will return the updated tree. function LanguageTree:is_valid() + if type(self._valid) == 'table' then + return #self._valid == #self._regions + end return self._valid end @@ -127,6 +131,16 @@ function LanguageTree:source() return self._source end +---@private +---This is only exposed so it can be wrapped for profiling +---@param old_tree TSTree +---@return TSTree, integer[] +function LanguageTree:_parse_tree(old_tree) + local tree, tree_changes = self._parser:parse(old_tree, self._source) + self:_do_callback('changedtree', tree_changes, tree) + return tree, tree_changes +end + --- Parses all defined regions using a treesitter parser --- for the language this tree represents. --- This will run the injection query for this language to @@ -135,35 +149,27 @@ end ---@return TSTree[] ---@return table|nil Change list function LanguageTree:parse() - if self._valid then + if self:is_valid() then return self._trees end - local parser = self._parser local changes = {} - local old_trees = self._trees - self._trees = {} - -- If there are no ranges, set to an empty list -- so the included ranges in the parser are cleared. - if self._regions and #self._regions > 0 then + if #self._regions > 0 then for i, ranges in ipairs(self._regions) do - local old_tree = old_trees[i] - parser:set_included_ranges(ranges) - - local tree, tree_changes = parser:parse(old_tree, self._source) - self:_do_callback('changedtree', tree_changes, tree) - - table.insert(self._trees, tree) - vim.list_extend(changes, tree_changes) + if not self._valid or not self._valid[i] then + self._parser:set_included_ranges(ranges) + local tree, tree_changes = self:_parse_tree(self._trees[i]) + self._trees[i] = tree + vim.list_extend(changes, tree_changes) + end end else - local tree, tree_changes = parser:parse(old_trees[1], self._source) - self:_do_callback('changedtree', tree_changes, tree) - - table.insert(self._trees, tree) - vim.list_extend(changes, tree_changes) + local tree, tree_changes = self:_parse_tree(self._trees[1]) + self._trees = { tree } + changes = tree_changes end local injections_by_lang = self:_get_injections() @@ -249,6 +255,7 @@ function LanguageTree:add_child(lang) end self._children[lang] = LanguageTree.new(self._source, lang, self._opts) + self._children[lang]._is_child = true self:invalidate() self:_do_callback('child_added', self._children[lang]) @@ -298,43 +305,35 @@ end --- This allows for embedded languages to be parsed together across different --- nodes, which is useful for templating languages like ERB and EJS. --- ---- Note: This call invalidates the tree and requires it to be parsed again. ---- ---@private ----@param regions integer[][][] List of regions this tree should manage and parse. +---@param regions Range4[][] List of regions this tree should manage and parse. function LanguageTree:set_included_regions(regions) -- Transform the tables from 4 element long to 6 element long (with byte offset) for _, region in ipairs(regions) do for i, range in ipairs(region) do if type(range) == 'table' and #range == 4 then - ---@diagnostic disable-next-line:no-unknown - local start_row, start_col, end_row, end_col = unpack(range) - local start_byte = 0 - local end_byte = 0 - local source = self._source - -- TODO(vigoux): proper byte computation here, and account for EOL ? - if type(source) == 'number' then - -- Easy case, this is a buffer parser - start_byte = a.nvim_buf_get_offset(source, start_row) + start_col - end_byte = a.nvim_buf_get_offset(source, end_row) + end_col - elseif type(self._source) == 'string' then - -- string parser, single `\n` delimited string - start_byte = vim.fn.byteidx(self._source, start_col) - end_byte = vim.fn.byteidx(self._source, end_col) - end + region[i] = Range.add_bytes(self._source, range) + end + end + end - region[i] = { start_row, start_col, start_byte, end_row, end_col, end_byte } + if #self._regions ~= #regions then + self._trees = {} + self:invalidate() + elseif self._valid ~= false then + if self._valid == true then + self._valid = {} + end + for i = 1, #regions do + self._valid[i] = true + if not vim.deep_equal(self._regions[i], regions[i]) then + self._valid[i] = nil + self._trees[i] = nil end end end self._regions = regions - -- Trees are no longer valid now that we have changed regions. - -- TODO(vigoux,steelsojka): Look into doing this smarter so we can use some of the - -- old trees for incremental parsing. Currently, this only - -- affects injected languages. - self._trees = {} - self:invalidate() end --- Gets the set of included regions @@ -346,10 +345,10 @@ end ---@param node TSNode ---@param id integer ---@param metadata TSMetadata ----@return Range +---@return Range4 local function get_range_from_metadata(node, id, metadata) if metadata[id] and metadata[id].range then - return metadata[id].range --[[@as Range]] + return metadata[id].range --[[@as Range4]] end return { node:range() } end @@ -378,7 +377,7 @@ function LanguageTree:_get_injections() self._injection_query:iter_matches(root_node, self._source, start_line, end_line + 1) do local lang = nil ---@type string - local ranges = {} ---@type Range[] + local ranges = {} ---@type Range4[] local combined = metadata.combined ---@type boolean -- Directives can configure how injections are captured as well as actual node captures. @@ -408,6 +407,7 @@ function LanguageTree:_get_injections() -- Lang should override any other language tag if name == 'language' and not lang then + ---@diagnostic disable-next-line lang = query.get_node_text(node, self._source, { metadata = metadata[id] }) elseif name == 'combined' then combined = true @@ -426,6 +426,8 @@ function LanguageTree:_get_injections() end end + assert(type(lang) == 'string') + -- Each tree index should be isolated from the other nodes. if not injections[tree_index] then injections[tree_index] = {} @@ -446,7 +448,7 @@ function LanguageTree:_get_injections() end end - ---@type table<string,Range[][]> + ---@type table<string,Range4[][]> local result = {} -- Generate a map by lang of node lists. @@ -486,6 +488,45 @@ function LanguageTree:_do_callback(cb_name, ...) end ---@private +---@param regions Range6[][] +---@param old_range Range6 +---@param new_range Range6 +---@return table<integer, true> region indices to invalidate +local function update_regions(regions, old_range, new_range) + ---@type table<integer,true> + local valid = {} + + for i, ranges in ipairs(regions or {}) do + valid[i] = true + for j, r in ipairs(ranges) do + if Range.intercepts(r, old_range) then + valid[i] = nil + break + end + + -- Range after change. Adjust + if Range.cmp_pos.gt(r[1], r[2], old_range[4], old_range[5]) then + local byte_offset = new_range[6] - old_range[6] + local row_offset = new_range[4] - old_range[4] + + -- Update the range to avoid invalidation in set_included_regions() + -- which will compare the regions against the parsed injection regions + ranges[j] = { + r[1] + row_offset, + r[2], + r[3] + byte_offset, + r[4] + row_offset, + r[5], + r[6] + byte_offset, + } + end + end + end + + return valid +end + +---@private ---@param bufnr integer ---@param changed_tick integer ---@param start_row integer @@ -510,14 +551,53 @@ function LanguageTree:_on_bytes( new_col, new_byte ) - self:invalidate() - local old_end_col = old_col + ((old_row == 0) and start_col or 0) local new_end_col = new_col + ((new_row == 0) and start_col or 0) - -- Edit all trees recursively, together BEFORE emitting a bytes callback. - -- In most cases this callback should only be called from the root tree. - self:for_each_tree(function(tree) + local old_range = { + start_row, + start_col, + start_byte, + start_row + old_row, + old_end_col, + start_byte + old_byte, + } + + local new_range = { + start_row, + start_col, + start_byte, + start_row + new_row, + new_end_col, + start_byte + new_byte, + } + + local valid_regions = update_regions(self._regions, old_range, new_range) + + if #self._regions == 0 or #valid_regions == 0 then + self._valid = false + else + self._valid = valid_regions + end + + for _, child in pairs(self._children) do + child:_on_bytes( + bufnr, + changed_tick, + start_row, + start_col, + start_byte, + old_row, + old_col, + old_byte, + new_row, + new_col, + new_byte + ) + end + + -- Edit trees together BEFORE emitting a bytes callback. + for _, tree in ipairs(self._trees) do tree:edit( start_byte, start_byte + old_byte, @@ -529,22 +609,24 @@ function LanguageTree:_on_bytes( start_row + new_row, new_end_col ) - end) + end - self:_do_callback( - 'bytes', - bufnr, - changed_tick, - start_row, - start_col, - start_byte, - old_row, - old_col, - old_byte, - new_row, - new_col, - new_byte - ) + if not self._is_child then + self:_do_callback( + 'bytes', + bufnr, + changed_tick, + start_row, + start_col, + start_byte, + old_row, + old_col, + old_byte, + new_row, + new_col, + new_byte + ) + end end ---@private @@ -595,19 +677,15 @@ end ---@private ---@param tree TSTree ----@param range Range +---@param range Range4 ---@return boolean local function tree_contains(tree, range) - local start_row, start_col, end_row, end_col = tree:root():range() - local start_fits = start_row < range[1] or (start_row == range[1] and start_col <= range[2]) - local end_fits = end_row > range[3] or (end_row == range[3] and end_col >= range[4]) - - return start_fits and end_fits + return Range.contains({ tree:root():range() }, range) end --- Determines whether {range} is contained in the |LanguageTree|. --- ----@param range Range `{ start_line, start_col, end_line, end_col }` +---@param range Range4 `{ start_line, start_col, end_line, end_col }` ---@return boolean function LanguageTree:contains(range) for _, tree in pairs(self._trees) do @@ -621,7 +699,7 @@ end --- Gets the tree that contains {range}. --- ----@param range Range `{ start_line, start_col, end_line, end_col }` +---@param range Range4 `{ start_line, start_col, end_line, end_col }` ---@param opts table|nil Optional keyword arguments: --- - ignore_injections boolean Ignore injected languages (default true) ---@return TSTree|nil @@ -631,10 +709,9 @@ function LanguageTree:tree_for_range(range, opts) if not ignore then for _, child in pairs(self._children) do - for _, tree in pairs(child:trees()) do - if tree_contains(tree, range) then - return tree - end + local tree = child:tree_for_range(range, opts) + if tree then + return tree end end end @@ -650,7 +727,7 @@ end --- Gets the smallest named node that contains {range}. --- ----@param range Range `{ start_line, start_col, end_line, end_col }` +---@param range Range4 `{ start_line, start_col, end_line, end_col }` ---@param opts table|nil Optional keyword arguments: --- - ignore_injections boolean Ignore injected languages (default true) ---@return TSNode | nil Found node @@ -663,7 +740,7 @@ end --- Gets the appropriate language that contains {range}. --- ----@param range Range `{ start_line, start_col, end_line, end_col }` +---@param range Range4 `{ start_line, start_col, end_line, end_col }` ---@return LanguageTree Managing {range} function LanguageTree:language_for_range(range) for _, child in pairs(self._children) do diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index 58a29f2fe0..13d98a0625 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -406,7 +406,7 @@ predicate_handlers['vim-match?'] = predicate_handlers['match?'] ---@class TSMetadata ---@field [integer] TSMetadata ---@field [string] integer|string ----@field range Range +---@field range Range4 ---@alias TSDirective fun(match: TSMatch, _, _, predicate: any[], metadata: TSMetadata) diff --git a/scripts/lua2dox.lua b/scripts/lua2dox.lua index fc0e915307..17de0ea9b4 100644 --- a/scripts/lua2dox.lua +++ b/scripts/lua2dox.lua @@ -291,7 +291,7 @@ local types = { 'integer', 'number', 'string', 'table', 'list', 'boolean', 'func local tagged_types = { 'TSNode', 'LanguageTree' } -- Document these as 'table' -local alias_types = { 'Range' } +local alias_types = { 'Range4', 'Range6' } --! \brief run the filter function TLua2DoX_filter.readfile(this, AppStamp, Filename) diff --git a/test/functional/treesitter/parser_spec.lua b/test/functional/treesitter/parser_spec.lua index 2430021617..fdd6403859 100644 --- a/test/functional/treesitter/parser_spec.lua +++ b/test/functional/treesitter/parser_spec.lua @@ -639,6 +639,17 @@ int x = INT_MAX; {1, 26, 1, 65}, -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y)) {2, 29, 2, 68} -- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y)) }, get_ranges()) + + helpers.feed('ggo<esc>') + eq(5, exec_lua("return #parser:children().c:trees()")) + eq({ + {0, 0, 8, 0}, -- root tree + {4, 14, 4, 17}, -- VALUE 123 + {5, 15, 5, 18}, -- VALUE1 123 + {6, 15, 6, 18}, -- VALUE2 123 + {2, 26, 2, 65}, -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y)) + {3, 29, 3, 68} -- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y)) + }, get_ranges()) end) end) @@ -660,6 +671,18 @@ int x = INT_MAX; {1, 26, 2, 68} -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y)) -- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y)) }, get_ranges()) + + helpers.feed('ggo<esc>') + eq("table", exec_lua("return type(parser:children().c)")) + eq(2, exec_lua("return #parser:children().c:trees()")) + eq({ + {0, 0, 8, 0}, -- root tree + {4, 14, 6, 18}, -- VALUE 123 + -- VALUE1 123 + -- VALUE2 123 + {2, 26, 3, 68} -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y)) + -- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y)) + }, get_ranges()) end) end) @@ -688,6 +711,18 @@ int x = INT_MAX; {1, 26, 2, 68} -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y)) -- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y)) }, get_ranges()) + + helpers.feed('ggo<esc>') + eq("table", exec_lua("return type(parser:children().c)")) + eq(2, exec_lua("return #parser:children().c:trees()")) + eq({ + {0, 0, 8, 0}, -- root tree + {4, 14, 6, 18}, -- VALUE 123 + -- VALUE1 123 + -- VALUE2 123 + {2, 26, 3, 68} -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y)) + -- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y)) + }, get_ranges()) end) it("should not inject bad languages", function() |