aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLewis Russell <lewis6991@gmail.com>2023-05-16 16:41:47 +0100
committerGitHub <noreply@github.com>2023-05-16 16:41:47 +0100
commit6b19170d44ca56cf65542ee184d2bc89c6d622a9 (patch)
tree5881cff78f1f8f05dd9027aa33d936a981244919
parentd36dd2bae8e899b40cc21603e600a5046213bc36 (diff)
downloadrneovim-6b19170d44ca56cf65542ee184d2bc89c6d622a9.tar.gz
rneovim-6b19170d44ca56cf65542ee184d2bc89c6d622a9.tar.bz2
rneovim-6b19170d44ca56cf65542ee184d2bc89c6d622a9.zip
fix(treesitter): correctly calculate bytes for text sources (#23655)
Fixes #20419
-rw-r--r--runtime/lua/vim/treesitter/_range.lua38
-rw-r--r--test/functional/treesitter/parser_spec.lua56
2 files changed, 81 insertions, 13 deletions
diff --git a/runtime/lua/vim/treesitter/_range.lua b/runtime/lua/vim/treesitter/_range.lua
index f4db5016ac..35081c6400 100644
--- a/runtime/lua/vim/treesitter/_range.lua
+++ b/runtime/lua/vim/treesitter/_range.lua
@@ -143,6 +143,29 @@ function M.contains(r1, r2)
return true
end
+--- @param source integer|string
+--- @param index integer
+--- @return integer
+local function get_offset(source, index)
+ if index == 0 then
+ return 0
+ end
+
+ if type(source) == 'number' then
+ return api.nvim_buf_get_offset(source, index)
+ end
+
+ local byte = 0
+ local next_offset = source:gmatch('()\n')
+ local line = 1
+ while line <= index do
+ byte = next_offset() --[[@as integer]]
+ line = line + 1
+ end
+
+ return byte
+end
+
---@private
---@param source integer|string
---@param range Range
@@ -152,19 +175,10 @@ function M.add_bytes(source, range)
return range --[[@as Range6]]
end
- local start_row, start_col, end_row, end_col = range[1], range[2], range[3], range[4]
- local start_byte = 0
- local end_byte = 0
+ local start_row, start_col, end_row, end_col = M.unpack4(range)
-- TODO(vigoux): proper byte computation here, and account for EOL ?
- if type(source) == 'number' then
- -- Easy case, this is a buffer parser
- start_byte = api.nvim_buf_get_offset(source, start_row) + start_col
- end_byte = api.nvim_buf_get_offset(source, end_row) + end_col
- elseif type(source) == 'string' then
- -- string parser, single `\n` delimited string
- start_byte = vim.fn.byteidx(source, start_col)
- end_byte = vim.fn.byteidx(source, end_col)
- end
+ local start_byte = get_offset(source, start_row) + start_col
+ local end_byte = get_offset(source, end_row) + end_col
return { start_row, start_col, start_byte, end_row, end_col, end_byte }
end
diff --git a/test/functional/treesitter/parser_spec.lua b/test/functional/treesitter/parser_spec.lua
index 60042ab2fd..14de07639b 100644
--- a/test/functional/treesitter/parser_spec.lua
+++ b/test/functional/treesitter/parser_spec.lua
@@ -486,7 +486,6 @@ end]]
eq({ 'any-of?', 'contains?', 'eq?', 'has-ancestor?', 'has-parent?', 'is-main?', 'lua-match?', 'match?', 'vim-match?' }, res_list)
end)
-
it('allows to set simple ranges', function()
insert(test_text)
@@ -528,6 +527,7 @@ end]]
eq(range_tbl, { { { 0, 0, 0, 17, 1, 508 } } })
end)
+
it("allows to set complex ranges", function()
insert(test_text)
@@ -992,4 +992,58 @@ int x = INT_MAX;
}, run_query())
end)
+
+ it('handles ranges when source is a multiline string (#20419)', function()
+ local source = [==[
+ vim.cmd[[
+ set number
+ set cmdheight=2
+ set lastsatus=2
+ ]]
+
+ set query = [[;; query
+ ((function_call
+ name: [
+ (identifier) @_cdef_identifier
+ (_ _ (identifier) @_cdef_identifier)
+ ]
+ arguments: (arguments (string content: _ @injection.content)))
+ (#set! injection.language "c")
+ (#eq? @_cdef_identifier "cdef"))
+ ]]
+ ]==]
+
+ local r = exec_lua([[
+ local parser = vim.treesitter.get_string_parser(..., 'lua')
+ parser:parse()
+ local ranges = {}
+ parser:for_each_tree(function(tstree, tree)
+ ranges[tree:lang()] = { tstree:root():range(true) }
+ end)
+ return ranges
+ ]], source)
+
+ eq({
+ lua = { 0, 6, 6, 16, 4, 438 },
+ query = { 6, 20, 113, 15, 6, 431 },
+ vim = { 1, 0, 16, 4, 6, 89 }
+ }, r)
+
+ -- The above ranges are provided directly from treesitter, however query directives may mutate
+ -- the ranges but only provide a Range4. Strip the byte entries from the ranges and make sure
+ -- add_bytes() produces the same result.
+
+ local rb = exec_lua([[
+ local r, source = ...
+ local add_bytes = require('vim.treesitter._range').add_bytes
+ for lang, range in pairs(r) do
+ r[lang] = {range[1], range[2], range[4], range[5]}
+ r[lang] = add_bytes(source, r[lang])
+ end
+ return r
+ ]], r, source)
+
+ eq(rb, r)
+
+ end)
end)