aboutsummaryrefslogtreecommitdiff
path: root/test/unit/set.lua
blob: f93238cc47990c309ea8a08c8acb2043bb12ca3d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
-- a set class for fast union/diff, can always return a table with the lines
-- in the same relative order in which they were added by calling the
-- to_table method. It does this by keeping two lua tables that mirror each
-- other:
-- 1) index => item
-- 2) item => index
--- @class Set
--- @field nelem integer
--- @field items string[]
--- @field tbl table
local Set = {}

--- @param items? string[]
function Set:new(items)
  local obj = {} --- @type Set
  setmetatable(obj, self)
  self.__index = self

  if type(items) == 'table' then
    local tempset = Set:new()
    tempset:union_table(items)
    obj.tbl = tempset:raw_tbl()
    obj.items = tempset:raw_items()
    obj.nelem = tempset:size()
  else
    obj.tbl = {}
    obj.items = {}
    obj.nelem = 0
  end

  return obj
end

--- @return Set
function Set:copy()
  local obj = { nelem = self.nelem, tbl = {}, items = {} } --- @type Set
  for k, v in pairs(self.tbl) do
    obj.tbl[k] = v
  end
  for k, v in pairs(self.items) do
    obj.items[k] = v
  end
  setmetatable(obj, Set)
  obj.__index = Set
  return obj
end

-- adds the argument Set to this Set
--- @param other Set
function Set:union(other)
  for e in other:iterator() do
    self:add(e)
  end
end

-- adds the argument table to this Set
function Set:union_table(t)
  for _, v in pairs(t) do
    self:add(v)
  end
end

-- subtracts the argument Set from this Set
--- @param other Set
function Set:diff(other)
  if other:size() > self:size() then
    -- this set is smaller than the other set
    for e in self:iterator() do
      if other:contains(e) then
        self:remove(e)
      end
    end
  else
    -- this set is larger than the other set
    for e in other:iterator() do
      if self.items[e] then
        self:remove(e)
      end
    end
  end
end

--- @param it string
function Set:add(it)
  if not self:contains(it) then
    local idx = #self.tbl + 1
    self.tbl[idx] = it
    self.items[it] = idx
    self.nelem = self.nelem + 1
  end
end

--- @param it string
function Set:remove(it)
  if self:contains(it) then
    local idx = self.items[it]
    self.tbl[idx] = nil
    self.items[it] = nil
    self.nelem = self.nelem - 1
  end
end

--- @param it string
--- @return boolean
function Set:contains(it)
  return self.items[it] or false
end

--- @return integer
function Set:size()
  return self.nelem
end

function Set:raw_tbl()
  return self.tbl
end

function Set:raw_items()
  return self.items
end

function Set:iterator()
  return pairs(self.items)
end

--- @return string[]
function Set:to_table()
  -- there might be gaps in @tbl, so we have to be careful and sort first
  local keys = {} --- @type string[]
  for idx, _ in pairs(self.tbl) do
    keys[#keys + 1] = idx
  end

  table.sort(keys)
  local copy = {} --- @type string[]
  for _, idx in ipairs(keys) do
    copy[#copy + 1] = self.tbl[idx]
  end
  return copy
end

return Set