aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter/_range.lua
blob: b87542c20f5929f8e5b697978a56e4e86b8dcf8d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
local api = vim.api

local M = {}

---@alias Range4 {[1]: integer, [2]: integer, [3]: integer, [4]: integer}
---@alias Range6 {[1]: integer, [2]: integer, [3]: integer, [4]: integer, [5]: integer, [6]: integer}

---@private
---@param a_row integer
---@param a_col integer
---@param b_row integer
---@param b_col integer
---@return integer
--- 1: a > b
--- 0: a == b
--- -1: a < b
local function cmp_pos(a_row, a_col, b_row, b_col)
  if a_row == b_row then
    if a_col > b_col then
      return 1
    elseif a_col < b_col then
      return -1
    else
      return 0
    end
  elseif a_row > b_row then
    return 1
  end

  return -1
end

M.cmp_pos = {
  lt = function(...)
    return cmp_pos(...) == -1
  end,
  le = function(...)
    return cmp_pos(...) ~= 1
  end,
  gt = function(...)
    return cmp_pos(...) == 1
  end,
  ge = function(...)
    return cmp_pos(...) ~= -1
  end,
  eq = function(...)
    return cmp_pos(...) == 0
  end,
  ne = function(...)
    return cmp_pos(...) ~= 0
  end,
}

setmetatable(M.cmp_pos, { __call = cmp_pos })

---@private
---@param r1 Range4|Range6
---@param r2 Range4|Range6
---@return boolean
function M.intercepts(r1, r2)
  local off_1 = #r1 == 6 and 1 or 0
  local off_2 = #r1 == 6 and 1 or 0

  local srow_1, scol_1, erow_1, ecol_1 = r1[1], r2[2], r1[3 + off_1], r1[4 + off_1]
  local srow_2, scol_2, erow_2, ecol_2 = r2[1], r2[2], r2[3 + off_2], r2[4 + off_2]

  -- r1 is above r2
  if M.cmp_pos.le(erow_1, ecol_1, srow_2, scol_2) then
    return false
  end

  -- r1 is below r2
  if M.cmp_pos.ge(srow_1, scol_1, erow_2, ecol_2) then
    return false
  end

  return true
end

---@private
---@param r1 Range4|Range6
---@param r2 Range4|Range6
---@return boolean whether r1 contains r2
function M.contains(r1, r2)
  local off_1 = #r1 == 6 and 1 or 0
  local off_2 = #r1 == 6 and 1 or 0

  local srow_1, scol_1, erow_1, ecol_1 = r1[1], r2[2], r1[3 + off_1], r1[4 + off_1]
  local srow_2, scol_2, erow_2, ecol_2 = r2[1], r2[2], r2[3 + off_2], r2[4 + off_2]

  -- start doesn't fit
  if M.cmp_pos.gt(srow_1, scol_1, srow_2, scol_2) then
    return false
  end

  -- end doesn't fit
  if M.cmp_pos.lt(erow_1, ecol_1, erow_2, ecol_2) then
    return false
  end

  return true
end

---@private
---@param source integer|string
---@param range Range4
---@return Range6
function M.add_bytes(source, range)
  local start_row, start_col, end_row, end_col = range[1], range[2], range[3], range[4]
  local start_byte = 0
  local end_byte = 0
  -- TODO(vigoux): proper byte computation here, and account for EOL ?
  if type(source) == 'number' then
    -- Easy case, this is a buffer parser
    start_byte = api.nvim_buf_get_offset(source, start_row) + start_col
    end_byte = api.nvim_buf_get_offset(source, end_row) + end_col
  elseif type(source) == 'string' then
    -- string parser, single `\n` delimited string
    start_byte = vim.fn.byteidx(source, start_col)
    end_byte = vim.fn.byteidx(source, end_col)
  end

  return { start_row, start_col, start_byte, end_row, end_col, end_byte }
end

return M