「BZOJ 1251」序列终结者 - Splay

给定一个长度为 N 的序列,每个序列的元素是一个整数。要支持以下三种操作:

  1. [L,R] 这个区间内的所有数加上 V
  2. [L,R] 这个区间翻转,比如 1 2 3 4 变成 4 3 2 1
  3. [L,R] 这个区间中的最大值。最开始所有元素都是 0

链接

BZOJ 1251
CodeVS 4655

题解

Splay 裸题,比文艺平衡树强一点,那个没有区间修改和查询。

区间修改:和翻转一样,维护一个 lazy-tag,然后 pushDown() 的时候下放即可。

区间查询:维护一个子树值的和,查询的时候直接选择区间然后返回这个和。需要在 maintain() 中维护。

细节需要注意,各种下放 ……

代码

#include <cstdio>
#include <cstring>
#include <algorithm>

const int MAXN = 50000;
const int MAXM = 100000;

void print();

template <typename T>
struct Splay {
    enum Relation {
        L = 0, R = 1
    };

    struct Node {
        Node *child[2], *parent, **root;
        T value, max, lazy;
        int size;
        bool reversed, bound;

        Node(Node *parent, const T &value, Node **root, bool bound = false) : parent(parent), value(value), lazy(0), max(value), reversed(false), root(root), bound(bound), size(1) {
            child[L] = child[R] = NULL;
        }

        ~Node() {
            if (child[L]) delete child[L];
            if (child[R]) delete child[R];
        }

        Relation relation() {
            return this == parent->child[L] ? L : R;
        }

        void maintain() {
            pushDown();

            size = 1;
            if (child[L]) size += child[L]->size;
            if (child[R]) size += child[R]->size;

            max = value;
            if (child[L]) max = std::max(max, child[L]->max);
            if (child[R]) max = std::max(max, child[R]->max);
        }

        void pushDown() {
            if (reversed) {
                if (child[L]) child[L]->reversed ^= 1;
                if (child[R]) child[R]->reversed ^= 1;
                std::swap(child[L], child[R]);

                reversed = false;
            }

            if (lazy) {
                if (child[L]) child[L]->lazy += lazy, child[L]->value += lazy, child[L]->max += lazy;
                if (child[R]) child[R]->lazy += lazy, child[R]->value += lazy, child[R]->max += lazy;

                lazy = 0;
            }
        }

        void rotate() {
            if (parent->parent) parent->parent->pushDown();
            parent->pushDown(), pushDown();

            Relation x = relation();
            Node *oldParent = parent;

            if (oldParent->parent) oldParent->parent->child[oldParent->relation()] = this;
            parent = oldParent->parent;

            oldParent->child[x] = child[x ^ 1];
            if (child[x ^ 1]) child[x ^ 1]->parent = oldParent;

            child[x ^ 1] = oldParent;
            oldParent->parent = this;

            oldParent->maintain(), maintain();

            if (!parent) *root = this;
        }

        void splay(Node *targetParent = NULL) {
            while (parent != targetParent) {
                if (parent->parent == targetParent) rotate();
                else {
                    parent->parent->pushDown(), parent->pushDown();
                    if (parent->relation() == relation()) parent->rotate(), rotate();
                    else rotate(), rotate();
                }
            }
        }

        int rank() {
            return !child[L] ? 0 : child[L]->size;
        }
    } *root;

    Splay() : root(NULL) {}

    ~Splay() {
        if (root) delete root;
    }

    void build(const T *a, int n) {
        root = buildRange(a, 1, n, NULL);
        buildBound(L), buildBound(R);
    }

    Node *buildRange(const T *a, int l, int r, Node *parent) {
        if (l > r) return NULL;
        int mid = l + ((r - l) >> 1);

        Node *v = new Node(parent, a[mid - 1], &root);
        if (l != r) {
            v->child[L] = buildRange(a, l, mid - 1, v);
            v->child[R] = buildRange(a, mid + 1, r, v);
        }

        v->maintain();
        return v;
    }

    void buildBound(Relation x) {
        Node **v = &root, *parent = NULL;
        while (*v) {
            parent = *v;
            parent->size++;
            v = &parent->child[x];
        }
        *v = new Node(parent, 0, &root, true);
        (*v)->maintain();
    }

    Node *select(int k) {
        k++;
        Node *v = root;
        //v->pushDown();
        while (v->pushDown(), k != v->rank() + 1) {
            if (k < v->rank() + 1) {
                v = v->child[L];
            } else {
                k -= v->rank() + 1;
                v = v->child[R];
            }
            //printf("select(k = %d)\n", k);
            //print();
            //v->pushDown();
        }
        v->splay();
        return v;
    }

    Node *select(int l, int r) {
        Node *lbound = select(l - 1);
        Node *rbound = select(r + 1);

        lbound->splay();
        rbound->splay(lbound);

        return rbound->child[L];
    }

    void update(int l, int r, const T &addition) {
        Node *range = select(l, r);
        range->value += addition, range->lazy += addition, range->max += addition;
    }

    const T &queryMax(int l, int r) {
        Node *range = select(l, r);
        return range->max;
    }

    void reverse(int l, int r) {
        Node *range = select(l, r);
        range->reversed ^= 1;
    }

    void fetch(T *a) {
        dfsFetch(a, root);
    }

    void dfsFetch(T *&a, Node *v) {
        if (!v) return;
        v->pushDown(), v->maintain();
        dfsFetch(a, v->child[L]);
        if (!v->bound) *a++ = v->value;
        dfsFetch(a, v->child[R]);
    }
};

int n, m, a[MAXN];
Splay<int> splay;

void dfs(Splay<int>::Node *v, int depth) {
    if (!v) return;
    dfs(v->child[Splay<int>::L], depth + 1);
    v->pushDown(), v->maintain();
    for (int i = 0; i < depth; i++) {
        putchar(' ');
    }
    printf("%d : %d\n", v->value, v->size);
    dfs(v->child[Splay<int>::R], depth + 1);
}

void print() {
    dfs(splay.root, 0);
    puts("--------------------------------------------------");
}

int main() {
    scanf("%d %d", &n, &m);

    memset(a, 0, sizeof(int) * n);
    splay.build(a, n);

    for (int i = 0; i < m; i++) {
        int command;
        scanf("%d", &command);
        if (command == 1) {
            int l, r, addition;
            scanf("%d %d %d", &l, &r, &addition);
            splay.update(l, r, addition);
        } else if (command == 2) {
            int l, r;
            scanf("%d %d", &l, &r);
            splay.reverse(l, r);
        } else if (command == 3) {
            int l, r;
            scanf("%d %d", &l, &r);
            printf("%d\n", splay.queryMax(l, r));
        } else throw;
    }

    //print();

    return 0;
}