Splay 学习笔记(三)

在《Splay 学习笔记(一)》中,我们学会了利用 Splay 来维护二叉排序树,现在让我们来把我们的 Splay 变得更加优美。

模板请见《Splay 模板 + 详细注释》

结构体定义

两个孩子用一个数组来存,0 表示左孩子,1 表示右孩子,不需要再编写函数来获得某个孩子的引用了。

引入 count 成员,表示这个值共出现了几次,不再重复插入相同的值,效率可以得到提高,特别是求前趋后继,实现起来也会变得更加简单。

enum Relation {
    L = 0, R = 1
};

struct Node {
    Node *child[2], *parent, **root;
    T value;
    int size, count;
}

Splay 操作

把之前的“旋转到某位置”改为“旋转直到某节点成为自己的父节点”,不需要二级指针了,也不需要判断如果参数为 NULL 那么转到根了。

void splay(Node *targetParent = NULL) {
    while (parent != targetParent) {
        if (parent->parent == targetParent) rotate();
        else if (parent->relation() == relation()) parent->rotate(), rotate();
        else rotate(), rotate();
    }
}

节点的前趋 / 后继

直接 Splay 后求即可,不需要多次迭代了。

Node *pred() {
    splay();
    Node *v = child[L];
    while (v->child[R]) v = v->child[R];
    return v;
}

Node *succ() {
    splay();
    Node *v = child[R];
    while (v->child[L]) v = v->child[L];
    return v;
}

选择

选择第 k 小的元素时,需要把循环的条件改为**k 是否在 [rank + 1, rank + count]**的范围内,迭代到右子树时也要做相应的改动。

Node *select(int k) {
    k++;
    Node *v = root;
    while (!(v->rank() + 1 <= k && v->rank() + v->count >= k)) {
        if (k < v->rank() + 1) {
            v = v->child[L];
        } else {
            k -= v->rank() + v->count;
            v = v->child[R];
        }
    }
    v->splay();
    return v;
}

完整代码(普通平衡树)

#include <cstdio>
#include <climits>

const int MAXN = 100000;

template <typename T, T INF>
struct Splay {
    enum Relation {
        L = 0, R = 1
    };

    struct Node {
        Node *child[2], *parent, **root;
        T value;
        int size, count;

        Node(Node *parent, const T &value, Node **root) : parent(parent), value(value), root(root), count(1) {
            child[L] = child[R] = NULL;
        }

        ~Node() {
            if (child[L]) delete child[L];
            if (child[R]) delete child[R];
        }

        Relation relation() {
            return this == parent->child[L] ? L : R;
        }

        void maintain() {
            size = (child[L] ? child[L]->size : 0) + (child[R] ? child[R]->size : 0) + count;
        }

        void rotate() {
            Relation x = relation();
            Node *oldParent = parent;

            if (oldParent->parent) oldParent->parent->child[oldParent->relation()] = this;
            parent = oldParent->parent;

            oldParent->child[x] = child[x ^ 1];
            if (child[x ^ 1]) child[x ^ 1]->parent = oldParent;

            child[x ^ 1] = oldParent;
            oldParent->parent = this;

            oldParent->maintain(), maintain();

            if (!parent) *root = this;
        }

        void splay(Node *targetParent = NULL) {
            while (parent != targetParent) {
                if (parent->parent == targetParent) rotate();
                else if (parent->relation() == relation()) parent->rotate(), rotate();
                else rotate(), rotate();
            }
        }

        Node *pred() {
            splay();
            Node *v = child[L];
            while (v->child[R]) v = v->child[R];
            return v;
        }

        Node *succ() {
            splay();
            Node *v = child[R];
            while (v->child[L]) v = v->child[L];
            return v;
        }

        int rank() {
            return !child[L] ? 0 : child[L]->size;
        }
    } *root;

    Splay() : root(NULL) {
        insert(INF), insert(-INF);
    }

    ~Splay() {
        delete root;
    }

    Node *find(const T &value) {
        Node *v = root;
        while (v && value != v->value) {
            if (value < v->value) {
                v = v->child[L];
            } else {
                v = v->child[R];
            }
        }

        if (!v) return NULL;

        v->splay();
        return v;
    }

    Node *insert(const T &value) {
        Node *v = find(value);
        if (v) {
            v->count++, v->maintain();
            return v;
        }

        Node **target = &root, *parent = NULL;

        while (*target) {
            parent = *target;
            parent->size++;
            if (value < parent->value) {
                target = &parent->child[L];
            } else {
                target = &parent->child[R];
            }
        }

        *target = new Node(parent, value, &root);
        (*target)->splay();

        return root;
    }

    void erase(const T &value) {
        erase(find(value));
    }

    void erase(Node *v) {
        if (v->count != 1) {
            v->splay();
            v->count--;
            v->maintain();
            return;
        }

        Node *pred = v->pred();
        Node *succ = v->succ();

        pred->splay();
        succ->splay(pred);

        delete succ->child[L];
        succ->child[L] = NULL;

        succ->maintain(), pred->maintain();
    }

    int rank(const T &value) {
        Node *v = find(value);
        if (v) return v->rank();
        else {
            v = insert(value);
            int ans = v->rank();
            erase(v);
            return ans;
        }
    }

    Node *select(int k) {
        k++;
        Node *v = root;
        while (!(v->rank() + 1 <= k && v->rank() + v->count >= k)) {
            if (k < v->rank() + 1) {
                v = v->child[L];
            } else {
                k -= v->rank() + v->count;
                v = v->child[R];
            }
        }
        v->splay();
        return v;
    }

    const T &pred(const T &value) {
        Node *v = find(value);
        if (v) return v->pred()->value;
        else {
            v = insert(value);
            const T &ans = v->pred()->value;
            erase(v);
            return ans;
        }
    }

    const T &succ(const T &value) {
        Node *v = find(value);
        if (v) return v->succ()->value;
        else {
            v = insert(value);
            const T &ans = v->succ()->value;
            erase(v);
            return ans;
        }
    }
};

int n;
Splay<int, INT_MAX> splay;

void dfs(Splay<int, INT_MAX>::Node *v, int depth) {
    if (v->child[Splay<int, INT_MAX>::L]) dfs(v->child[Splay<int, INT_MAX>::L], depth + 1);
    for (int i = 0; i < depth; i++) {
        putchar(' ');
    }
    printf("%d\n", v->value);
    if (v->child[Splay<int, INT_MAX>::R]) dfs(v->child[Splay<int, INT_MAX>::R], depth + 1);
}

void print() {
    dfs(splay.root, 0);
    puts("--------------------------------------------------");
}

int main() {
    scanf("%d", &n);

    for (int i = 0; i < n; i++) {
        int command, x;
        scanf("%d %d", &command, &x);
        if (command == 1) {
            splay.insert(x);
        } else if (command == 2) {
            splay.erase(x);
        } else if (command == 3) {
            printf("%d\n", splay.rank(x));
        } else if (command == 4) {
            printf("%d\n", splay.select(x)->value);
        } else if (command == 5) {
            printf("%d\n", splay.pred(x));
        } else if (command == 6) {
            printf("%d\n", splay.succ(x));
        }
    }

    return 0;
}