diff options
-rw-r--r-- | runtime/lua/vim/treesitter/languagetree.lua | 80 | ||||
-rw-r--r-- | test/functional/treesitter/parser_spec.lua | 34 |
2 files changed, 44 insertions, 70 deletions
diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua index f2e745ec65..e7cee33a03 100644 --- a/runtime/lua/vim/treesitter/languagetree.lua +++ b/runtime/lua/vim/treesitter/languagetree.lua @@ -868,35 +868,42 @@ end ---@alias vim.treesitter.languagetree.Injection table<string,table<integer,vim.treesitter.languagetree.InjectionElem>> ----@param t table<integer,vim.treesitter.languagetree.Injection> ----@param tree_index integer +---@param t vim.treesitter.languagetree.Injection ---@param pattern integer ---@param lang string ---@param combined boolean ---@param ranges Range6[] -local function add_injection(t, tree_index, pattern, lang, combined, ranges) +---@param result table<string,Range6[][]> +local function add_injection(t, pattern, lang, combined, ranges, result) if #ranges == 0 then -- Make sure not to add an empty range set as this is interpreted to mean the whole buffer. return end - -- Each tree index should be isolated from the other nodes. - if not t[tree_index] then - t[tree_index] = {} + if not result[lang] then + result[lang] = {} end - if not t[tree_index][lang] then - t[tree_index][lang] = {} + if not combined then + table.insert(result[lang], ranges) + return + end + + if not t[lang] then + t[lang] = {} end - -- Key this by pattern. If combined is set to true all captures of this pattern + -- Key this by pattern. For combined injections, all captures of this pattern -- will be parsed by treesitter as the same "source". - -- If combined is false, each "region" will be parsed as a single source. - if not t[tree_index][lang][pattern] then - t[tree_index][lang][pattern] = { combined = combined, regions = {} } + if not t[lang][pattern] then + local regions = {} + t[lang][pattern] = regions + table.insert(result[lang], regions) end - table.insert(t[tree_index][lang][pattern].regions, ranges) + for _, range in ipairs(ranges) do + table.insert(t[lang][pattern], range) + end end -- TODO(clason): replace by refactored `ts.has_parser` API (without side effects) @@ -964,19 +971,6 @@ function LanguageTree:_get_injection(match, metadata) return lang, combined, ranges end ---- Can't use vim.tbl_flatten since a range is just a table. ----@param regions Range6[][] ----@return Range6[] -local function combine_regions(regions) - local result = {} ---@type Range6[] - for _, region in ipairs(regions) do - for _, range in ipairs(region) do - result[#result + 1] = range - end - end - return result -end - --- Gets language injection regions by language. --- --- This is where most of the injection processing occurs. @@ -993,13 +987,16 @@ function LanguageTree:_get_injections(range, thread_state) return {} end - ---@type table<integer,vim.treesitter.languagetree.Injection> - local injections = {} local start = vim.uv.hrtime() + ---@type table<string,Range6[][]> + local result = {} + local full_scan = range == true or self._injection_query.has_combined_injections - for index, tree in pairs(self._trees) do + for _, tree in pairs(self._trees) do + ---@type vim.treesitter.languagetree.Injection + local injections = {} local root_node = tree:root() local start_line, end_line ---@type integer, integer if full_scan then @@ -1013,7 +1010,7 @@ function LanguageTree:_get_injections(range, thread_state) do local lang, combined, ranges = self:_get_injection(match, metadata) if lang then - add_injection(injections, index, pattern, lang, combined, ranges) + add_injection(injections, pattern, lang, combined, ranges, result) else self:_log('match from injection query failed for pattern', pattern) end @@ -1025,29 +1022,6 @@ function LanguageTree:_get_injections(range, thread_state) end end - ---@type table<string,Range6[][]> - local result = {} - - -- Generate a map by lang of node lists. - -- Each list is a set of ranges that should be parsed together. - for _, lang_map in pairs(injections) do - for lang, patterns in pairs(lang_map) do - if not result[lang] then - result[lang] = {} - end - - for _, entry in pairs(patterns) do - if entry.combined then - table.insert(result[lang], combine_regions(entry.regions)) - else - for _, ranges in pairs(entry.regions) do - table.insert(result[lang], ranges) - end - end - end - end - end - if full_scan then self._processed_injection_range = entire_document_range else diff --git a/test/functional/treesitter/parser_spec.lua b/test/functional/treesitter/parser_spec.lua index b348f77b38..510eacb958 100644 --- a/test/functional/treesitter/parser_spec.lua +++ b/test/functional/treesitter/parser_spec.lua @@ -575,22 +575,22 @@ int x = INT_MAX; eq(5, exec_lua('return #parser:children().c:trees()')) eq({ { 0, 0, 7, 0 }, -- root tree + { 1, 26, 1, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y)) + { 2, 29, 2, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y)) { 3, 14, 3, 17 }, -- VALUE 123 { 4, 15, 4, 18 }, -- VALUE1 123 { 5, 15, 5, 18 }, -- VALUE2 123 - { 1, 26, 1, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y)) - { 2, 29, 2, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y)) }, get_ranges()) n.feed('ggo<esc>') eq(5, exec_lua('return #parser:children().c:trees()')) eq({ { 0, 0, 8, 0 }, -- root tree + { 2, 26, 2, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y)) + { 3, 29, 3, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y)) { 4, 14, 4, 17 }, -- VALUE 123 { 5, 15, 5, 18 }, -- VALUE1 123 { 6, 15, 6, 18 }, -- VALUE2 123 - { 2, 26, 2, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y)) - { 3, 29, 3, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y)) }, get_ranges()) end) end) @@ -613,11 +613,11 @@ int x = INT_MAX; eq(2, exec_lua('return #parser:children().c:trees()')) eq({ { 0, 0, 7, 0 }, -- root tree + { 1, 26, 2, 66 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y)) + -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y)) { 3, 14, 5, 18 }, -- VALUE 123 -- VALUE1 123 -- VALUE2 123 - { 1, 26, 2, 66 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y)) - -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y)) }, get_ranges()) n.feed('ggo<esc>') @@ -625,11 +625,11 @@ int x = INT_MAX; eq(2, exec_lua('return #parser:children().c:trees()')) eq({ { 0, 0, 8, 0 }, -- root tree - { 4, 14, 6, 18 }, -- VALUE 123 - -- VALUE1 123 - -- VALUE2 123 { 2, 26, 3, 66 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y)) -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y)) + -- VALUE 123 + { 4, 14, 6, 18 }, -- VALUE1 123 + -- VALUE2 123 }, get_ranges()) n.feed('7ggI//<esc>') @@ -638,10 +638,10 @@ int x = INT_MAX; eq(2, exec_lua('return #parser:children().c:trees()')) eq({ { 0, 0, 8, 0 }, -- root tree - { 4, 14, 5, 18 }, -- VALUE 123 - -- VALUE1 123 { 2, 26, 3, 66 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y)) -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y)) + -- VALUE 123 + { 4, 14, 5, 18 }, -- VALUE1 123 }, get_ranges()) end) @@ -794,22 +794,22 @@ int x = INT_MAX; eq(5, exec_lua('return #parser:children().c:trees()')) eq({ { 0, 0, 7, 0 }, -- root tree + { 1, 26, 1, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y)) + { 2, 29, 2, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y)) { 3, 14, 3, 17 }, -- VALUE 123 { 4, 15, 4, 18 }, -- VALUE1 123 { 5, 15, 5, 18 }, -- VALUE2 123 - { 1, 26, 1, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y)) - { 2, 29, 2, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y)) }, get_ranges()) n.feed('ggo<esc>') eq(5, exec_lua('return #parser:children().c:trees()')) eq({ { 0, 0, 8, 0 }, -- root tree + { 2, 26, 2, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y)) + { 3, 29, 3, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y)) { 4, 14, 4, 17 }, -- VALUE 123 { 5, 15, 5, 18 }, -- VALUE1 123 { 6, 15, 6, 18 }, -- VALUE2 123 - { 2, 26, 2, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y)) - { 3, 29, 3, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y)) }, get_ranges()) end) end) @@ -831,11 +831,11 @@ int x = INT_MAX; eq('table', exec_lua('return type(parser:children().c)')) eq({ { 0, 0, 7, 0 }, -- root tree + { 1, 26, 1, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y)) + { 2, 29, 2, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y)) { 3, 16, 3, 16 }, -- VALUE 123 { 4, 17, 4, 17 }, -- VALUE1 123 { 5, 17, 5, 17 }, -- VALUE2 123 - { 1, 26, 1, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y)) - { 2, 29, 2, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y)) }, get_ranges()) end) it('should list all directives', function() |