aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter/_range.lua
diff options
context:
space:
mode:
authorLewis Russell <lewis6991@gmail.com>2023-02-23 15:19:52 +0000
committerGitHub <noreply@github.com>2023-02-23 15:19:52 +0000
commit75e53341f37eeeda7d9be7f934249f7e5e4397e9 (patch)
treeb0b88f28c0ed701a9b2f13dbfe988061d48fa937 /runtime/lua/vim/treesitter/_range.lua
parent86807157438240757199f925f538d7ad02322754 (diff)
downloadrneovim-75e53341f37eeeda7d9be7f934249f7e5e4397e9.tar.gz
rneovim-75e53341f37eeeda7d9be7f934249f7e5e4397e9.tar.bz2
rneovim-75e53341f37eeeda7d9be7f934249f7e5e4397e9.zip
perf(treesitter): smarter languagetree invalidation
Problem: Treesitter injections are slow because all injected trees are invalidated on every change. Solution: Implement smarter invalidation to avoid reparsing injected regions. - In on_bytes, try and update self._regions as best we can. This PR just offsets any regions after the change. - Add valid flags for each region in self._regions. - Call on_bytes recursively for all children. - We still need to run the query every time for the top level tree. I don't know how to avoid this. However, if the new injection ranges don't change, then we re-use the old trees and avoid reparsing children. This should result in roughly a 2-3x reduction in tree parsing when the comment injections are enabled.
Diffstat (limited to 'runtime/lua/vim/treesitter/_range.lua')
-rw-r--r--runtime/lua/vim/treesitter/_range.lua126
1 files changed, 126 insertions, 0 deletions
diff --git a/runtime/lua/vim/treesitter/_range.lua b/runtime/lua/vim/treesitter/_range.lua
new file mode 100644
index 0000000000..b87542c20f
--- /dev/null
+++ b/runtime/lua/vim/treesitter/_range.lua
@@ -0,0 +1,126 @@
+local api = vim.api
+
+local M = {}
+
+---@alias Range4 {[1]: integer, [2]: integer, [3]: integer, [4]: integer}
+---@alias Range6 {[1]: integer, [2]: integer, [3]: integer, [4]: integer, [5]: integer, [6]: integer}
+
+---@private
+---@param a_row integer
+---@param a_col integer
+---@param b_row integer
+---@param b_col integer
+---@return integer
+--- 1: a > b
+--- 0: a == b
+--- -1: a < b
+local function cmp_pos(a_row, a_col, b_row, b_col)
+ if a_row == b_row then
+ if a_col > b_col then
+ return 1
+ elseif a_col < b_col then
+ return -1
+ else
+ return 0
+ end
+ elseif a_row > b_row then
+ return 1
+ end
+
+ return -1
+end
+
+M.cmp_pos = {
+ lt = function(...)
+ return cmp_pos(...) == -1
+ end,
+ le = function(...)
+ return cmp_pos(...) ~= 1
+ end,
+ gt = function(...)
+ return cmp_pos(...) == 1
+ end,
+ ge = function(...)
+ return cmp_pos(...) ~= -1
+ end,
+ eq = function(...)
+ return cmp_pos(...) == 0
+ end,
+ ne = function(...)
+ return cmp_pos(...) ~= 0
+ end,
+}
+
+setmetatable(M.cmp_pos, { __call = cmp_pos })
+
+---@private
+---@param r1 Range4|Range6
+---@param r2 Range4|Range6
+---@return boolean
+function M.intercepts(r1, r2)
+ local off_1 = #r1 == 6 and 1 or 0
+ local off_2 = #r1 == 6 and 1 or 0
+
+ local srow_1, scol_1, erow_1, ecol_1 = r1[1], r2[2], r1[3 + off_1], r1[4 + off_1]
+ local srow_2, scol_2, erow_2, ecol_2 = r2[1], r2[2], r2[3 + off_2], r2[4 + off_2]
+
+ -- r1 is above r2
+ if M.cmp_pos.le(erow_1, ecol_1, srow_2, scol_2) then
+ return false
+ end
+
+ -- r1 is below r2
+ if M.cmp_pos.ge(srow_1, scol_1, erow_2, ecol_2) then
+ return false
+ end
+
+ return true
+end
+
+---@private
+---@param r1 Range4|Range6
+---@param r2 Range4|Range6
+---@return boolean whether r1 contains r2
+function M.contains(r1, r2)
+ local off_1 = #r1 == 6 and 1 or 0
+ local off_2 = #r1 == 6 and 1 or 0
+
+ local srow_1, scol_1, erow_1, ecol_1 = r1[1], r2[2], r1[3 + off_1], r1[4 + off_1]
+ local srow_2, scol_2, erow_2, ecol_2 = r2[1], r2[2], r2[3 + off_2], r2[4 + off_2]
+
+ -- start doesn't fit
+ if M.cmp_pos.gt(srow_1, scol_1, srow_2, scol_2) then
+ return false
+ end
+
+ -- end doesn't fit
+ if M.cmp_pos.lt(erow_1, ecol_1, erow_2, ecol_2) then
+ return false
+ end
+
+ return true
+end
+
+---@private
+---@param source integer|string
+---@param range Range4
+---@return Range6
+function M.add_bytes(source, range)
+ local start_row, start_col, end_row, end_col = range[1], range[2], range[3], range[4]
+ local start_byte = 0
+ local end_byte = 0
+ -- TODO(vigoux): proper byte computation here, and account for EOL ?
+ if type(source) == 'number' then
+ -- Easy case, this is a buffer parser
+ start_byte = api.nvim_buf_get_offset(source, start_row) + start_col
+ end_byte = api.nvim_buf_get_offset(source, end_row) + end_col
+ elseif type(source) == 'string' then
+ -- string parser, single `\n` delimited string
+ start_byte = vim.fn.byteidx(source, start_col)
+ end_byte = vim.fn.byteidx(source, end_col)
+ end
+
+ return { start_row, start_col, start_byte, end_row, end_col, end_byte }
+end
+
+return M