diff options
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 6 | ||||
-rw-r--r-- | test/functional/treesitter/parser_spec.lua | 50 |
2 files changed, 55 insertions, 1 deletions
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index ed5146be44..51d538c0ff 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -253,7 +253,11 @@ local directive_handlers = { ["set!"] = function(_, _, _, pred, metadata) if #pred == 4 then -- (#set! @capture "key" "value") - metadata[pred[2]][pred[3]] = pred[4] + local capture = pred[2] + if not metadata[capture] then + metadata[capture] = {} + end + metadata[capture][pred[3]] = pred[4] else -- (#set! "key" "value") metadata[pred[2]] = pred[3] diff --git a/test/functional/treesitter/parser_spec.lua b/test/functional/treesitter/parser_spec.lua index 72ff6f2fb6..f267f9fb5d 100644 --- a/test/functional/treesitter/parser_spec.lua +++ b/test/functional/treesitter/parser_spec.lua @@ -599,6 +599,56 @@ int x = INT_MAX; eq(result, "value") end) + + describe("when setting a key on a capture", function() + it("it should create the nested table", function() + insert([[ + int x = 3; + ]]) + + local result = exec_lua([[ + local query = require("vim.treesitter.query") + local value + + query = vim.treesitter.parse_query("c", '((number_literal) @number (#set! @number "key" "value"))') + parser = vim.treesitter.get_parser(0, "c") + + for pattern, match, metadata in query:iter_matches(parser:parse()[1]:root(), 0) do + for _, nested_tbl in pairs(metadata) do + return nested_tbl.key + end + end + ]]) + + eq(result, "value") + end) + + it("it should not overwrite the nested table", function() + insert([[ + int x = 3; + ]]) + + local result = exec_lua([[ + local query = require("vim.treesitter.query") + local result + + query = vim.treesitter.parse_query("c", '((number_literal) @number (#set! @number "key" "value") (#set! @number "key2" "value2"))') + parser = vim.treesitter.get_parser(0, "c") + + for pattern, match, metadata in query:iter_matches(parser:parse()[1]:root(), 0) do + for _, nested_tbl in pairs(metadata) do + return nested_tbl + end + end + ]]) + local expected = { + ["key"] = "value", + ["key2"] = "value2", + } + + eq(expected, result) + end) + end) end) end) end) |