普通平衡树的模板。
插入、查询、删除、前趋、后继、排名、选择。
更新于 2016 年 1 月 11 日:insert
有锅,根节点处重复插入时没有使 size
增加。
#include <cstdio>
#include <climits>
struct Splay
{
struct Node
{
// root 可能会被修改,使用指向指针的指针
Node *fa, *ch[2], **root;
// x 表示这个节点上的数
// size 表示整棵树(左右子树和自身)的大小
// cnt 表示这个数 x 有多少个
int x, size, cnt;
Node(Node **root, Node *fa, int x) : root(root), fa(fa), x(x), cnt(1), size(1)
{
ch[0] = ch[1] = NULL; // 不要忘记初始化
}
// 如果当前节点是其父节点的左儿子,返回 0 否则返回 1
int relation()
{
return this == fa->ch[0] ? 0 : 1;
}
void maintain()
{
size = cnt;
if (ch[0]) size += ch[0]->size;
if (ch[1]) size += ch[1]->size;
}
void rotate()
{
// 旧的父节点
Node *old = fa;
// 旋转前当前节点与父节点的关系
int r = relation();
// 自身 <--> 父节点的父节点
fa = old->fa;
if (old->fa)
{
// 如果父节点有父节点(不是根),那么用当前节点替换父节点的父节点中父节点的位置的指向
old->fa->ch[old->relation()] = this;
}
// 原有的另一个子节点 <--> 原有的父节点
if (ch[r ^ 1])
{
ch[r ^ 1]->fa = old;
}
old->ch[r] = ch[r ^ 1]; // 右边的 ch[r ^ 1] 相当于 this->ch[r ^ 1]
// 原有的父节点 作为自身的子节点
old->fa = this;
ch[r ^ 1] = old;
// 更新维护信息
old->maintain();
maintain();
// 如果转到了根,更新储存的指向根节点的指针
if (fa == NULL)
{
*root = this;
}
}
// 旋转到某一特定位置,如在删除时将后继节点旋转为根的右儿子
// target 为目标父节点,缺省参数为 NULL,即旋转直到成为根(没有父节点)
void splay(Node *target = NULL)
{
while (fa != target) // while (父节点不是目标父节点)
{
if (fa->fa == target) // 父节点的父节点是目标父节点,直接转一次
{
rotate();
}
else if (fa->relation() == relation()) // 关系相同,先转父节点,再转自身
{
fa->rotate();
rotate();
}
else
{
rotate();
rotate();
}
}
}
// 前趋,全称 precursor 或 predecessor
Node *pred()
{
// 左子树的最右点
Node *v = ch[0];
while (v->ch[1]) v = v->ch[1];
return v;
}
// 后继,全称 successor
Node *succ()
{
// 右子树的最左点
Node *v = ch[1];
while (v->ch[0]) v = v->ch[0];
return v;
}
// 求一个节点的排名,即左子树大小
int rank()
{
return ch[0] ? ch[0]->size : 0;
}
} *root;
Splay() : root(NULL)
{
insert(INT_MAX);
insert(INT_MIN);
}
Node *insert(int x)
{
// v 是一个指向指针的指针
// 表示要插入到的位置
// 如果 v 指向一个空指针
// 那么就可以插入到这里
//
// fa 是新节点的父节点
Node **v = &root, *fa = NULL;
while (*v != NULL && (*v)->x != x) // 直到找到一个空位置,或者找到原有的值为 x 的节点
{
fa = *v; // 以上一次找到的不空的 v 作为父节点
fa->size++; // 因为要在这棵子树下插入一个新的节点
// 根据大小向左右子树迭代
if (x < fa->x)
{
v = &fa->ch[0];
}
else
{
v = &fa->ch[1];
}
}
if (*v != NULL)
{
(*v)->cnt++;
(*v)->size++; // 锅
}
else
{
(*v) = new Node(&root, fa, x);
}
(*v)->splay();
// 伸展之后,*v 这个位置存的可能不再是新节点(因为父子关系改变了,而 v 一般指向一个节点的某个子节点的位置)
// 直接返回根
return root;
}
Node *find(int x)
{
// 从根节点开始找
Node *v = root;
while (v != NULL && v->x != x)
{
if (x < v->x)
{
v = v->ch[0];
}
else
{
v = v->ch[1];
}
}
if (v) v->splay();
return v;
}
// 删除一个节点
void erase(Node *v)
{
Node *pred = v->pred(), *succ = v->succ();
pred->splay();
succ->splay(pred); // 使后继成为前趋(根)的右儿子
// 此时后继的左儿子即为要删除的节点 v,且 v 此时为叶子节点
if (v->size > 1)
{
// 如果存在不止一个 x,只需将计数减一
v->size--;
v->cnt--;
}
else
{
delete succ->ch[0]; // 使用 delete 关键字(严格地说,new 和 delete 是运算符)释放 new 分配的内存
succ->ch[0] = NULL; // 不要忘记置空
}
// 删了节点,需要将大小减小
succ->size--;
pred->size--;
}
// 删除一个数
void erase(int x)
{
Node *v = find(x);
if (!v) return; // 没有找到要删的节点,直接返回(一般题目中不会出现这种情况)
erase(v);
}
// 求一个数的前趋
int pred(int x)
{
Node *v = find(x);
if (v == NULL)
{
v = insert(x);
int res = v->pred()->x;
erase(v);
return res;
}
else
{
return v->pred()->x;
}
}
// 求一个数的后继
int succ(int x)
{
Node *v = find(x);
if (v == NULL)
{
v = insert(x);
int res = v->succ()->x;
erase(v);
return res;
}
else
{
return v->succ()->x;
}
}
// 求一个数的排名
int rank(int x)
{
Node *v = find(x);
if (v == NULL)
{
v = insert(x);
// 此时 v 已经是根节点了,因为新插入插入节点会被伸展
int res = v->rank(); // 因为有一个无穷小,所以不需要 +1
erase(v);
return res;
}
else
{
// 此时 v 已经是根节点了,因为查找到的节点会被伸展
return v->rank();
}
}
// 求第 k 小的数
int select(int k)
{
Node *v = root;
while (!(k >= v->rank() && k < v->rank() + v->cnt))
{
if (k < v->rank())
{
// 要查的第 k 比当前节点的排名小,需要继续在左子树中查询
v = v->ch[0];
}
else
{
k -= v->rank() + v->cnt;
v = v->ch[1];
}
}
v->splay();
return v->x;
}
} splay;
int main()
{
int n;
scanf("%d", &n);
while (n--)
{
int opt, x;
scanf("%d %d", &opt, &x);
if (opt == 1)
{
splay.insert(x);
}
else if (opt == 2)
{
splay.erase(x);
}
else if (opt == 3)
{
printf("%d\n", splay.rank(x));
}
else if (opt == 4)
{
printf("%d\n", splay.select(x));
}
else if (opt == 5)
{
printf("%d\n", splay.pred(x));
}
else if (opt == 6)
{
printf("%d\n", splay.succ(x));
}
}
return 0;
}