aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter/_range.lua
blob: 21e46a560acbdc53aa6e30706e0aecb6d6e7a5c2 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
local api = vim.api

local M = {}

---@alias Range4 {[1]: integer, [2]: integer, [3]: integer, [4]: integer}
---@alias Range6 {[1]: integer, [2]: integer, [3]: integer, [4]: integer, [5]: integer, [6]: integer}

---@private
---@param a_row integer
---@param a_col integer
---@param b_row integer
---@param b_col integer
---@return integer
--- 1: a > b
--- 0: a == b
--- -1: a < b
local function cmp_pos(a_row, a_col, b_row, b_col)
  if a_row == b_row then
    if a_col > b_col then
      return 1
    elseif a_col < b_col then
      return -1
    else
      return 0
    end
  elseif a_row > b_row then
    return 1
  end

  return -1
end

M.cmp_pos = {
  lt = function(...)
    return cmp_pos(...) == -1
  end,
  le = function(...)
    return cmp_pos(...) ~= 1
  end,
  gt = function(...)
    return cmp_pos(...) == 1
  end,
  ge = function(...)
    return cmp_pos(...) ~= -1
  end,
  eq = function(...)
    return cmp_pos(...) == 0
  end,
  ne = function(...)
    return cmp_pos(...) ~= 0
  end,
}

setmetatable(M.cmp_pos, { __call = cmp_pos })

---@private
---Check if a variable is a valid range object
---@param r any
---@return boolean
function M.validate(r)
  if type(r) ~= 'table' or #r ~= 6 and #r ~= 4 then
    return false
  end

  for _, e in
    ipairs(r --[[@as any[] ]])
  do
    if type(e) ~= 'number' then
      return false
    end
  end

  return true
end

---@private
---@param r1 Range4|Range6
---@param r2 Range4|Range6
---@return boolean
function M.intercepts(r1, r2)
  local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
  local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)

  -- r1 is above r2
  if M.cmp_pos.le(erow_1, ecol_1, srow_2, scol_2) then
    return false
  end

  -- r1 is below r2
  if M.cmp_pos.ge(srow_1, scol_1, erow_2, ecol_2) then
    return false
  end

  return true
end

---@private
---@param r Range4|Range6
---@return integer, integer, integer, integer
function M.unpack4(r)
  local off_1 = #r == 6 and 1 or 0
  return r[1], r[2], r[3 + off_1], r[4 + off_1]
end

---@private
---@param r1 Range4|Range6
---@param r2 Range4|Range6
---@return boolean whether r1 contains r2
function M.contains(r1, r2)
  local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
  local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)

  -- start doesn't fit
  if M.cmp_pos.gt(srow_1, scol_1, srow_2, scol_2) then
    return false
  end

  -- end doesn't fit
  if M.cmp_pos.lt(erow_1, ecol_1, erow_2, ecol_2) then
    return false
  end

  return true
end

---@private
---@param source integer|string
---@param range Range4|Range6
---@return Range6
function M.add_bytes(source, range)
  if type(range) == 'table' and #range == 6 then
    return range --[[@as Range6]]
  end

  local start_row, start_col, end_row, end_col = range[1], range[2], range[3], range[4]
  local start_byte = 0
  local end_byte = 0
  -- TODO(vigoux): proper byte computation here, and account for EOL ?
  if type(source) == 'number' then
    -- Easy case, this is a buffer parser
    start_byte = api.nvim_buf_get_offset(source, start_row) + start_col
    end_byte = api.nvim_buf_get_offset(source, end_row) + end_col
  elseif type(source) == 'string' then
    -- string parser, single `\n` delimited string
    start_byte = vim.fn.byteidx(source, start_col)
    end_byte = vim.fn.byteidx(source, end_col)
  end

  return { start_row, start_col, start_byte, end_row, end_col, end_byte }
end

return M