aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLewis Russell <lewis6991@gmail.com>2023-02-23 15:19:52 +0000
committerGitHub <noreply@github.com>2023-02-23 15:19:52 +0000
commit75e53341f37eeeda7d9be7f934249f7e5e4397e9 (patch)
treeb0b88f28c0ed701a9b2f13dbfe988061d48fa937
parent86807157438240757199f925f538d7ad02322754 (diff)
downloadrneovim-75e53341f37eeeda7d9be7f934249f7e5e4397e9.tar.gz
rneovim-75e53341f37eeeda7d9be7f934249f7e5e4397e9.tar.bz2
rneovim-75e53341f37eeeda7d9be7f934249f7e5e4397e9.zip
perf(treesitter): smarter languagetree invalidation
Problem: Treesitter injections are slow because all injected trees are invalidated on every change. Solution: Implement smarter invalidation to avoid reparsing injected regions. - In on_bytes, try and update self._regions as best we can. This PR just offsets any regions after the change. - Add valid flags for each region in self._regions. - Call on_bytes recursively for all children. - We still need to run the query every time for the top level tree. I don't know how to avoid this. However, if the new injection ranges don't change, then we re-use the old trees and avoid reparsing children. This should result in roughly a 2-3x reduction in tree parsing when the comment injections are enabled.
-rw-r--r--runtime/lua/vim/treesitter/_range.lua126
-rw-r--r--runtime/lua/vim/treesitter/languagetree.lua249
-rw-r--r--runtime/lua/vim/treesitter/query.lua2
-rw-r--r--scripts/lua2dox.lua2
-rw-r--r--test/functional/treesitter/parser_spec.lua35
5 files changed, 326 insertions, 88 deletions
diff --git a/runtime/lua/vim/treesitter/_range.lua b/runtime/lua/vim/treesitter/_range.lua
new file mode 100644
index 0000000000..b87542c20f
--- /dev/null
+++ b/runtime/lua/vim/treesitter/_range.lua
@@ -0,0 +1,126 @@
+local api = vim.api
+
+local M = {}
+
+---@alias Range4 {[1]: integer, [2]: integer, [3]: integer, [4]: integer}
+---@alias Range6 {[1]: integer, [2]: integer, [3]: integer, [4]: integer, [5]: integer, [6]: integer}
+
+---@private
+---@param a_row integer
+---@param a_col integer
+---@param b_row integer
+---@param b_col integer
+---@return integer
+--- 1: a > b
+--- 0: a == b
+--- -1: a < b
+local function cmp_pos(a_row, a_col, b_row, b_col)
+ if a_row == b_row then
+ if a_col > b_col then
+ return 1
+ elseif a_col < b_col then
+ return -1
+ else
+ return 0
+ end
+ elseif a_row > b_row then
+ return 1
+ end
+
+ return -1
+end
+
+M.cmp_pos = {
+ lt = function(...)
+ return cmp_pos(...) == -1
+ end,
+ le = function(...)
+ return cmp_pos(...) ~= 1
+ end,
+ gt = function(...)
+ return cmp_pos(...) == 1
+ end,
+ ge = function(...)
+ return cmp_pos(...) ~= -1
+ end,
+ eq = function(...)
+ return cmp_pos(...) == 0
+ end,
+ ne = function(...)
+ return cmp_pos(...) ~= 0
+ end,
+}
+
+setmetatable(M.cmp_pos, { __call = cmp_pos })
+
+---@private
+---@param r1 Range4|Range6
+---@param r2 Range4|Range6
+---@return boolean
+function M.intercepts(r1, r2)
+ local off_1 = #r1 == 6 and 1 or 0
+ local off_2 = #r1 == 6 and 1 or 0
+
+ local srow_1, scol_1, erow_1, ecol_1 = r1[1], r2[2], r1[3 + off_1], r1[4 + off_1]
+ local srow_2, scol_2, erow_2, ecol_2 = r2[1], r2[2], r2[3 + off_2], r2[4 + off_2]
+
+ -- r1 is above r2
+ if M.cmp_pos.le(erow_1, ecol_1, srow_2, scol_2) then
+ return false
+ end
+
+ -- r1 is below r2
+ if M.cmp_pos.ge(srow_1, scol_1, erow_2, ecol_2) then
+ return false
+ end
+
+ return true
+end
+
+---@private
+---@param r1 Range4|Range6
+---@param r2 Range4|Range6
+---@return boolean whether r1 contains r2
+function M.contains(r1, r2)
+ local off_1 = #r1 == 6 and 1 or 0
+ local off_2 = #r1 == 6 and 1 or 0
+
+ local srow_1, scol_1, erow_1, ecol_1 = r1[1], r2[2], r1[3 + off_1], r1[4 + off_1]
+ local srow_2, scol_2, erow_2, ecol_2 = r2[1], r2[2], r2[3 + off_2], r2[4 + off_2]
+
+ -- start doesn't fit
+ if M.cmp_pos.gt(srow_1, scol_1, srow_2, scol_2) then
+ return false
+ end
+
+ -- end doesn't fit
+ if M.cmp_pos.lt(erow_1, ecol_1, erow_2, ecol_2) then
+ return false
+ end
+
+ return true
+end
+
+---@private
+---@param source integer|string
+---@param range Range4
+---@return Range6
+function M.add_bytes(source, range)
+ local start_row, start_col, end_row, end_col = range[1], range[2], range[3], range[4]
+ local start_byte = 0
+ local end_byte = 0
+ -- TODO(vigoux): proper byte computation here, and account for EOL ?
+ if type(source) == 'number' then
+ -- Easy case, this is a buffer parser
+ start_byte = api.nvim_buf_get_offset(source, start_row) + start_col
+ end_byte = api.nvim_buf_get_offset(source, end_row) + end_col
+ elseif type(source) == 'string' then
+ -- string parser, single `\n` delimited string
+ start_byte = vim.fn.byteidx(source, start_col)
+ end_byte = vim.fn.byteidx(source, end_col)
+ end
+
+ return { start_row, start_col, start_byte, end_row, end_col, end_byte }
+end
+
+return M
diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua
index c9fd4bb2ea..2d4e2e595b 100644
--- a/runtime/lua/vim/treesitter/languagetree.lua
+++ b/runtime/lua/vim/treesitter/languagetree.lua
@@ -1,9 +1,8 @@
local a = vim.api
local query = require('vim.treesitter.query')
local language = require('vim.treesitter.language')
+local Range = require('vim.treesitter._range')
----@alias Range {[1]: integer, [2]: integer, [3]: integer, [4]: integer}
---
---@alias TSCallbackName
---| 'changedtree'
---| 'bytes'
@@ -24,11 +23,13 @@ local language = require('vim.treesitter.language')
---@field private _injection_query Query Queries defining injected languages
---@field private _opts table Options
---@field private _parser TSParser Parser for language
----@field private _regions Range[][] List of regions this tree should manage and parse
+---@field private _regions Range6[][] List of regions this tree should manage and parse
---@field private _lang string Language name
---@field private _source (integer|string) Buffer or string to parse
---@field private _trees TSTree[] Reference to parsed tree (one for each language)
----@field private _valid boolean If the parsed tree is valid
+---@field private _valid boolean|table<integer,true> If the parsed tree is valid
+--- TODO(lewis6991): combine _regions, _valid and _trees
+---@field private _is_child boolean
local LanguageTree = {}
---@class LanguageTreeOpts
@@ -114,6 +115,9 @@ end
--- If the tree is invalid, call `parse()`.
--- This will return the updated tree.
function LanguageTree:is_valid()
+ if type(self._valid) == 'table' then
+ return #self._valid == #self._regions
+ end
return self._valid
end
@@ -127,6 +131,16 @@ function LanguageTree:source()
return self._source
end
+---@private
+---This is only exposed so it can be wrapped for profiling
+---@param old_tree TSTree
+---@return TSTree, integer[]
+function LanguageTree:_parse_tree(old_tree)
+ local tree, tree_changes = self._parser:parse(old_tree, self._source)
+ self:_do_callback('changedtree', tree_changes, tree)
+ return tree, tree_changes
+end
+
--- Parses all defined regions using a treesitter parser
--- for the language this tree represents.
--- This will run the injection query for this language to
@@ -135,35 +149,27 @@ end
---@return TSTree[]
---@return table|nil Change list
function LanguageTree:parse()
- if self._valid then
+ if self:is_valid() then
return self._trees
end
- local parser = self._parser
local changes = {}
- local old_trees = self._trees
- self._trees = {}
-
-- If there are no ranges, set to an empty list
-- so the included ranges in the parser are cleared.
- if self._regions and #self._regions > 0 then
+ if #self._regions > 0 then
for i, ranges in ipairs(self._regions) do
- local old_tree = old_trees[i]
- parser:set_included_ranges(ranges)
-
- local tree, tree_changes = parser:parse(old_tree, self._source)
- self:_do_callback('changedtree', tree_changes, tree)
-
- table.insert(self._trees, tree)
- vim.list_extend(changes, tree_changes)
+ if not self._valid or not self._valid[i] then
+ self._parser:set_included_ranges(ranges)
+ local tree, tree_changes = self:_parse_tree(self._trees[i])
+ self._trees[i] = tree
+ vim.list_extend(changes, tree_changes)
+ end
end
else
- local tree, tree_changes = parser:parse(old_trees[1], self._source)
- self:_do_callback('changedtree', tree_changes, tree)
-
- table.insert(self._trees, tree)
- vim.list_extend(changes, tree_changes)
+ local tree, tree_changes = self:_parse_tree(self._trees[1])
+ self._trees = { tree }
+ changes = tree_changes
end
local injections_by_lang = self:_get_injections()
@@ -249,6 +255,7 @@ function LanguageTree:add_child(lang)
end
self._children[lang] = LanguageTree.new(self._source, lang, self._opts)
+ self._children[lang]._is_child = true
self:invalidate()
self:_do_callback('child_added', self._children[lang])
@@ -298,43 +305,35 @@ end
--- This allows for embedded languages to be parsed together across different
--- nodes, which is useful for templating languages like ERB and EJS.
---
---- Note: This call invalidates the tree and requires it to be parsed again.
----
---@private
----@param regions integer[][][] List of regions this tree should manage and parse.
+---@param regions Range4[][] List of regions this tree should manage and parse.
function LanguageTree:set_included_regions(regions)
-- Transform the tables from 4 element long to 6 element long (with byte offset)
for _, region in ipairs(regions) do
for i, range in ipairs(region) do
if type(range) == 'table' and #range == 4 then
- ---@diagnostic disable-next-line:no-unknown
- local start_row, start_col, end_row, end_col = unpack(range)
- local start_byte = 0
- local end_byte = 0
- local source = self._source
- -- TODO(vigoux): proper byte computation here, and account for EOL ?
- if type(source) == 'number' then
- -- Easy case, this is a buffer parser
- start_byte = a.nvim_buf_get_offset(source, start_row) + start_col
- end_byte = a.nvim_buf_get_offset(source, end_row) + end_col
- elseif type(self._source) == 'string' then
- -- string parser, single `\n` delimited string
- start_byte = vim.fn.byteidx(self._source, start_col)
- end_byte = vim.fn.byteidx(self._source, end_col)
- end
+ region[i] = Range.add_bytes(self._source, range)
+ end
+ end
+ end
- region[i] = { start_row, start_col, start_byte, end_row, end_col, end_byte }
+ if #self._regions ~= #regions then
+ self._trees = {}
+ self:invalidate()
+ elseif self._valid ~= false then
+ if self._valid == true then
+ self._valid = {}
+ end
+ for i = 1, #regions do
+ self._valid[i] = true
+ if not vim.deep_equal(self._regions[i], regions[i]) then
+ self._valid[i] = nil
+ self._trees[i] = nil
end
end
end
self._regions = regions
- -- Trees are no longer valid now that we have changed regions.
- -- TODO(vigoux,steelsojka): Look into doing this smarter so we can use some of the
- -- old trees for incremental parsing. Currently, this only
- -- affects injected languages.
- self._trees = {}
- self:invalidate()
end
--- Gets the set of included regions
@@ -346,10 +345,10 @@ end
---@param node TSNode
---@param id integer
---@param metadata TSMetadata
----@return Range
+---@return Range4
local function get_range_from_metadata(node, id, metadata)
if metadata[id] and metadata[id].range then
- return metadata[id].range --[[@as Range]]
+ return metadata[id].range --[[@as Range4]]
end
return { node:range() }
end
@@ -378,7 +377,7 @@ function LanguageTree:_get_injections()
self._injection_query:iter_matches(root_node, self._source, start_line, end_line + 1)
do
local lang = nil ---@type string
- local ranges = {} ---@type Range[]
+ local ranges = {} ---@type Range4[]
local combined = metadata.combined ---@type boolean
-- Directives can configure how injections are captured as well as actual node captures.
@@ -408,6 +407,7 @@ function LanguageTree:_get_injections()
-- Lang should override any other language tag
if name == 'language' and not lang then
+ ---@diagnostic disable-next-line
lang = query.get_node_text(node, self._source, { metadata = metadata[id] })
elseif name == 'combined' then
combined = true
@@ -426,6 +426,8 @@ function LanguageTree:_get_injections()
end
end
+ assert(type(lang) == 'string')
+
-- Each tree index should be isolated from the other nodes.
if not injections[tree_index] then
injections[tree_index] = {}
@@ -446,7 +448,7 @@ function LanguageTree:_get_injections()
end
end
- ---@type table<string,Range[][]>
+ ---@type table<string,Range4[][]>
local result = {}
-- Generate a map by lang of node lists.
@@ -486,6 +488,45 @@ function LanguageTree:_do_callback(cb_name, ...)
end
---@private
+---@param regions Range6[][]
+---@param old_range Range6
+---@param new_range Range6
+---@return table<integer, true> region indices to invalidate
+local function update_regions(regions, old_range, new_range)
+ ---@type table<integer,true>
+ local valid = {}
+
+ for i, ranges in ipairs(regions or {}) do
+ valid[i] = true
+ for j, r in ipairs(ranges) do
+ if Range.intercepts(r, old_range) then
+ valid[i] = nil
+ break
+ end
+
+ -- Range after change. Adjust
+ if Range.cmp_pos.gt(r[1], r[2], old_range[4], old_range[5]) then
+ local byte_offset = new_range[6] - old_range[6]
+ local row_offset = new_range[4] - old_range[4]
+
+ -- Update the range to avoid invalidation in set_included_regions()
+ -- which will compare the regions against the parsed injection regions
+ ranges[j] = {
+ r[1] + row_offset,
+ r[2],
+ r[3] + byte_offset,
+ r[4] + row_offset,
+ r[5],
+ r[6] + byte_offset,
+ }
+ end
+ end
+ end
+
+ return valid
+end
+
+---@private
---@param bufnr integer
---@param changed_tick integer
---@param start_row integer
@@ -510,14 +551,53 @@ function LanguageTree:_on_bytes(
new_col,
new_byte
)
- self:invalidate()
-
local old_end_col = old_col + ((old_row == 0) and start_col or 0)
local new_end_col = new_col + ((new_row == 0) and start_col or 0)
- -- Edit all trees recursively, together BEFORE emitting a bytes callback.
- -- In most cases this callback should only be called from the root tree.
- self:for_each_tree(function(tree)
+ local old_range = {
+ start_row,
+ start_col,
+ start_byte,
+ start_row + old_row,
+ old_end_col,
+ start_byte + old_byte,
+ }
+
+ local new_range = {
+ start_row,
+ start_col,
+ start_byte,
+ start_row + new_row,
+ new_end_col,
+ start_byte + new_byte,
+ }
+
+ local valid_regions = update_regions(self._regions, old_range, new_range)
+
+ if #self._regions == 0 or #valid_regions == 0 then
+ self._valid = false
+ else
+ self._valid = valid_regions
+ end
+
+ for _, child in pairs(self._children) do
+ child:_on_bytes(
+ bufnr,
+ changed_tick,
+ start_row,
+ start_col,
+ start_byte,
+ old_row,
+ old_col,
+ old_byte,
+ new_row,
+ new_col,
+ new_byte
+ )
+ end
+
+ -- Edit trees together BEFORE emitting a bytes callback.
+ for _, tree in ipairs(self._trees) do
tree:edit(
start_byte,
start_byte + old_byte,
@@ -529,22 +609,24 @@ function LanguageTree:_on_bytes(
start_row + new_row,
new_end_col
)
- end)
+ end
- self:_do_callback(
- 'bytes',
- bufnr,
- changed_tick,
- start_row,
- start_col,
- start_byte,
- old_row,
- old_col,
- old_byte,
- new_row,
- new_col,
- new_byte
- )
+ if not self._is_child then
+ self:_do_callback(
+ 'bytes',
+ bufnr,
+ changed_tick,
+ start_row,
+ start_col,
+ start_byte,
+ old_row,
+ old_col,
+ old_byte,
+ new_row,
+ new_col,
+ new_byte
+ )
+ end
end
---@private
@@ -595,19 +677,15 @@ end
---@private
---@param tree TSTree
----@param range Range
+---@param range Range4
---@return boolean
local function tree_contains(tree, range)
- local start_row, start_col, end_row, end_col = tree:root():range()
- local start_fits = start_row < range[1] or (start_row == range[1] and start_col <= range[2])
- local end_fits = end_row > range[3] or (end_row == range[3] and end_col >= range[4])
-
- return start_fits and end_fits
+ return Range.contains({ tree:root():range() }, range)
end
--- Determines whether {range} is contained in the |LanguageTree|.
---
----@param range Range `{ start_line, start_col, end_line, end_col }`
+---@param range Range4 `{ start_line, start_col, end_line, end_col }`
---@return boolean
function LanguageTree:contains(range)
for _, tree in pairs(self._trees) do
@@ -621,7 +699,7 @@ end
--- Gets the tree that contains {range}.
---
----@param range Range `{ start_line, start_col, end_line, end_col }`
+---@param range Range4 `{ start_line, start_col, end_line, end_col }`
---@param opts table|nil Optional keyword arguments:
--- - ignore_injections boolean Ignore injected languages (default true)
---@return TSTree|nil
@@ -631,10 +709,9 @@ function LanguageTree:tree_for_range(range, opts)
if not ignore then
for _, child in pairs(self._children) do
- for _, tree in pairs(child:trees()) do
- if tree_contains(tree, range) then
- return tree
- end
+ local tree = child:tree_for_range(range, opts)
+ if tree then
+ return tree
end
end
end
@@ -650,7 +727,7 @@ end
--- Gets the smallest named node that contains {range}.
---
----@param range Range `{ start_line, start_col, end_line, end_col }`
+---@param range Range4 `{ start_line, start_col, end_line, end_col }`
---@param opts table|nil Optional keyword arguments:
--- - ignore_injections boolean Ignore injected languages (default true)
---@return TSNode | nil Found node
@@ -663,7 +740,7 @@ end
--- Gets the appropriate language that contains {range}.
---
----@param range Range `{ start_line, start_col, end_line, end_col }`
+---@param range Range4 `{ start_line, start_col, end_line, end_col }`
---@return LanguageTree Managing {range}
function LanguageTree:language_for_range(range)
for _, child in pairs(self._children) do
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua
index 58a29f2fe0..13d98a0625 100644
--- a/runtime/lua/vim/treesitter/query.lua
+++ b/runtime/lua/vim/treesitter/query.lua
@@ -406,7 +406,7 @@ predicate_handlers['vim-match?'] = predicate_handlers['match?']
---@class TSMetadata
---@field [integer] TSMetadata
---@field [string] integer|string
----@field range Range
+---@field range Range4
---@alias TSDirective fun(match: TSMatch, _, _, predicate: any[], metadata: TSMetadata)
diff --git a/scripts/lua2dox.lua b/scripts/lua2dox.lua
index fc0e915307..17de0ea9b4 100644
--- a/scripts/lua2dox.lua
+++ b/scripts/lua2dox.lua
@@ -291,7 +291,7 @@ local types = { 'integer', 'number', 'string', 'table', 'list', 'boolean', 'func
local tagged_types = { 'TSNode', 'LanguageTree' }
-- Document these as 'table'
-local alias_types = { 'Range' }
+local alias_types = { 'Range4', 'Range6' }
--! \brief run the filter
function TLua2DoX_filter.readfile(this, AppStamp, Filename)
diff --git a/test/functional/treesitter/parser_spec.lua b/test/functional/treesitter/parser_spec.lua
index 2430021617..fdd6403859 100644
--- a/test/functional/treesitter/parser_spec.lua
+++ b/test/functional/treesitter/parser_spec.lua
@@ -639,6 +639,17 @@ int x = INT_MAX;
{1, 26, 1, 65}, -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y))
{2, 29, 2, 68} -- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y))
}, get_ranges())
+
+ helpers.feed('ggo<esc>')
+ eq(5, exec_lua("return #parser:children().c:trees()"))
+ eq({
+ {0, 0, 8, 0}, -- root tree
+ {4, 14, 4, 17}, -- VALUE 123
+ {5, 15, 5, 18}, -- VALUE1 123
+ {6, 15, 6, 18}, -- VALUE2 123
+ {2, 26, 2, 65}, -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y))
+ {3, 29, 3, 68} -- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y))
+ }, get_ranges())
end)
end)
@@ -660,6 +671,18 @@ int x = INT_MAX;
{1, 26, 2, 68} -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y))
-- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y))
}, get_ranges())
+
+ helpers.feed('ggo<esc>')
+ eq("table", exec_lua("return type(parser:children().c)"))
+ eq(2, exec_lua("return #parser:children().c:trees()"))
+ eq({
+ {0, 0, 8, 0}, -- root tree
+ {4, 14, 6, 18}, -- VALUE 123
+ -- VALUE1 123
+ -- VALUE2 123
+ {2, 26, 3, 68} -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y))
+ -- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y))
+ }, get_ranges())
end)
end)
@@ -688,6 +711,18 @@ int x = INT_MAX;
{1, 26, 2, 68} -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y))
-- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y))
}, get_ranges())
+
+ helpers.feed('ggo<esc>')
+ eq("table", exec_lua("return type(parser:children().c)"))
+ eq(2, exec_lua("return #parser:children().c:trees()"))
+ eq({
+ {0, 0, 8, 0}, -- root tree
+ {4, 14, 6, 18}, -- VALUE 123
+ -- VALUE1 123
+ -- VALUE2 123
+ {2, 26, 3, 68} -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y))
+ -- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y))
+ }, get_ranges())
end)
it("should not inject bad languages", function()