aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter/_fold.lua
blob: a66cc6d543a3107612395717e22f021fa02f6a5f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
local api = vim.api

local M = {}

--- Memoizes a function based on the buffer tick of the provided bufnr.
--- The cache entry is cleared when the buffer is detached to avoid memory leaks.
---@generic F: function
---@param fn F fn to memoize, taking the bufnr as first argument
---@return F
local function memoize_by_changedtick(fn)
  ---@type table<integer,{result:any,last_tick:integer}>
  local cache = {}

  ---@param bufnr integer
  return function(bufnr, ...)
    local tick = api.nvim_buf_get_changedtick(bufnr)

    if cache[bufnr] then
      if cache[bufnr].last_tick == tick then
        return cache[bufnr].result
      end
    else
      local function detach_handler()
        cache[bufnr] = nil
      end

      -- Clean up logic only!
      api.nvim_buf_attach(bufnr, false, {
        on_detach = detach_handler,
        on_reload = detach_handler,
      })
    end

    cache[bufnr] = {
      result = fn(bufnr, ...),
      last_tick = tick,
    }

    return cache[bufnr].result
  end
end

---@param bufnr integer
---@param capture string
---@param query_name string
---@param callback fun(id: integer, node:TSNode, metadata: TSMetadata)
local function iter_matches_with_capture(bufnr, capture, query_name, callback)
  local parser = vim.treesitter.get_parser(bufnr)

  if not parser then
    return
  end

  parser:for_each_tree(function(tree, lang_tree)
    local lang = lang_tree:lang()
    local query = vim.treesitter.query.get_query(lang, query_name)
    if query then
      local root = tree:root()
      local start, _, stop = root:range()
      for _, match, metadata in query:iter_matches(root, bufnr, start, stop) do
        for id, node in pairs(match) do
          if query.captures[id] == capture then
            callback(id, node, metadata)
          end
        end
      end
    end
  end)
end

---@private
--- TODO(lewis6991): copied from languagetree.lua. Consolidate
---@param node TSNode
---@param id integer
---@param metadata TSMetadata
---@return Range
local function get_range_from_metadata(node, id, metadata)
  if metadata[id] and metadata[id].range then
    return metadata[id].range --[[@as Range]]
  end
  return { node:range() }
end

-- This is cached on buf tick to avoid computing that multiple times
-- Especially not for every line in the file when `zx` is hit
---@param bufnr integer
---@return table<integer,string>
local folds_levels = memoize_by_changedtick(function(bufnr)
  local max_fold_level = vim.wo.foldnestmax
  local function trim_level(level)
    if level > max_fold_level then
      return max_fold_level
    end
    return level
  end

  -- start..stop is an inclusive range
  local start_counts = {} ---@type table<integer,integer>
  local stop_counts = {} ---@type table<integer,integer>

  local prev_start = -1
  local prev_stop = -1

  local min_fold_lines = vim.wo.foldminlines

  iter_matches_with_capture(bufnr, 'fold', 'folds', function(id, node, metadata)
    local range = get_range_from_metadata(node, id, metadata)
    local start, stop, stop_col = range[1], range[3], range[4]

    if stop_col == 0 then
      stop = stop - 1
    end

    local fold_length = stop - start + 1

    -- Fold only multiline nodes that are not exactly the same as previously met folds
    -- Checking against just the previously found fold is sufficient if nodes
    -- are returned in preorder or postorder when traversing tree
    if fold_length > min_fold_lines and not (start == prev_start and stop == prev_stop) then
      start_counts[start] = (start_counts[start] or 0) + 1
      stop_counts[stop] = (stop_counts[stop] or 0) + 1
      prev_start = start
      prev_stop = stop
    end
  end)

  ---@type table<integer,string>
  local levels = {}
  local current_level = 0

  -- We now have the list of fold opening and closing, fill the gaps and mark where fold start
  for lnum = 0, api.nvim_buf_line_count(bufnr) do
    local last_trimmed_level = trim_level(current_level)
    current_level = current_level + (start_counts[lnum] or 0)
    local trimmed_level = trim_level(current_level)
    current_level = current_level - (stop_counts[lnum] or 0)

    -- Determine if it's the start/end of a fold
    -- NB: vim's fold-expr interface does not have a mechanism to indicate that
    -- two (or more) folds start at this line, so it cannot distinguish between
    --  ( \n ( \n )) \n (( \n ) \n )
    -- versus
    --  ( \n ( \n ) \n ( \n ) \n )
    -- If it did have such a mechanism, (trimmed_level - last_trimmed_level)
    -- would be the correct number of starts to pass on.
    local prefix = ''
    if trimmed_level - last_trimmed_level > 0 then
      prefix = '>'
    end

    levels[lnum + 1] = prefix .. tostring(trimmed_level)
  end

  return levels
end)

---@param lnum integer|nil
---@return string
function M.foldexpr(lnum)
  lnum = lnum or vim.v.lnum
  local bufnr = api.nvim_get_current_buf()

  ---@diagnostic disable-next-line:invisible
  if not vim.treesitter._has_parser(bufnr) or not lnum then
    return '0'
  end

  local levels = folds_levels(bufnr) or {}

  return levels[lnum] or '0'
end

return M