diff options
-rw-r--r-- | runtime/doc/lua.txt | 13 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter.lua | 6 | ||||
-rw-r--r-- | src/nvim/lua/executor.c | 12 | ||||
-rw-r--r-- | src/nvim/lua/treesitter.c | 102 | ||||
-rw-r--r-- | test/functional/lua/treesitter_spec.lua | 61 |
5 files changed, 161 insertions, 33 deletions
diff --git a/runtime/doc/lua.txt b/runtime/doc/lua.txt index 5a49d36503..00126f668b 100644 --- a/runtime/doc/lua.txt +++ b/runtime/doc/lua.txt @@ -512,6 +512,9 @@ retained for the lifetime of a buffer but this is subject to change. A plugin should keep a reference to the parser object as long as it wants incremental updates. +Parser methods *lua-treesitter-parser* + +tsparser:parse() *tsparser:parse()* Whenever you need to access the current syntax tree, parse the buffer: > tstree = parser:parse() @@ -528,6 +531,16 @@ shouldn't be done directly in the change callback anyway as they will be very frequent. Rather a plugin that does any kind of analysis on a tree should use a timer to throttle too frequent updates. +tsparser:set_included_ranges(ranges) *tsparser:set_included_ranges()* + Changes the ranges the parser should consider. This is used for + language injection. `ranges` should be of the form (all zero-based): > + { + {start_node, end_node}, + ... + } +< + NOTE: `start_node` and `end_node` are both inclusive. + Tree methods *lua-treesitter-tree* tstree:root() *tstree:root()* diff --git a/runtime/lua/vim/treesitter.lua b/runtime/lua/vim/treesitter.lua index c502e45bd0..f356673839 100644 --- a/runtime/lua/vim/treesitter.lua +++ b/runtime/lua/vim/treesitter.lua @@ -30,6 +30,12 @@ function Parser:_on_lines(bufnr, _, start_row, old_stop_row, stop_row, old_byte_ self.valid = false end +function Parser:set_included_ranges(ranges) + self._parser:set_included_ranges(ranges) + -- The buffer will need to be parsed again later + self.valid = false +end + local M = { parse_query = vim._ts_parse_query, } diff --git a/src/nvim/lua/executor.c b/src/nvim/lua/executor.c index 327ed6d6b7..4b47b34d8a 100644 --- a/src/nvim/lua/executor.c +++ b/src/nvim/lua/executor.c @@ -1128,21 +1128,11 @@ void ex_luafile(exarg_T *const eap) } } -static int create_tslua_parser(lua_State *L) -{ - if (lua_gettop(L) < 1 || !lua_isstring(L, 1)) { - return luaL_error(L, "string expected"); - } - - const char *lang_name = lua_tostring(L, 1); - return tslua_push_parser(L, lang_name); -} - static void nlua_add_treesitter(lua_State *const lstate) FUNC_ATTR_NONNULL_ALL { tslua_init(lstate); - lua_pushcfunction(lstate, create_tslua_parser); + lua_pushcfunction(lstate, tslua_push_parser); lua_setfield(lstate, -2, "_create_ts_parser"); lua_pushcfunction(lstate, tslua_add_language); diff --git a/src/nvim/lua/treesitter.c b/src/nvim/lua/treesitter.c index 51d9549033..ddf54720a7 100644 --- a/src/nvim/lua/treesitter.c +++ b/src/nvim/lua/treesitter.c @@ -20,6 +20,7 @@ #include "nvim/lua/treesitter.h" #include "nvim/api/private/handle.h" #include "nvim/memline.h" +#include "nvim/buffer.h" typedef struct { TSParser *parser; @@ -41,6 +42,7 @@ static struct luaL_Reg parser_meta[] = { { "parse_buf", parser_parse_buf }, { "edit", parser_edit }, { "tree", parser_tree }, + { "set_included_ranges", parser_set_ranges }, { NULL, NULL } }; @@ -214,8 +216,13 @@ int tslua_inspect_lang(lua_State *L) return 1; } -int tslua_push_parser(lua_State *L, const char *lang_name) +int tslua_push_parser(lua_State *L) { + // Gather language + if (lua_gettop(L) < 1 || !lua_isstring(L, 1)) { + return luaL_error(L, "string expected"); + } + const char *lang_name = lua_tostring(L, 1); TSLanguage *lang = pmap_get(cstr_t)(langs, lang_name); if (!lang) { return luaL_error(L, "no such language: %s", lang_name); @@ -377,6 +384,57 @@ static int parser_edit(lua_State *L) return 0; } +static int parser_set_ranges(lua_State *L) +{ + if (lua_gettop(L) < 2) { + return luaL_error( + L, + "not enough args to parser:set_included_ranges()"); + } + + TSLua_parser *p = parser_check(L); + if (!p || !p->tree) { + return 0; + } + + if (!lua_istable(L, 2)) { + return luaL_error( + L, + "argument for parser:set_included_ranges() should be a table."); + } + + size_t tbl_len = lua_objlen(L, 2); + TSRange *ranges = xmalloc(sizeof(TSRange) * tbl_len); + + + // [ parser, ranges ] + for (size_t index = 0; index < tbl_len; index++) { + lua_rawgeti(L, 2, index + 1); // [ parser, ranges, range ] + + TSNode node; + if (!node_check(L, -1, &node)) { + xfree(ranges); + return luaL_error( + L, + "ranges should be tables of nodes."); + } + lua_pop(L, 1); // [ parser, ranges ] + + ranges[index] = (TSRange) { + .start_point = ts_node_start_point(node), + .end_point = ts_node_end_point(node), + .start_byte = ts_node_start_byte(node), + .end_byte = ts_node_end_byte(node) + }; + } + + // This memcpies ranges, thus we can free it afterwards + ts_parser_set_included_ranges(p->parser, ranges, tbl_len); + xfree(ranges); + + return 0; +} + // Tree methods @@ -459,9 +517,9 @@ static void push_node(lua_State *L, TSNode node, int uindex) lua_setfenv(L, -2); // [udata] } -static bool node_check(lua_State *L, TSNode *res) +static bool node_check(lua_State *L, int index, TSNode *res) { - TSNode *ud = luaL_checkudata(L, 1, "treesitter_node"); + TSNode *ud = luaL_checkudata(L, index, "treesitter_node"); if (ud) { *res = *ud; return true; @@ -473,7 +531,7 @@ static bool node_check(lua_State *L, TSNode *res) static int node_tostring(lua_State *L) { TSNode node; - if (!node_check(L, &node)) { + if (!node_check(L, 1, &node)) { return 0; } lua_pushstring(L, "<node "); @@ -486,7 +544,7 @@ static int node_tostring(lua_State *L) static int node_eq(lua_State *L) { TSNode node; - if (!node_check(L, &node)) { + if (!node_check(L, 1, &node)) { return 0; } // This should only be called if both x and y in "x == y" has the @@ -503,7 +561,7 @@ static int node_eq(lua_State *L) static int node_range(lua_State *L) { TSNode node; - if (!node_check(L, &node)) { + if (!node_check(L, 1, &node)) { return 0; } TSPoint start = ts_node_start_point(node); @@ -518,7 +576,7 @@ static int node_range(lua_State *L) static int node_start(lua_State *L) { TSNode node; - if (!node_check(L, &node)) { + if (!node_check(L, 1, &node)) { return 0; } TSPoint start = ts_node_start_point(node); @@ -532,7 +590,7 @@ static int node_start(lua_State *L) static int node_end(lua_State *L) { TSNode node; - if (!node_check(L, &node)) { + if (!node_check(L, 1, &node)) { return 0; } TSPoint end = ts_node_end_point(node); @@ -546,7 +604,7 @@ static int node_end(lua_State *L) static int node_child_count(lua_State *L) { TSNode node; - if (!node_check(L, &node)) { + if (!node_check(L, 1, &node)) { return 0; } uint32_t count = ts_node_child_count(node); @@ -557,7 +615,7 @@ static int node_child_count(lua_State *L) static int node_named_child_count(lua_State *L) { TSNode node; - if (!node_check(L, &node)) { + if (!node_check(L, 1, &node)) { return 0; } uint32_t count = ts_node_named_child_count(node); @@ -568,7 +626,7 @@ static int node_named_child_count(lua_State *L) static int node_type(lua_State *L) { TSNode node; - if (!node_check(L, &node)) { + if (!node_check(L, 1, &node)) { return 0; } lua_pushstring(L, ts_node_type(node)); @@ -578,7 +636,7 @@ static int node_type(lua_State *L) static int node_symbol(lua_State *L) { TSNode node; - if (!node_check(L, &node)) { + if (!node_check(L, 1, &node)) { return 0; } TSSymbol symbol = ts_node_symbol(node); @@ -589,7 +647,7 @@ static int node_symbol(lua_State *L) static int node_named(lua_State *L) { TSNode node; - if (!node_check(L, &node)) { + if (!node_check(L, 1, &node)) { return 0; } lua_pushboolean(L, ts_node_is_named(node)); @@ -599,7 +657,7 @@ static int node_named(lua_State *L) static int node_sexpr(lua_State *L) { TSNode node; - if (!node_check(L, &node)) { + if (!node_check(L, 1, &node)) { return 0; } char *allocated = ts_node_string(node); @@ -611,7 +669,7 @@ static int node_sexpr(lua_State *L) static int node_missing(lua_State *L) { TSNode node; - if (!node_check(L, &node)) { + if (!node_check(L, 1, &node)) { return 0; } lua_pushboolean(L, ts_node_is_missing(node)); @@ -621,7 +679,7 @@ static int node_missing(lua_State *L) static int node_has_error(lua_State *L) { TSNode node; - if (!node_check(L, &node)) { + if (!node_check(L, 1, &node)) { return 0; } lua_pushboolean(L, ts_node_has_error(node)); @@ -631,7 +689,7 @@ static int node_has_error(lua_State *L) static int node_child(lua_State *L) { TSNode node; - if (!node_check(L, &node)) { + if (!node_check(L, 1, &node)) { return 0; } long num = lua_tointeger(L, 2); @@ -644,7 +702,7 @@ static int node_child(lua_State *L) static int node_named_child(lua_State *L) { TSNode node; - if (!node_check(L, &node)) { + if (!node_check(L, 1, &node)) { return 0; } long num = lua_tointeger(L, 2); @@ -657,7 +715,7 @@ static int node_named_child(lua_State *L) static int node_descendant_for_range(lua_State *L) { TSNode node; - if (!node_check(L, &node)) { + if (!node_check(L, 1, &node)) { return 0; } TSPoint start = { (uint32_t)lua_tointeger(L, 2), @@ -673,7 +731,7 @@ static int node_descendant_for_range(lua_State *L) static int node_named_descendant_for_range(lua_State *L) { TSNode node; - if (!node_check(L, &node)) { + if (!node_check(L, 1, &node)) { return 0; } TSPoint start = { (uint32_t)lua_tointeger(L, 2), @@ -689,7 +747,7 @@ static int node_named_descendant_for_range(lua_State *L) static int node_parent(lua_State *L) { TSNode node; - if (!node_check(L, &node)) { + if (!node_check(L, 1, &node)) { return 0; } TSNode parent = ts_node_parent(node); @@ -771,7 +829,7 @@ static int query_next_capture(lua_State *L) static int node_rawquery(lua_State *L) { TSNode node; - if (!node_check(L, &node)) { + if (!node_check(L, 1, &node)) { return 0; } TSQuery *query = query_check(L, 2); diff --git a/test/functional/lua/treesitter_spec.lua b/test/functional/lua/treesitter_spec.lua index ecee471386..ab0224a911 100644 --- a/test/functional/lua/treesitter_spec.lua +++ b/test/functional/lua/treesitter_spec.lua @@ -404,4 +404,65 @@ static int nlua_schedule(lua_State *const lstate) end eq({true,true}, {has_named,has_anonymous}) end) + it('allows to set simple ranges', function() + if not check_parser() then return end + + insert(test_text) + + local res = exec_lua([[ + parser = vim.treesitter.get_parser(0, "c") + return { parser:parse():root():range() } + ]]) + + eq({0, 0, 19, 0}, res) + + -- The following sets the included ranges for the current parser + -- As stated here, this only includes the function (thus the whole buffer, without the last line) + local res2 = exec_lua([[ + local root = parser:parse():root() + parser:set_included_ranges({root:child(0)}) + parser.valid = false + return { parser:parse():root():range() } + ]]) + + eq({0, 0, 18, 1}, res2) + end) + it("allows to set complex ranges", function() + if not check_parser() then return end + + insert(test_text) + + + local res = exec_lua([[ + parser = vim.treesitter.get_parser(0, "c") + query = vim.treesitter.parse_query("c", "(declaration) @decl") + + local nodes = {} + for _, node in query:iter_captures(parser:parse():root(), 0, 0, 19) do + table.insert(nodes, node) + end + + parser:set_included_ranges(nodes) + + local root = parser:parse():root() + + local res = {} + for i=0,(root:named_child_count() - 1) do + table.insert(res, { root:named_child(i):range() }) + end + return res + ]]) + + eq({ + { 2, 2, 2, 40 }, + { 3, 3, 3, 32 }, + { 4, 7, 4, 8 }, + { 4, 8, 4, 25 }, + { 8, 2, 8, 6 }, + { 8, 7, 8, 33 }, + { 9, 8, 9, 20 }, + { 10, 4, 10, 5 }, + { 10, 5, 10, 20 }, + { 14, 9, 14, 27 } }, res) + end) end) |