aboutsummaryrefslogtreecommitdiff
path: root/test/unit/set.lua
blob: 4e66546f320842c1eac7d32d65e69b45073b6b01 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
-- a set class for fast union/diff, can always return a table with the lines
-- in the same relative order in which they were added by calling the
-- to_table method. It does this by keeping two lua tables that mirror each
-- other:
-- 1) index => item
-- 2) item => index
local Set = {}

function Set:new(items)
  local obj = {}
  setmetatable(obj, self)
  self.__index = self

  if type(items) == 'table' then
    local tempset = Set:new()
    tempset:union_table(items)
    obj.tbl = tempset:raw_tbl()
    obj.items = tempset:raw_items()
    obj.nelem = tempset:size()
  else
    obj.tbl = {}
    obj.items = {}
    obj.nelem = 0
  end

  return obj
end

-- adds the argument Set to this Set
function Set:union(other)
  for e in other:iterator() do
    self:add(e)
  end
end

-- adds the argument table to this Set
function Set:union_table(t)
  for _, v in pairs(t) do
    self:add(v)
  end
end

-- subtracts the argument Set from this Set
function Set:diff(other)
  if other:size() > self:size() then
    -- this set is smaller than the other set
    for e in self:iterator() do
      if other:contains(e) then
        self:remove(e)
      end
    end
  else
    -- this set is larger than the other set
    for e in other:iterator() do
      if self.items[e] then
        self:remove(e)
      end
    end
  end
end

function Set:add(it)
  if not self:contains(it) then
    local idx = #self.tbl + 1
    self.tbl[idx] = it
    self.items[it] = idx
    self.nelem = self.nelem + 1
  end
end

function Set:remove(it)
  if self:contains(it) then
    local idx = self.items[it]
    self.tbl[idx] = nil
    self.items[it] = nil
    self.nelem = self.nelem - 1
  end
end

function Set:contains(it)
  return self.items[it] or false
end

function Set:size()
  return self.nelem
end

function Set:raw_tbl()
  return self.tbl
end

function Set:raw_items()
  return self.items
end

function Set:iterator()
  return pairs(self.items)
end

function Set:to_table()
  -- there might be gaps in @tbl, so we have to be careful and sort first
  local keys
  do
    local _accum_0 = { }
    local _len_0 = 1
    for idx, _ in pairs(self.tbl) do
      _accum_0[_len_0] = idx
      _len_0 = _len_0 + 1
    end
    keys = _accum_0
  end
  table.sort(keys)
  local copy
  do
    local _accum_0 = { }
    local _len_0 = 1
    for _index_0 = 1, #keys do
      local idx = keys[_index_0]
      _accum_0[_len_0] = self.tbl[idx]
      _len_0 = _len_0 + 1
    end
    copy = _accum_0
  end
  return copy
end

return Set