#include "balanced.h" // Returns the height of node, handling the NULL case static int node_height(struct node *n) { return n ? n->height : 0; } // precondition: self != NULL // left and right can be NULL to denote an empty tree. // set_node(self, left, right) setups self so that it represents // a tree with left and right subtrees and the correct height, // and returns self. static struct node *set_node(struct node *self, struct node *left, struct node *right) { self->left = left; self->right = right; int lh = node_height(left); int rh = node_height(right); self->height = 1 + ((lh > rh) ? lh : rh); return self; } // precondition: max_height >= min_height // is_balanced(min_height, max_height) returns true if a tree // with sub-nodes of these heights is balanced static _Bool is_balanced(int min_height, int max_height) { return (max_height - min_height) <= 1; } typedef struct node * make_fun(struct node *self, struct node *left, struct node *right); static struct node * rot_left(make_fun make, struct node *self, struct node *left, struct node *right) { return make(right, make(self, left, right->left), right->right); } static struct node * rot_right(make_fun make, struct node *self, struct node *left, struct node *right) { return make(left, left->left, make(self, left->right, right)); } static struct node * node_left(struct node *self, struct node *left, struct node *right) { if (is_balanced(node_height(left), node_height(right))) return set_node(self, left, right); if (right && node_height(right->right) < node_height(right->left)) right = rot_right(node_left, right, right->left, right->right); return rot_left(node_left, self, left, right); } static struct node * node_right(struct node *self, struct node *left, struct node *right) { if (is_balanced(node_height(right), node_height(left))) return set_node(self, left, right); if (left && node_height(left->left) < node_height(left->right)) left = rot_left(node_right, left, left->left, left->right); return rot_right(node_right, self, left, right); } struct node *make_tree(struct node *self, struct node *left, struct node *right) { if (node_height(left) <= node_height(right)) return node_left(self, left, right); else return node_right(self, left, right); }