#ifndef _SHARED_AVL_TREE_H_ #define _SHARED_AVL_TREE_H_ #define ALLOC kalloc #define FREE kfree #include #include #include "kern/mem.h" #include "shared/stdmacro.h" #define get_height(n) ((n) == NULL ? 0 : (n)->height) #define max(a, b) ((a) < (b) ? (b) : (a)) #define reset_height(node) \ (node)->height = \ (max(get_height((node)->right), get_height((node)->left)) + 1) #define avl_tree_t(T) CONCAT(avl_tree__, T) #define avl_tree_node_t(T) CONCAT(avl_tree_node__, T) #define avl_tree_new(T) CONCAT(avl_tree_new__, T) #define avl_tree_free(T) CONCAT(avl_tree_free__, T) #define avl_tree_size(T) CONCAT(avl_tree_size__, T) #define avl_tree_insert(T) CONCAT(avl_tree_insert__, T) #define avl_tree_find(T) CONCAT(avl_tree_find__, T) #define avl_tree_erase(T) CONCAT(avl_tree_erase__, T) #define avl_tree_height(T) CONCAT(avl_tree_height__, T) #define null_dtor(a) typedef unsigned int bool; #define AVL_TREE_DECL(T) \ typedef struct CONCAT(AVL_TREE_NODE_, T) { \ T value; \ size_t height; \ struct CONCAT(AVL_TREE_NODE_, T) * left; \ struct CONCAT(AVL_TREE_NODE_, T) * right; \ } avl_tree_node_t(T); \ \ typedef struct { \ avl_tree_node_t(T) * root; \ } avl_tree_t(T); \ \ size_t avl_tree_height(T)(avl_tree_t(T) * tree); \ avl_tree_t(T) * avl_tree_new(T)(); \ void avl_tree_free(T)(avl_tree_t(T) * tree); \ size_t avl_tree_size(T)(const avl_tree_t(T) * avl_tree); \ T* avl_tree_insert(T)(avl_tree_t(T) * avl_tree, T value); \ T* avl_tree_find(T)(const avl_tree_t(T)*, T val); \ bool avl_tree_erase(T)(avl_tree_t(T) * tree, T val, T * out); #define AVL_TREE_IMPL(T, CMP, DTOR) \ avl_tree_t(T) * avl_tree_new(T)() \ { \ avl_tree_t(T)* ret = ALLOC(sizeof(avl_tree_t(T))); \ ret->root = NULL; \ return ret; \ } \ static void CONCAT(avl_loose_free__, T)(avl_tree_node_t(T) * node) \ { \ if (!node) return; \ CONCAT(avl_loose_free__, T)(node->right); \ CONCAT(avl_loose_free__, T)(node->left); \ DTOR(node->value); \ FREE(node); \ } \ void avl_tree_free(T)(avl_tree_t(T) * tree) \ { \ CONCAT(avl_loose_free__, T)(tree->root); \ FREE(tree); \ } \ static inline size_t CONCAT( \ loose_size__, T)(const avl_tree_node_t(T) * node) \ { \ if (!node) return 0; \ return 1 + CONCAT(loose_size__, T)(node->left) + \ CONCAT(loose_size__, T)(node->right); \ } \ size_t avl_tree_size(T)(const avl_tree_t(T) * tree) \ { \ return CONCAT(loose_size__, T)(tree->root); \ } \ static int CONCAT(balance_factor, T)(avl_tree_node_t(T) * node) \ { \ return get_height(node->left) - get_height(node->right); \ } \ static avl_tree_node_t(T) * CONCAT(ll_rotate, T)(avl_tree_node_t(T) * node) \ { \ avl_tree_node_t(T)* child = node->left; \ node->left = child->right; \ reset_height(node); \ child->right = node; \ reset_height(child); \ return child; \ } \ static avl_tree_node_t(T) * CONCAT(rr_rotate, T)(avl_tree_node_t(T) * node) \ { \ avl_tree_node_t(T)* child = node->right; \ node->right = child->left; \ reset_height(node); \ child->left = node; \ reset_height(child); \ return child; \ } \ static avl_tree_node_t(T) * CONCAT(rl_rotate, T)(avl_tree_node_t(T) * node) \ { \ avl_tree_node_t(T)* child = node->right; \ node->right = CONCAT(ll_rotate, T)(child); \ reset_height(node); \ return CONCAT(rr_rotate, T)(node); \ } \ static avl_tree_node_t(T) * CONCAT(lr_rotate, T)(avl_tree_node_t(T) * node) \ { \ avl_tree_node_t(T)* child = node->left; \ node->left = CONCAT(rr_rotate, T)(child); \ reset_height(node); \ return CONCAT(ll_rotate, T)(node); \ } \ static avl_tree_node_t(T) * \ CONCAT(avl_tree_balance_, T)(avl_tree_node_t(T) * node) \ { \ int d = CONCAT(balance_factor, T)(node); \ if (d > 1) { \ if (CONCAT(balance_factor, T)(node->left) > 0) { \ return CONCAT(ll_rotate, T)(node); \ } else { \ return CONCAT(lr_rotate, T)(node); \ } \ } else if (d < -1) { \ if (CONCAT(balance_factor, T)(node->right) > 0) { \ return CONCAT(rl_rotate, T)(node); \ } else { \ return CONCAT(rr_rotate, T)(node); \ } \ } \ \ return node; \ } \ static avl_tree_node_t(T) * \ CONCAT(avl_tree_loose_insert_, T)( \ avl_tree_node_t(T) * node, T value, T * *ptr_out) \ { \ if (!node) { \ node = ALLOC(sizeof(avl_tree_node_t(T))); \ assert(node); \ node->left = NULL; \ node->right = NULL; \ node->value = value; \ node->height = 1; \ *ptr_out = &node->value; \ } else { \ typeof(CMP(node->value, value)) cmp = CMP(node->value, value); \ if (cmp < 0) { \ node->left = \ CONCAT(avl_tree_loose_insert_, T)(node->left, value, ptr_out); \ reset_height(node); \ node = CONCAT(avl_tree_balance_, T)(node); \ } else if (cmp > 0) { \ node->right = \ CONCAT(avl_tree_loose_insert_, T)(node->right, value, ptr_out); \ reset_height(node); \ node = CONCAT(avl_tree_balance_, T)(node); \ } \ } \ return node; \ } \ T* avl_tree_insert(T)(avl_tree_t(T) * tree, T value) \ { \ T* ret; \ tree->root = CONCAT(avl_tree_loose_insert_, T)(tree->root, value, &ret); \ return ret; \ } \ size_t avl_tree_height(T)(avl_tree_t(T) * tree) \ { \ if (!tree) return 0; \ return get_height(tree->root); \ } \ static T* CONCAT(loose_avl_tree_find, T)(avl_tree_node_t(T) * node, T value) \ { \ if (!node) return NULL; \ \ typeof(CMP(node->value, value)) cmp = CMP(node->value, value); \ if (cmp > 0) { \ return CONCAT(loose_avl_tree_find, T)(node->right, value); \ } else if (cmp < 0) { \ return CONCAT(loose_avl_tree_find, T)(node->left, value); \ } \ return &node->value; \ } \ T* avl_tree_find(T)(const avl_tree_t(T) * tree, T val) \ { \ if (!tree) return NULL; \ return CONCAT(loose_avl_tree_find, T)(tree->root, val); \ } \ static avl_tree_node_t(T) * \ CONCAT(pluck_left, T)(avl_tree_node_t(T) * node, T * into) \ { \ if (node->left) { \ node->left = CONCAT(pluck_left, T)(node->left, into); \ reset_height(node); \ return CONCAT(avl_tree_balance_, T)(node); \ } else { \ *into = node->value; \ FREE(node); \ return node->right; \ } \ } \ static avl_tree_node_t(T) * \ CONCAT(pluck_right, T)(avl_tree_node_t(T) * node, T * into) \ { \ if (node->right) { \ node->right = CONCAT(pluck_right, T)(node->right, into); \ reset_height(node); \ return CONCAT(avl_tree_balance_, T)(node); \ } else { \ *into = node->value; \ FREE(node); \ return node->left; \ } \ } \ avl_tree_node_t(T) * \ CONCAT(loose_avl_tree_erase, T)( \ avl_tree_node_t(T) * node, T value, bool* out, T* erased) \ { \ if (!node) { \ *out = 0; \ return NULL; \ } \ typeof(CMP(node->value, value)) cmp = CMP(node->value, value); \ if (cmp == 0) { \ if (erased) *erased = node->value; \ *out = 1; \ if (!node->right && !node->left) { \ FREE(node); \ return NULL; \ } \ if (get_height(node->right) > get_height(node->left)) { \ node->right = CONCAT(pluck_left, T)(node->right, &node->value); \ reset_height(node); \ node = CONCAT(avl_tree_balance_, T)(node); \ return node; \ } \ node->left = CONCAT(pluck_right, T)(node->left, &node->value); \ reset_height(node); \ node = CONCAT(avl_tree_balance_, T)(node); \ return node; \ } else if (cmp < 0) { \ node->left = \ CONCAT(loose_avl_tree_erase, T)(node->left, value, out, erased); \ reset_height(node); \ node = CONCAT(avl_tree_balance_, T)(node); \ } else { \ node->right = \ CONCAT(loose_avl_tree_erase, T)(node->right, value, out, erased); \ reset_height(node); \ node = CONCAT(avl_tree_balance_, T)(node); \ } \ return node; \ } \ bool avl_tree_erase(T)(avl_tree_t(T) * tree, T val, T * erased) \ { \ if (!tree) return 0; \ bool out; \ tree->root = \ CONCAT(loose_avl_tree_erase, T)(tree->root, val, &out, erased); \ return out; \ } #endif /* _SHARED_AVL_TREE_H_ */