diff options
-rw-r--r-- | runtime/lua/vim/treesitter/languagetree.lua | 19 | ||||
-rw-r--r-- | src/nvim/lua/treesitter.c | 170 | ||||
-rw-r--r-- | test/functional/lua/treesitter_spec.lua | 8 |
3 files changed, 137 insertions, 60 deletions
diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua index ed07e73a55..70e2ac4c62 100644 --- a/runtime/lua/vim/treesitter/languagetree.lua +++ b/runtime/lua/vim/treesitter/languagetree.lua @@ -1,3 +1,4 @@ +local a = vim.api local query = require'vim.treesitter.query' local language = require'vim.treesitter.language' @@ -234,6 +235,24 @@ end -- -- @param regions A list of regions this tree should manange and parse. function LanguageTree:set_included_regions(regions) + -- Transform the tables from 4 element long to 6 element long (with byte offset) + for _, region in ipairs(regions) do + for i, range in ipairs(region) do + if type(range) == "table" and #range == 4 then + -- TODO(vigoux): I don't think string parsers are useful for now + if type(self._source) == "number" then + local start_row, start_col, end_row, end_col = unpack(range) + -- Easy case, this is a buffer parser + -- TODO(vigoux): proper byte computation here, and account for EOL ? + local start_byte = a.nvim_buf_get_offset(self.bufnr, start_row) + start_col + local end_byte = a.nvim_buf_get_offset(self.bufnr, end_row) + end_col + + region[i] = { start_row, start_col, start_byte, end_row, end_col, end_byte } + end + end + end + end + self._regions = regions -- Trees are no longer valid now that we have changed regions. -- TODO(vigoux,steelsojka): Look into doing this smarter so we can use some of the diff --git a/src/nvim/lua/treesitter.c b/src/nvim/lua/treesitter.c index a9a57d386b..a640b97d3b 100644 --- a/src/nvim/lua/treesitter.c +++ b/src/nvim/lua/treesitter.c @@ -22,6 +22,13 @@ #include "nvim/memline.h" #include "nvim/buffer.h" +#define TS_META_PARSER "treesitter_parser" +#define TS_META_TREE "treesitter_tree" +#define TS_META_NODE "treesitter_node" +#define TS_META_QUERY "treesitter_query" +#define TS_META_QUERYCURSOR "treesitter_querycursor" +#define TS_META_TREECURSOR "treesitter_treecursor" + typedef struct { TSQueryCursor *cursor; int predicated_match; @@ -115,12 +122,12 @@ void tslua_init(lua_State *L) langs = pmap_new(cstr_t)(); // type metatables - build_meta(L, "treesitter_parser", parser_meta); - build_meta(L, "treesitter_tree", tree_meta); - build_meta(L, "treesitter_node", node_meta); - build_meta(L, "treesitter_query", query_meta); - build_meta(L, "treesitter_querycursor", querycursor_meta); - build_meta(L, "treesitter_treecursor", treecursor_meta); + build_meta(L, TS_META_PARSER, parser_meta); + build_meta(L, TS_META_TREE, tree_meta); + build_meta(L, TS_META_NODE, node_meta); + build_meta(L, TS_META_QUERY, query_meta); + build_meta(L, TS_META_QUERYCURSOR, querycursor_meta); + build_meta(L, TS_META_TREECURSOR, treecursor_meta); } int tslua_has_language(lua_State *L) @@ -132,12 +139,8 @@ int tslua_has_language(lua_State *L) int tslua_add_language(lua_State *L) { - if (lua_gettop(L) < 2 || !lua_isstring(L, 1) || !lua_isstring(L, 2)) { - return luaL_error(L, "string expected"); - } - - const char *path = lua_tostring(L, 1); - const char *lang_name = lua_tostring(L, 2); + const char *path = luaL_checkstring(L, 1); + const char *lang_name = luaL_checkstring(L, 2); if (pmap_has(cstr_t)(langs, lang_name)) { return 0; @@ -176,8 +179,9 @@ int tslua_add_language(lua_State *L) || lang_version > TREE_SITTER_LANGUAGE_VERSION) { return luaL_error( L, - "ABI version mismatch : expected %d, found %d", - TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION, lang_version); + "ABI version mismatch : supported between %d and %d, found %d", + TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION, + TREE_SITTER_LANGUAGE_VERSION, lang_version); } pmap_put(cstr_t)(langs, xstrdup(lang_name), lang); @@ -188,10 +192,7 @@ int tslua_add_language(lua_State *L) int tslua_inspect_lang(lua_State *L) { - if (lua_gettop(L) < 1 || !lua_isstring(L, 1)) { - return luaL_error(L, "string expected"); - } - const char *lang_name = lua_tostring(L, 1); + const char *lang_name = luaL_checkstring(L, 1); TSLanguage *lang = pmap_get(cstr_t)(langs, lang_name); if (!lang) { @@ -232,11 +233,9 @@ int tslua_inspect_lang(lua_State *L) int tslua_push_parser(lua_State *L) { - // Gather language - if (lua_gettop(L) < 1 || !lua_isstring(L, 1)) { - return luaL_error(L, "string expected"); - } - const char *lang_name = lua_tostring(L, 1); + // Gather language name + const char *lang_name = luaL_checkstring(L, 1); + TSLanguage *lang = pmap_get(cstr_t)(langs, lang_name); if (!lang) { return luaL_error(L, "no such language: %s", lang_name); @@ -250,14 +249,14 @@ int tslua_push_parser(lua_State *L) return luaL_error(L, "Failed to load language : %s", lang_name); } - lua_getfield(L, LUA_REGISTRYINDEX, "treesitter_parser"); // [udata, meta] + lua_getfield(L, LUA_REGISTRYINDEX, TS_META_PARSER); // [udata, meta] lua_setmetatable(L, -2); // [udata] return 1; } static TSParser ** parser_check(lua_State *L, uint16_t index) { - return luaL_checkudata(L, index, "treesitter_parser"); + return luaL_checkudata(L, index, TS_META_PARSER); } static int parser_gc(lua_State *L) @@ -388,7 +387,7 @@ static int parser_parse(lua_State *L) TSRange *changed = old_tree ? ts_tree_get_changed_ranges( old_tree, new_tree, &n_ranges) : NULL; - tslua_push_tree(L, new_tree, false); // [tree] + push_tree(L, new_tree, false); // [tree] push_ranges(L, changed, n_ranges); // [tree, ranges] @@ -403,7 +402,7 @@ static int tree_copy(lua_State *L) return 0; } - tslua_push_tree(L, *tree, true); // [tree] + push_tree(L, *tree, true); // [tree] return 1; } @@ -435,6 +434,72 @@ static int tree_edit(lua_State *L) return 0; } +// Use the top of the stack (without popping it) to create a TSRange, it can be +// either a lua table or a TSNode +static void range_from_lua(lua_State *L, TSRange *range) +{ + TSNode node; + + if (lua_istable(L, -1)) { + // should be a table of 6 elements + if (lua_objlen(L, -1) != 6) { + goto error; + } + + uint32_t start_row, start_col, start_byte, end_row, end_col, end_byte; + lua_rawgeti(L, -1, 1); // [ range, start_row] + start_row = luaL_checkinteger(L, -1); + lua_pop(L, 1); + + lua_rawgeti(L, -1, 2); // [ range, start_col] + start_col = luaL_checkinteger(L, -1); + lua_pop(L, 1); + + lua_rawgeti(L, -1, 3); // [ range, start_byte] + start_byte = luaL_checkinteger(L, -1); + lua_pop(L, 1); + + lua_rawgeti(L, -1, 4); // [ range, end_row] + end_row = luaL_checkinteger(L, -1); + lua_pop(L, 1); + + lua_rawgeti(L, -1, 5); // [ range, end_col] + end_col = luaL_checkinteger(L, -1); + lua_pop(L, 1); + + lua_rawgeti(L, -1, 6); // [ range, end_byte] + end_byte = luaL_checkinteger(L, -1); + lua_pop(L, 1); // [ range ] + + *range = (TSRange) { + .start_point = (TSPoint) { + .row = start_row, + .column = start_col + }, + .end_point = (TSPoint) { + .row = end_row, + .column = end_col + }, + .start_byte = start_byte, + .end_byte = end_byte, + }; + } else if (node_check(L, -1, &node)) { + *range = (TSRange) { + .start_point = ts_node_start_point(node), + .end_point = ts_node_end_point(node), + .start_byte = ts_node_start_byte(node), + .end_byte = ts_node_end_byte(node) + }; + } else { + goto error; + } + return; +error: + luaL_error( + L, + "Ranges can only be made from 6 element long tables or nodes."); +} + static int parser_set_ranges(lua_State *L) { if (lua_gettop(L) < 2) { @@ -461,22 +526,8 @@ static int parser_set_ranges(lua_State *L) // [ parser, ranges ] for (size_t index = 0; index < tbl_len; index++) { lua_rawgeti(L, 2, index + 1); // [ parser, ranges, range ] - - TSNode node; - if (!node_check(L, -1, &node)) { - xfree(ranges); - return luaL_error( - L, - "ranges should be tables of nodes."); - } - lua_pop(L, 1); // [ parser, ranges ] - - ranges[index] = (TSRange) { - .start_point = ts_node_start_point(node), - .end_point = ts_node_end_point(node), - .start_byte = ts_node_start_byte(node), - .end_byte = ts_node_end_byte(node) - }; + range_from_lua(L, ranges + index); + lua_pop(L, 1); } // This memcpies ranges, thus we can free it afterwards @@ -506,7 +557,7 @@ static int parser_get_ranges(lua_State *L) /// push tree interface on lua stack. /// /// This makes a copy of the tree, so ownership of the argument is unaffected. -void tslua_push_tree(lua_State *L, TSTree *tree, bool do_copy) +void push_tree(lua_State *L, TSTree *tree, bool do_copy) { if (tree == NULL) { lua_pushnil(L); @@ -520,7 +571,7 @@ void tslua_push_tree(lua_State *L, TSTree *tree, bool do_copy) *ud = tree; } - lua_getfield(L, LUA_REGISTRYINDEX, "treesitter_tree"); // [udata, meta] + lua_getfield(L, LUA_REGISTRYINDEX, TS_META_TREE); // [udata, meta] lua_setmetatable(L, -2); // [udata] // table used for node wrappers to keep a reference to tree wrapper @@ -534,7 +585,7 @@ void tslua_push_tree(lua_State *L, TSTree *tree, bool do_copy) static TSTree **tree_check(lua_State *L, uint16_t index) { - TSTree **ud = luaL_checkudata(L, index, "treesitter_tree"); + TSTree **ud = luaL_checkudata(L, index, TS_META_TREE); return ud; } @@ -582,7 +633,7 @@ static void push_node(lua_State *L, TSNode node, int uindex) } TSNode *ud = lua_newuserdata(L, sizeof(TSNode)); // [udata] *ud = node; - lua_getfield(L, LUA_REGISTRYINDEX, "treesitter_node"); // [udata, meta] + lua_getfield(L, LUA_REGISTRYINDEX, TS_META_NODE); // [udata, meta] lua_setmetatable(L, -2); // [udata] lua_getfenv(L, uindex); // [udata, reftable] lua_setfenv(L, -2); // [udata] @@ -590,7 +641,7 @@ static void push_node(lua_State *L, TSNode node, int uindex) static bool node_check(lua_State *L, int index, TSNode *res) { - TSNode *ud = luaL_checkudata(L, index, "treesitter_node"); + TSNode *ud = luaL_checkudata(L, index, TS_META_NODE); if (ud) { *res = *ud; return true; @@ -618,13 +669,12 @@ static int node_eq(lua_State *L) if (!node_check(L, 1, &node)) { return 0; } - // This should only be called if both x and y in "x == y" has the - // treesitter_node metatable. So it is ok to error out otherwise. - TSNode *ud = luaL_checkudata(L, 2, "treesitter_node"); - if (!ud) { + + TSNode node2; + if (!node_check(L, 2, &node2)) { return 0; } - TSNode node2 = *ud; + lua_pushboolean(L, ts_node_eq(node, node2)); return 1; } @@ -859,7 +909,7 @@ static int node_named_descendant_for_range(lua_State *L) static int node_next_child(lua_State *L) { TSTreeCursor *ud = luaL_checkudata( - L, lua_upvalueindex(1), "treesitter_treecursor"); + L, lua_upvalueindex(1), TS_META_TREECURSOR); if (!ud) { return 0; } @@ -909,7 +959,7 @@ static int node_iter_children(lua_State *L) TSTreeCursor *ud = lua_newuserdata(L, sizeof(TSTreeCursor)); // [udata] *ud = ts_tree_cursor_new(source); - lua_getfield(L, LUA_REGISTRYINDEX, "treesitter_treecursor"); // [udata, mt] + lua_getfield(L, LUA_REGISTRYINDEX, TS_META_TREECURSOR); // [udata, mt] lua_setmetatable(L, -2); // [udata] lua_pushvalue(L, 1); // [udata, source_node] lua_pushcclosure(L, node_next_child, 2); @@ -919,7 +969,7 @@ static int node_iter_children(lua_State *L) static int treecursor_gc(lua_State *L) { - TSTreeCursor *ud = luaL_checkudata(L, 1, "treesitter_treecursor"); + TSTreeCursor *ud = luaL_checkudata(L, 1, TS_META_TREECURSOR); ts_tree_cursor_delete(ud); return 0; } @@ -1031,7 +1081,7 @@ static int node_rawquery(lua_State *L) ud->cursor = cursor; ud->predicated_match = -1; - lua_getfield(L, LUA_REGISTRYINDEX, "treesitter_querycursor"); + lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERYCURSOR); lua_setmetatable(L, -2); // [udata] lua_pushvalue(L, 1); // [udata, node] @@ -1051,7 +1101,7 @@ static int node_rawquery(lua_State *L) static int querycursor_gc(lua_State *L) { - TSLua_cursor *ud = luaL_checkudata(L, 1, "treesitter_querycursor"); + TSLua_cursor *ud = luaL_checkudata(L, 1, TS_META_QUERYCURSOR); ts_query_cursor_delete(ud->cursor); return 0; } @@ -1084,7 +1134,7 @@ int ts_lua_parse_query(lua_State *L) TSQuery **ud = lua_newuserdata(L, sizeof(TSQuery *)); // [udata] *ud = query; - lua_getfield(L, LUA_REGISTRYINDEX, "treesitter_query"); // [udata, meta] + lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERY); // [udata, meta] lua_setmetatable(L, -2); // [udata] return 1; } @@ -1102,7 +1152,7 @@ static const char *query_err_string(TSQueryError err) { static TSQuery *query_check(lua_State *L, int index) { - TSQuery **ud = luaL_checkudata(L, index, "treesitter_query"); + TSQuery **ud = luaL_checkudata(L, index, TS_META_QUERY); return *ud; } diff --git a/test/functional/lua/treesitter_spec.lua b/test/functional/lua/treesitter_spec.lua index 273a5119cb..65dc1b3e03 100644 --- a/test/functional/lua/treesitter_spec.lua +++ b/test/functional/lua/treesitter_spec.lua @@ -781,6 +781,14 @@ local hl_query = [[ ]] eq(range, { { 0, 0, 18, 1 } }) + + local range_tbl = exec_lua [[ + parser:set_included_regions { { { 0, 0, 17, 1 } } } + parser:parse() + return parser:included_regions() + ]] + + eq(range_tbl, { { { 0, 0, 0, 17, 1, 508 } } }) end) it("allows to set complex ranges", function() if not check_parser() then return end |