diff options
-rw-r--r-- | runtime/doc/treesitter.txt | 4 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/highlighter.lua | 63 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/languagetree.lua | 17 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 10 | ||||
-rw-r--r-- | src/nvim/mapping.c | 158 | ||||
-rw-r--r-- | test/functional/treesitter/highlight_spec.lua | 117 |
6 files changed, 181 insertions, 188 deletions
diff --git a/runtime/doc/treesitter.txt b/runtime/doc/treesitter.txt index 64b4ca7ca2..ee34c45cce 100644 --- a/runtime/doc/treesitter.txt +++ b/runtime/doc/treesitter.txt @@ -976,8 +976,8 @@ Query:iter_captures({node}, {source}, {start}, {stop}) • {stop} (integer) Stopping line for the search (end-exclusive) Return: ~ - (fun(): integer, TSNode, TSMetadata): capture id, capture node, - metadata + (fun(end_line: integer|nil): integer, TSNode, TSMetadata): capture id, + capture node, metadata *Query:iter_matches()* Query:iter_matches({node}, {source}, {start}, {stop}, {opts}) diff --git a/runtime/lua/vim/treesitter/highlighter.lua b/runtime/lua/vim/treesitter/highlighter.lua index 8d4d6a9337..496193c6ed 100644 --- a/runtime/lua/vim/treesitter/highlighter.lua +++ b/runtime/lua/vim/treesitter/highlighter.lua @@ -2,7 +2,7 @@ local api = vim.api local query = vim.treesitter.query local Range = require('vim.treesitter._range') ----@alias TSHlIter fun(): integer, TSNode, TSMetadata +---@alias TSHlIter fun(end_line: integer|nil): integer, TSNode, TSMetadata ---@class TSHighlightState ---@field next_row integer @@ -241,40 +241,43 @@ local function on_line_impl(self, buf, line, is_spell_nav) end while line >= state.next_row do - local capture, node, metadata = state.iter() + local capture, node, metadata = state.iter(line) - if capture == nil then - break + local range = { root_end_row + 1, 0, root_end_row + 1, 0 } + if node then + range = vim.treesitter.get_range(node, buf, metadata and metadata[capture]) end - - local range = vim.treesitter.get_range(node, buf, metadata[capture]) local start_row, start_col, end_row, end_col = Range.unpack4(range) - local hl = highlighter_query.hl_cache[capture] - - local capture_name = highlighter_query:query().captures[capture] - local spell = nil ---@type boolean? - if capture_name == 'spell' then - spell = true - elseif capture_name == 'nospell' then - spell = false - end - -- Give nospell a higher priority so it always overrides spell captures. - local spell_pri_offset = capture_name == 'nospell' and 1 or 0 - - if hl and end_row >= line and (not is_spell_nav or spell ~= nil) then - local priority = (tonumber(metadata.priority) or vim.highlight.priorities.treesitter) - + spell_pri_offset - api.nvim_buf_set_extmark(buf, ns, start_row, start_col, { - end_line = end_row, - end_col = end_col, - hl_group = hl, - ephemeral = true, - priority = priority, - conceal = metadata.conceal, - spell = spell, - }) + if capture then + local hl = highlighter_query.hl_cache[capture] + + local capture_name = highlighter_query:query().captures[capture] + local spell = nil ---@type boolean? + if capture_name == 'spell' then + spell = true + elseif capture_name == 'nospell' then + spell = false + end + + -- Give nospell a higher priority so it always overrides spell captures. + local spell_pri_offset = capture_name == 'nospell' and 1 or 0 + + if hl and end_row >= line and (not is_spell_nav or spell ~= nil) then + local priority = (tonumber(metadata.priority) or vim.highlight.priorities.treesitter) + + spell_pri_offset + api.nvim_buf_set_extmark(buf, ns, start_row, start_col, { + end_line = end_row, + end_col = end_col, + hl_group = hl, + ephemeral = true, + priority = priority, + conceal = metadata.conceal, + spell = spell, + }) + end end + if start_row > line then state.next_row = start_row end diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua index 79f36a27fd..4dd5a18396 100644 --- a/runtime/lua/vim/treesitter/languagetree.lua +++ b/runtime/lua/vim/treesitter/languagetree.lua @@ -734,7 +734,8 @@ local has_parser = function(lang) or #vim.api.nvim_get_runtime_file('parser/' .. lang .. '.*', false) > 0 end ---- Return parser name for language (if exists) or filetype (if registered and exists) +--- Return parser name for language (if exists) or filetype (if registered and exists). +--- Also attempts with the input lower-cased. --- ---@param alias string language or filetype name ---@return string? # resolved parser name @@ -743,10 +744,19 @@ local function resolve_lang(alias) return alias end + if has_parser(alias:lower()) then + return alias:lower() + end + local lang = vim.treesitter.language.get_lang(alias) if lang and has_parser(lang) then return lang end + + lang = vim.treesitter.language.get_lang(alias:lower()) + if lang and has_parser(lang) then + return lang + end end ---@private @@ -758,9 +768,10 @@ end function LanguageTree:_get_injection(match, metadata) local ranges = {} ---@type Range6[] local combined = metadata['injection.combined'] ~= nil + local injection_lang = metadata['injection.language'] --[[@as string?]] local lang = metadata['injection.self'] ~= nil and self:lang() or metadata['injection.parent'] ~= nil and self._parent_lang - or metadata['injection.language'] --[[@as string?]] + or (injection_lang and resolve_lang(injection_lang)) local include_children = metadata['injection.include-children'] ~= nil for id, node in pairs(match) do @@ -768,7 +779,7 @@ function LanguageTree:_get_injection(match, metadata) -- Lang should override any other language tag if name == 'injection.language' then local text = vim.treesitter.get_node_text(node, self._source, { metadata = metadata[id] }) - lang = resolve_lang(text) or resolve_lang(text:lower()) + lang = resolve_lang(text) elseif name == 'injection.content' then ranges = get_node_ranges(node, self._source, metadata[id], include_children) end diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index d7973cc48f..6d9b214d4a 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -708,7 +708,8 @@ end ---@param start integer Starting line for the search ---@param stop integer Stopping line for the search (end-exclusive) --- ----@return (fun(): integer, TSNode, TSMetadata): capture id, capture node, metadata +---@return (fun(end_line: integer|nil): integer, TSNode, TSMetadata): +--- capture id, capture node, metadata function Query:iter_captures(node, source, start, stop) if type(source) == 'number' and source == 0 then source = api.nvim_get_current_buf() @@ -717,7 +718,7 @@ function Query:iter_captures(node, source, start, stop) start, stop = value_or_node_range(start, stop, node) local raw_iter = node:_rawquery(self.query, true, start, stop) - local function iter() + local function iter(end_line) local capture, captured_node, match = raw_iter() local metadata = {} @@ -725,7 +726,10 @@ function Query:iter_captures(node, source, start, stop) local active = self:match_preds(match, match.pattern, source) match.active = active if not active then - return iter() -- tail call: try next match + if end_line and captured_node:range() > end_line then + return nil, captured_node, nil + end + return iter(end_line) -- tail call: try next match end self:apply_directives(match, match.pattern, source, metadata) diff --git a/src/nvim/mapping.c b/src/nvim/mapping.c index ab528f8865..35a728314c 100644 --- a/src/nvim/mapping.c +++ b/src/nvim/mapping.c @@ -148,9 +148,7 @@ mapblock_T *get_maphash(int index, buf_T *buf) /// "mpp" is a pointer to the m_next field of the PREVIOUS entry! static void mapblock_free(mapblock_T **mpp) { - mapblock_T *mp; - - mp = *mpp; + mapblock_T *mp = *mpp; xfree(mp->m_keys); if (!mp->m_simplified) { NLUA_CLEAR_REF(mp->m_luaref); @@ -212,8 +210,6 @@ static char *map_mode_to_chars(int mode) /// @param local true for buffer-local map static void showmap(mapblock_T *mp, bool local) { - size_t len = 1; - if (message_filtered(mp->m_keys) && message_filtered(mp->m_str) && (mp->m_desc == NULL || message_filtered(mp->m_desc))) { return; @@ -226,12 +222,10 @@ static void showmap(mapblock_T *mp, bool local) } } - { - char *const mapchars = map_mode_to_chars(mp->m_mode); - msg_puts(mapchars); - len = strlen(mapchars); - xfree(mapchars); - } + char *const mapchars = map_mode_to_chars(mp->m_mode); + msg_puts(mapchars); + size_t len = strlen(mapchars); + xfree(mapchars); while (++len <= 3) { msg_putchar(' '); @@ -572,33 +566,16 @@ static void map_add(buf_T *buf, mapblock_T **map_table, mapblock_T **abbr_table, /// @param buf Target Buffer static int buf_do_map(int maptype, MapArguments *args, int mode, bool is_abbrev, buf_T *buf) { - mapblock_T *mp, **mpp; - const char *p; - int n; int retval = 0; - mapblock_T **abbr_table; - mapblock_T **map_table; - int noremap; - map_table = maphash; - abbr_table = &first_abbr; + // If <buffer> was given, we'll be searching through the buffer's + // mappings/abbreviations, not the globals. + mapblock_T **map_table = args->buffer ? buf->b_maphash : maphash; + mapblock_T **abbr_table = args->buffer ? &buf->b_first_abbr : &first_abbr; // For ":noremap" don't remap, otherwise do remap. - if (maptype == MAPTYPE_NOREMAP) { - noremap = REMAP_NONE; - } else { - noremap = REMAP_YES; - } - - if (args->buffer) { - // If <buffer> was given, we'll be searching through the buffer's - // mappings/abbreviations, not the globals. - map_table = buf->b_maphash; - abbr_table = &buf->b_first_abbr; - } - if (args->script) { - noremap = REMAP_SCRIPT; - } + int noremap = args->script ? REMAP_SCRIPT : + maptype == MAPTYPE_NOREMAP ? REMAP_NONE : REMAP_YES; const bool has_lhs = (args->lhs[0] != NUL); const bool has_rhs = args->rhs_lua != LUA_NOREF || (args->rhs[0] != NUL) || args->rhs_is_noop; @@ -648,8 +625,8 @@ static int buf_do_map(int maptype, MapArguments *args, int mode, bool is_abbrev, const int first = vim_iswordp(lhs); int last = first; - p = lhs + utfc_ptr2len(lhs); - n = 1; + const char *p = lhs + utfc_ptr2len(lhs); + int n = 1; while (p < lhs + len) { n++; // nr of (multi-byte) chars last = vim_iswordp(p); // type of last char @@ -685,6 +662,7 @@ static int buf_do_map(int maptype, MapArguments *args, int mode, bool is_abbrev, && maptype != MAPTYPE_UNMAP) { // need to loop over all global hash lists for (int hash = 0; hash < 256 && !got_int; hash++) { + mapblock_T *mp; if (is_abbrev) { if (hash != 0) { // there is only one abbreviation list break; @@ -714,6 +692,7 @@ static int buf_do_map(int maptype, MapArguments *args, int mode, bool is_abbrev, if (map_table != buf->b_maphash && !has_rhs && maptype != MAPTYPE_UNMAP) { // need to loop over all global hash lists for (int hash = 0; hash < 256 && !got_int; hash++) { + mapblock_T *mp; if (is_abbrev) { if (hash != 0) { // there is only one abbreviation list break; @@ -729,7 +708,7 @@ static int buf_do_map(int maptype, MapArguments *args, int mode, bool is_abbrev, showmap(mp, true); did_local = true; } else { - n = mp->m_keylen; + int n = mp->m_keylen; if (strncmp(mp->m_keys, lhs, (size_t)(n < len ? n : len)) == 0) { showmap(mp, true); did_local = true; @@ -759,8 +738,8 @@ static int buf_do_map(int maptype, MapArguments *args, int mode, bool is_abbrev, hash_end = 256; } for (int hash = hash_start; hash < hash_end && !got_int; hash++) { - mpp = is_abbrev ? abbr_table : &(map_table[hash]); - for (mp = *mpp; mp != NULL && !got_int; mp = *mpp) { + mapblock_T **mpp = is_abbrev ? abbr_table : &(map_table[hash]); + for (mapblock_T *mp = *mpp; mp != NULL && !got_int; mp = *mpp) { if ((mp->m_mode & mode) == 0) { // skip entries with wrong mode mpp = &(mp->m_next); @@ -772,6 +751,8 @@ static int buf_do_map(int maptype, MapArguments *args, int mode, bool is_abbrev, did_it = true; } } else { // do we have a match? + int n; + const char *p; if (round) { // second round: Try unmap "rhs" string n = (int)strlen(mp->m_str); p = mp->m_str; @@ -994,12 +975,10 @@ free_and_return: /// Get the mapping mode from the command name. static int get_map_mode(char **cmdp, bool forceit) { - char *p; - int modec; int mode; - p = *cmdp; - modec = (uint8_t)(*p++); + char *p = *cmdp; + int modec = (uint8_t)(*p++); if (modec == 'i') { mode = MODE_INSERT; // :imap } else if (modec == 'l') { @@ -1036,16 +1015,13 @@ static int get_map_mode(char **cmdp, bool forceit) /// This function used to be called map_clear(). static void do_mapclear(char *cmdp, char *arg, int forceit, int abbr) { - int mode; - int local; - - local = (strcmp(arg, "<buffer>") == 0); + bool local = strcmp(arg, "<buffer>") == 0; if (!local && *arg != NUL) { emsg(_(e_invarg)); return; } - mode = get_map_mode(&cmdp, forceit); + int mode = get_map_mode(&cmdp, forceit); map_clear_mode(curbuf, mode, local, abbr); } @@ -1057,11 +1033,8 @@ static void do_mapclear(char *cmdp, char *arg, int forceit, int abbr) /// @param abbr true for abbreviations void map_clear_mode(buf_T *buf, int mode, bool local, bool abbr) { - mapblock_T *mp, **mpp; - int hash; - int new_hash; - - for (hash = 0; hash < 256; hash++) { + for (int hash = 0; hash < 256; hash++) { + mapblock_T **mpp; if (abbr) { if (hash > 0) { // there is only one abbrlist break; @@ -1079,7 +1052,7 @@ void map_clear_mode(buf_T *buf, int mode, bool local, bool abbr) } } while (*mpp != NULL) { - mp = *mpp; + mapblock_T *mp = *mpp; if (mp->m_mode & mode) { mp->m_mode &= ~mode; if (mp->m_mode == 0) { // entry can be deleted @@ -1087,7 +1060,7 @@ void map_clear_mode(buf_T *buf, int mode, bool local, bool abbr) continue; } // May need to put this entry into another hash list. - new_hash = MAP_HASH(mp->m_mode, (uint8_t)mp->m_keys[0]); + int new_hash = MAP_HASH(mp->m_mode, (uint8_t)mp->m_keys[0]); if (!abbr && new_hash != hash) { *mpp = mp->m_next; if (local) { @@ -1119,7 +1092,6 @@ bool map_to_exists(const char *const str, const char *const modechars, const boo FUNC_ATTR_NONNULL_ALL FUNC_ATTR_WARN_UNUSED_RESULT FUNC_ATTR_PURE { int mode = 0; - int retval; char *buf = NULL; const char *const rhs = replace_termcodes(str, strlen(str), &buf, 0, @@ -1141,7 +1113,7 @@ bool map_to_exists(const char *const str, const char *const modechars, const boo MAPMODE(mode, modechars, 'c', MODE_CMDLINE); #undef MAPMODE - retval = map_to_exists_mode(rhs, mode, abbr); + int retval = map_to_exists_mode(rhs, mode, abbr); xfree(buf); return retval; @@ -1159,13 +1131,12 @@ bool map_to_exists(const char *const str, const char *const modechars, const boo /// @return true if there is at least one mapping with given parameters. int map_to_exists_mode(const char *const rhs, const int mode, const bool abbr) { - mapblock_T *mp; - int hash; bool exp_buffer = false; // Do it twice: once for global maps and once for local maps. while (true) { - for (hash = 0; hash < 256; hash++) { + for (int hash = 0; hash < 256; hash++) { + mapblock_T *mp; if (abbr) { if (hash > 0) { // There is only one abbr list. break; @@ -1486,10 +1457,7 @@ int ExpandMappings(char *pat, regmatch_T *regmatch, int *numMatches, char ***mat // Return true if there is an abbreviation, false if not. bool check_abbr(int c, char *ptr, int col, int mincol) { - int scol; // starting column of the abbr. uint8_t tb[MB_MAXBYTES + 4]; - mapblock_T *mp; - mapblock_T *mp2; int clen = 0; // length in characters if (typebuf.tb_no_abbr_cnt) { // abbrev. are not recursive @@ -1509,6 +1477,8 @@ bool check_abbr(int c, char *ptr, int col, int mincol) return false; } + int scol; // starting column of the abbr. + { bool is_id = true; bool vim_abbr; @@ -1539,8 +1509,8 @@ bool check_abbr(int c, char *ptr, int col, int mincol) if (scol < col) { // there is a word in front of the cursor ptr += scol; int len = col - scol; - mp = curbuf->b_first_abbr; - mp2 = first_abbr; + mapblock_T *mp = curbuf->b_first_abbr; + mapblock_T *mp2 = first_abbr; if (mp == NULL) { mp = mp2; mp2 = NULL; @@ -1715,18 +1685,13 @@ char *eval_map_expr(mapblock_T *mp, int c) /// @param buf buffer for local mappings or NULL int makemap(FILE *fd, buf_T *buf) { - mapblock_T *mp; - char c1, c2, c3; - char *p; - char *cmd; - int abbr; - int hash; bool did_cpo = false; // Do the loop twice: Once for mappings, once for abbreviations. // Then loop over all map hash lists. - for (abbr = 0; abbr < 2; abbr++) { - for (hash = 0; hash < 256; hash++) { + for (int abbr = 0; abbr < 2; abbr++) { + for (int hash = 0; hash < 256; hash++) { + mapblock_T *mp; if (abbr) { if (hash > 0) { // there is only one abbr list break; @@ -1755,6 +1720,7 @@ int makemap(FILE *fd, buf_T *buf) if (mp->m_luaref != LUA_NOREF) { continue; } + char *p; for (p = mp->m_str; *p != NUL; p++) { if ((uint8_t)p[0] == K_SPECIAL && (uint8_t)p[1] == KS_EXTRA && p[2] == KE_SNR) { @@ -1768,14 +1734,10 @@ int makemap(FILE *fd, buf_T *buf) // It's possible to create a mapping and then ":unmap" certain // modes. We recreate this here by mapping the individual // modes, which requires up to three of them. - c1 = NUL; - c2 = NUL; - c3 = NUL; - if (abbr) { - cmd = "abbr"; - } else { - cmd = "map"; - } + char c1 = NUL; + char c2 = NUL; + char c3 = NUL; + char *cmd = abbr ? "abbr" : "map"; switch (mp->m_mode) { case MODE_NORMAL | MODE_VISUAL | MODE_SELECT | MODE_OP_PENDING: break; @@ -1929,7 +1891,6 @@ int makemap(FILE *fd, buf_T *buf) int put_escstr(FILE *fd, char *strstart, int what) { uint8_t *str = (uint8_t *)strstart; - int c; // :map xx <Nop> if (*str == NUL && what == 1) { @@ -1953,7 +1914,7 @@ int put_escstr(FILE *fd, char *strstart, int what) continue; } - c = *str; + int c = *str; // Special key codes have to be translated to be able to make sense // when they are read back. if (c == K_SPECIAL && what != 2) { @@ -2030,14 +1991,13 @@ int put_escstr(FILE *fd, char *strstart, int what) char *check_map(char *keys, int mode, int exact, int ign_mod, int abbr, mapblock_T **mp_ptr, int *local_ptr, int *rhs_lua) { - int len, minlen; - mapblock_T *mp; *rhs_lua = LUA_NOREF; - len = (int)strlen(keys); + int len = (int)strlen(keys); for (int local = 1; local >= 0; local--) { // loop over all hash lists for (int hash = 0; hash < 256; hash++) { + mapblock_T *mp; if (abbr) { if (hash > 0) { // there is only one list. break; @@ -2062,7 +2022,7 @@ char *check_map(char *keys, int mode, int exact, int ign_mod, int abbr, mapblock s += 3; keylen -= 3; } - minlen = keylen < len ? keylen : len; + int minlen = keylen < len ? keylen : len; if (strncmp(s, keys, (size_t)minlen) == 0) { if (mp_ptr != NULL) { *mp_ptr = mp; @@ -2097,11 +2057,7 @@ void f_hasmapto(typval_T *argvars, typval_T *rettv, EvalFuncData fptr) } } - if (map_to_exists(name, mode, abbr)) { - rettv->vval.v_number = true; - } else { - rettv->vval.v_number = false; - } + rettv->vval.v_number = map_to_exists(name, mode, abbr); } /// Fill a Dictionary with all applicable maparg() like dictionaries @@ -2439,14 +2395,11 @@ void langmap_init(void) /// changed at any time! const char *did_set_langmap(optset_T *args) { - char *p; - char *p2; - int from, to; - - ga_clear(&langmap_mapga); // clear the previous map first - langmap_init(); // back to one-to-one map + ga_clear(&langmap_mapga); // clear the previous map first + langmap_init(); // back to one-to-one map - for (p = p_langmap; p[0] != NUL;) { + for (char *p = p_langmap; p[0] != NUL;) { + char *p2; for (p2 = p; p2[0] != NUL && p2[0] != ',' && p2[0] != ';'; MB_PTR_ADV(p2)) { if (p2[0] == '\\' && p2[1] != NUL) { @@ -2466,8 +2419,8 @@ const char *did_set_langmap(optset_T *args) if (p[0] == '\\' && p[1] != NUL) { p++; } - from = utf_ptr2char(p); - to = NUL; + int from = utf_ptr2char(p); + int to = NUL; if (p2 == NULL) { MB_PTR_ADV(p); if (p[0] != ',') { @@ -2524,9 +2477,8 @@ const char *did_set_langmap(optset_T *args) static void do_exmap(exarg_T *eap, int isabbrev) { - int mode; char *cmdp = eap->cmd; - mode = get_map_mode(&cmdp, eap->forceit || isabbrev); + int mode = get_map_mode(&cmdp, eap->forceit || isabbrev); switch (do_map((*cmdp == 'n') ? MAPTYPE_NOREMAP : (*cmdp == 'u') ? MAPTYPE_UNMAP : MAPTYPE_MAP, diff --git a/test/functional/treesitter/highlight_spec.lua b/test/functional/treesitter/highlight_spec.lua index 0528370e2a..0aa0cdd6d6 100644 --- a/test/functional/treesitter/highlight_spec.lua +++ b/test/functional/treesitter/highlight_spec.lua @@ -85,6 +85,56 @@ void ui_refresh(void) } }]] +local injection_text_c = [[ +int x = INT_MAX; +#define READ_STRING(x, y) (char *)read_string((x), (size_t)(y)) +#define foo void main() { \ + return 42; \ + } +]] + +local injection_grid_c = [[ + int x = INT_MAX; | + #define READ_STRING(x, y) (char *)read_string((x), (size_t)(y)) | + #define foo void main() { \ | + return 42; \ | + } | + ^ | + {1:~ }| + {1:~ }| + {1:~ }| + {1:~ }| + {1:~ }| + {1:~ }| + {1:~ }| + {1:~ }| + {1:~ }| + {1:~ }| + {1:~ }| + | +]] + +local injection_grid_expected_c = [[ + {3:int} x = {5:INT_MAX}; | + #define {5:READ_STRING}(x, y) ({3:char} *)read_string((x), ({3:size_t})(y)) | + #define foo {3:void} main() { \ | + {4:return} {5:42}; \ | + } | + ^ | + {1:~ }| + {1:~ }| + {1:~ }| + {1:~ }| + {1:~ }| + {1:~ }| + {1:~ }| + {1:~ }| + {1:~ }| + {1:~ }| + {1:~ }| + | +]] + describe('treesitter highlighting (C)', function() local screen @@ -411,34 +461,9 @@ describe('treesitter highlighting (C)', function() end) it("supports injected languages", function() - insert([[ - int x = INT_MAX; - #define READ_STRING(x, y) (char *)read_string((x), (size_t)(y)) - #define foo void main() { \ - return 42; \ - } - ]]) + insert(injection_text_c) - screen:expect{grid=[[ - int x = INT_MAX; | - #define READ_STRING(x, y) (char *)read_string((x), (size_t)(y)) | - #define foo void main() { \ | - return 42; \ | - } | - ^ | - {1:~ }| - {1:~ }| - {1:~ }| - {1:~ }| - {1:~ }| - {1:~ }| - {1:~ }| - {1:~ }| - {1:~ }| - {1:~ }| - {1:~ }| - | - ]]} + screen:expect{grid=injection_grid_c} exec_lua [[ local parser = vim.treesitter.get_parser(0, "c", { @@ -448,26 +473,24 @@ describe('treesitter highlighting (C)', function() test_hl = highlighter.new(parser, {queries = {c = hl_query}}) ]] - screen:expect{grid=[[ - {3:int} x = {5:INT_MAX}; | - #define {5:READ_STRING}(x, y) ({3:char} *)read_string((x), ({3:size_t})(y)) | - #define foo {3:void} main() { \ | - {4:return} {5:42}; \ | - } | - ^ | - {1:~ }| - {1:~ }| - {1:~ }| - {1:~ }| - {1:~ }| - {1:~ }| - {1:~ }| - {1:~ }| - {1:~ }| - {1:~ }| - {1:~ }| - | - ]]} + screen:expect{grid=injection_grid_expected_c} + end) + + it("supports injecting by ft name in metadata['injection.language']", function() + insert(injection_text_c) + + screen:expect{grid=injection_grid_c} + + exec_lua [[ + vim.treesitter.language.register("c", "foo") + local parser = vim.treesitter.get_parser(0, "c", { + injections = {c = '(preproc_def (preproc_arg) @injection.content (#set! injection.language "fOO")) (preproc_function_def value: (preproc_arg) @injection.content (#set! injection.language "fOO"))'} + }) + local highlighter = vim.treesitter.highlighter + test_hl = highlighter.new(parser, {queries = {c = hl_query}}) + ]] + + screen:expect{grid=injection_grid_expected_c} end) it("supports overriding queries, like ", function() |