set.hpp 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. #pragma once
  2. //set
  3. //implementation: red-black tree
  4. //
  5. //search: O(log n) average; O(log n) worst
  6. //insert: O(log n) average; O(log n) worst
  7. //remove: O(log n) average; O(log n) worst
  8. //
  9. //requirements:
  10. // bool T::operator==(const T&) const;
  11. // bool T::operator< (const T&) const;
  12. #include <nall/utility.hpp>
  13. #include <nall/vector.hpp>
  14. namespace nall {
  15. template<typename T> struct set {
  16. struct node_t {
  17. T value;
  18. bool red = 1;
  19. node_t* link[2] = {nullptr, nullptr};
  20. node_t() = default;
  21. node_t(const T& value) : value(value) {}
  22. };
  23. node_t* root = nullptr;
  24. uint nodes = 0;
  25. set() = default;
  26. set(const set& source) { operator=(source); }
  27. set(set&& source) { operator=(move(source)); }
  28. set(std::initializer_list<T> list) { for(auto& value : list) insert(value); }
  29. ~set() { reset(); }
  30. auto operator=(const set& source) -> set& {
  31. if(this == &source) return *this;
  32. reset();
  33. copy(root, source.root);
  34. nodes = source.nodes;
  35. return *this;
  36. }
  37. auto operator=(set&& source) -> set& {
  38. if(this == &source) return *this;
  39. root = source.root;
  40. nodes = source.nodes;
  41. source.root = nullptr;
  42. source.nodes = 0;
  43. return *this;
  44. }
  45. explicit operator bool() const { return nodes; }
  46. auto size() const -> uint { return nodes; }
  47. auto reset() -> void {
  48. reset(root);
  49. nodes = 0;
  50. }
  51. auto find(const T& value) -> maybe<T&> {
  52. if(node_t* node = find(root, value)) return node->value;
  53. return nothing;
  54. }
  55. auto find(const T& value) const -> maybe<const T&> {
  56. if(node_t* node = find(root, value)) return node->value;
  57. return nothing;
  58. }
  59. auto insert(const T& value) -> maybe<T&> {
  60. uint count = size();
  61. node_t* v = insert(root, value);
  62. root->red = 0;
  63. if(size() == count) return nothing;
  64. return v->value;
  65. }
  66. template<typename... P> auto insert(const T& value, P&&... p) -> bool {
  67. bool result = insert(value);
  68. insert(forward<P>(p)...) | result;
  69. return result;
  70. }
  71. auto remove(const T& value) -> bool {
  72. uint count = size();
  73. bool done = 0;
  74. remove(root, &value, done);
  75. if(root) root->red = 0;
  76. return size() < count;
  77. }
  78. template<typename... P> auto remove(const T& value, P&&... p) -> bool {
  79. bool result = remove(value);
  80. return remove(forward<P>(p)...) | result;
  81. }
  82. struct base_iterator {
  83. auto operator!=(const base_iterator& source) const -> bool { return position != source.position; }
  84. auto operator++() -> base_iterator& {
  85. if(++position >= source.size()) { position = source.size(); return *this; }
  86. if(stack.right()->link[1]) {
  87. stack.append(stack.right()->link[1]);
  88. while(stack.right()->link[0]) stack.append(stack.right()->link[0]);
  89. } else {
  90. node_t* child;
  91. do child = stack.takeRight();
  92. while(child == stack.right()->link[1]);
  93. }
  94. return *this;
  95. }
  96. base_iterator(const set& source, uint position) : source(source), position(position) {
  97. node_t* node = source.root;
  98. while(node) {
  99. stack.append(node);
  100. node = node->link[0];
  101. }
  102. }
  103. protected:
  104. const set& source;
  105. uint position;
  106. vector<node_t*> stack;
  107. };
  108. struct iterator : base_iterator {
  109. iterator(const set& source, uint position) : base_iterator(source, position) {}
  110. auto operator*() const -> T& { return base_iterator::stack.right()->value; }
  111. };
  112. auto begin() -> iterator { return iterator(*this, 0); }
  113. auto end() -> iterator { return iterator(*this, size()); }
  114. struct const_iterator : base_iterator {
  115. const_iterator(const set& source, uint position) : base_iterator(source, position) {}
  116. auto operator*() const -> const T& { return base_iterator::stack.right()->value; }
  117. };
  118. auto begin() const -> const const_iterator { return const_iterator(*this, 0); }
  119. auto end() const -> const const_iterator { return const_iterator(*this, size()); }
  120. private:
  121. auto reset(node_t*& node) -> void {
  122. if(!node) return;
  123. if(node->link[0]) reset(node->link[0]);
  124. if(node->link[1]) reset(node->link[1]);
  125. delete node;
  126. node = nullptr;
  127. }
  128. auto copy(node_t*& target, const node_t* source) -> void {
  129. if(!source) return;
  130. target = new node_t(source->value);
  131. target->red = source->red;
  132. copy(target->link[0], source->link[0]);
  133. copy(target->link[1], source->link[1]);
  134. }
  135. auto find(node_t* node, const T& value) const -> node_t* {
  136. if(node == nullptr) return nullptr;
  137. if(node->value == value) return node;
  138. return find(node->link[node->value < value], value);
  139. }
  140. auto red(node_t* node) const -> bool { return node && node->red; }
  141. auto black(node_t* node) const -> bool { return !red(node); }
  142. auto rotate(node_t*& a, bool dir) -> void {
  143. node_t*& b = a->link[!dir];
  144. node_t*& c = b->link[dir];
  145. a->red = 1, b->red = 0;
  146. std::swap(a, b);
  147. std::swap(b, c);
  148. }
  149. auto rotateTwice(node_t*& node, bool dir) -> void {
  150. rotate(node->link[!dir], !dir);
  151. rotate(node, dir);
  152. }
  153. auto insert(node_t*& node, const T& value) -> node_t* {
  154. if(!node) { nodes++; node = new node_t(value); return node; }
  155. if(node->value == value) { node->value = value; return node; } //prevent duplicate entries
  156. bool dir = node->value < value;
  157. node_t* v = insert(node->link[dir], value);
  158. if(black(node->link[dir])) return v;
  159. if(red(node->link[!dir])) {
  160. node->red = 1;
  161. node->link[0]->red = 0;
  162. node->link[1]->red = 0;
  163. } else if(red(node->link[dir]->link[dir])) {
  164. rotate(node, !dir);
  165. } else if(red(node->link[dir]->link[!dir])) {
  166. rotateTwice(node, !dir);
  167. }
  168. return v;
  169. }
  170. auto balance(node_t*& node, bool dir, bool& done) -> void {
  171. node_t* p = node;
  172. node_t* s = node->link[!dir];
  173. if(!s) return;
  174. if(red(s)) {
  175. rotate(node, dir);
  176. s = p->link[!dir];
  177. }
  178. if(black(s->link[0]) && black(s->link[1])) {
  179. if(red(p)) done = 1;
  180. p->red = 0, s->red = 1;
  181. } else {
  182. bool save = p->red;
  183. bool head = node == p;
  184. if(red(s->link[!dir])) rotate(p, dir);
  185. else rotateTwice(p, dir);
  186. p->red = save;
  187. p->link[0]->red = 0;
  188. p->link[1]->red = 0;
  189. if(head) node = p;
  190. else node->link[dir] = p;
  191. done = 1;
  192. }
  193. }
  194. auto remove(node_t*& node, const T* value, bool& done) -> void {
  195. if(!node) { done = 1; return; }
  196. if(node->value == *value) {
  197. if(!node->link[0] || !node->link[1]) {
  198. node_t* save = node->link[!node->link[0]];
  199. if(red(node)) done = 1;
  200. else if(red(save)) save->red = 0, done = 1;
  201. nodes--;
  202. delete node;
  203. node = save;
  204. return;
  205. } else {
  206. node_t* heir = node->link[0];
  207. while(heir->link[1]) heir = heir->link[1];
  208. node->value = heir->value;
  209. value = &heir->value;
  210. }
  211. }
  212. bool dir = node->value < *value;
  213. remove(node->link[dir], value, done);
  214. if(!done) balance(node, dir, done);
  215. }
  216. };
  217. }