「省选模拟赛」染色 - 树链剖分

给定一棵 个节点的树,树的节点标号从 开始。每个节点可以是白色或黑色,初始时每个节点的颜色为白色。要求支持以下两种操作:

  1. 将节点 涂黑;
  2. 查询节点 到所有黑点距离之和。

题解

先预处理出所有点到根的距离之和,维护当前所有黑点到根的距离之和 和黑点总数

考虑询问由根节点的向它的某个子节点 u 移动,则答案为:

树链剖分,用线段树维护 在链上的总和,每次 计算。

代码

#include <cstdio>
#include <stack>
#include <algorithm>

const int MAXN = 100000; 
const int MAXM = 100000; 

struct SegmentTree {
    int l, r;
    SegmentTree *lchild, *rchild;

    int blackCount;
    int count, lazyCount, base;
    long long sum;

    SegmentTree(int l, int r, SegmentTree *lchild, SegmentTree *rchild) : l(l), r(r), lchild(lchild), rchild(rchild), blackCount(0), count(0), lazyCount(0), base(0), sum(0) {}

    ~SegmentTree() {
        if (lchild) delete lchild;
        if (rchild) delete rchild;
    }

    void pushDown() {
        if (lazyCount) {
            if (lchild) lchild->lazyCount += lazyCount, lchild->count += lazyCount, lchild->sum += (long long)lchild->base * lazyCount;
            if (rchild) rchild->lazyCount += lazyCount, rchild->count += lazyCount, rchild->sum += (long long)rchild->base * lazyCount;
            lazyCount = 0;
        }
    }

    long long querySum(int l, int r) {
        if (l > this->r || r < this->l) return 0;
        else if (l <= this->l && r >= this->r) return sum;
        else {
            pushDown();
            long long result = 0;
            if (lchild) result += lchild->querySum(l, r);
            if (rchild) result += rchild->querySum(l, r);
            // printf("sq(%d, %d) = %d\n", l, r, result);
            return result;
        }
    }

    int queryCount(int l, int r) {
        if (l > this->r || r < this->l) return 0;
        else if (l <= this->l && r >= this->r) return count;
        else {
            pushDown();
            int result = 0;
            if (lchild) result += lchild->queryCount(l, r);
            if (rchild) result += rchild->queryCount(l, r);
            // printf("sq(%d, %d) = %d\n", l, r, result);
            return result;
        }
    }

    void updateCount(int l, int r, int delta) {
        if (l > this->r || r < this->l) return;
        else if (l <= this->l && r >= this->r) lazyCount += delta, count += delta, sum += (long long)delta * base;
        else {
            pushDown();
            sum = count = 0;
            if (lchild) lchild->updateCount(l, r, delta), count += lchild->count, sum += lchild->sum;
            if (rchild) rchild->updateCount(l, r, delta), count += rchild->count, sum += rchild->sum;
        }
    }

    void setBase(int x, int base) {
        if (x > this->r || x < this->l) return;
        else if (x == this->l && x == this->r) this->base = base;
        else {
            if (lchild) lchild->setBase(x, base);
            if (rchild) rchild->setBase(x, base);
            this->base = lchild->base + rchild->base;
        }
    }

    int queryBlackCount(int l, int r) {
        if (l > this->r || r < this->l) return 0;
        else if (l <= this->l && r >= this->r) return blackCount;
        else {
            pushDown();
            int result = 0;
            if (lchild) result += lchild->queryBlackCount(l, r);
            if (rchild) result += rchild->queryBlackCount(l, r);
            return result;
        }
    }

    void updateBlackCount(int x, int delta) {
        if (x > this->r || x < this->l) return;
        else if (x == this->l && x == this->r) blackCount += delta;
        else {
            pushDown();
            blackCount = 0;
            if (lchild) lchild->updateBlackCount(x, delta), blackCount += lchild->blackCount;
            if (rchild) rchild->updateBlackCount(x, delta), blackCount += rchild->blackCount;
        }
    }
};

SegmentTree *buildSegmentTree(int l, int r) {
    if (l > r) return NULL;
    else if (l == r) return new SegmentTree(l, r, NULL, NULL);
    else return new SegmentTree(l, r, buildSegmentTree(l, l + ((r - l) >> 1)), buildSegmentTree(l + ((r - l) >> 1) + 1, r));
}

struct Tree;
struct Path;

struct Tree {
    Tree *firstChild, *next, *parent, *maxSizeChild;
    bool visited, colored;
    Path *path;
    int size, depth, pos, posEnd;
    long long cost, costPrefixSum;
} trees[MAXN];

struct Path {
    Tree *top;
    int length;
    Path(Tree *top) : top(top), length(1) {}
};

inline void addRelation(int child, int parent) {
    trees[child].parent = &trees[parent];
    trees[child].next = trees[parent].firstChild;
    trees[parent].firstChild = &trees[child];
}

int n, m, blackCount;
long long sum;
SegmentTree *segment;

inline void cut() {
    std::stack<Tree *> s;
    s.push(&trees[0]);
    trees[0].depth = 1;
    trees[0].costPrefixSum = 0;

    while (!s.empty()) {
        Tree *t = s.top();

        if (!t->visited) {
            t->visited = true;

            for (Tree *c = t->firstChild; c; c = c->next) {
                c->depth = t->depth + 1;
                c->costPrefixSum = t->costPrefixSum + c->cost;
                s.push(c);
            }
        } else {
            t->size = 1;
            for (Tree *c = t->firstChild; c; c = c->next) {
                t->size += c->size;
                if (t->maxSizeChild == NULL || t->maxSizeChild->size < c->size) t->maxSizeChild = c;
            }

            s.pop();
        }
    }

    for (int i = 0; i < n; i++) trees[i].visited = false;

    s.push(&trees[0]);

    int timeStamp = 0;
    while (!s.empty()) {
        Tree *t = s.top();

        if (!t->visited) {
            t->visited = true;

            if (t->parent == NULL || t->parent->maxSizeChild != t) t->path = new Path(t);
            else t->path = t->parent->path, t->path->length++;

            t->pos = ++timeStamp;

            for (Tree *c = t->firstChild; c; c = c->next) {
                if (c != t->maxSizeChild) s.push(c);
            }
            if (t->maxSizeChild) s.push(t->maxSizeChild);
        } else {
            t->posEnd = timeStamp;
            s.pop();
        }
    }

    segment = buildSegmentTree(1, n);

    for (int i = 0; i < n; i++) {
        segment->setBase(trees[i].pos, trees[i].cost);
    }
}

inline int queryDist(int u, int v) {
    if (u == v) return 0;

    Tree *a = &trees[u], *b = &trees[v];
    int sum = 0;
    while (a->path != b->path) {
        if (a->path->top->depth < b->path->top->depth) std::swap(a, b);
        sum += segment->querySum(a->path->top->pos, a->pos);
        // printf("query(%d, %d)::sum = %d\n", u, v, sum);
        a = a->path->top->parent;
    }

    if (a->pos > b->pos) std::swap(a, b);
    sum += segment->querySum(a->pos, b->pos);
    sum -= a->cost;

    return sum;
}

inline void updateToBlack(int u) {
    if (trees[u].colored) return;
    trees[u].colored = true;
    sum += trees[u].costPrefixSum;
    for (Tree *t = &trees[u]; t; t = t->path->top->parent) {
        segment->updateCount(t->path->top->pos, t->pos, 1);
    }
    segment->updateBlackCount(trees[u].pos, 1);
    blackCount++;
}

inline long long queryAllDist(int u) {
    Tree *t = &trees[u];
    int childBlackCount = segment->queryBlackCount(t->pos, t->posEnd);
    long long tmp = 0;
    for (Tree *t = &trees[u]; t; t = t->path->top->parent) tmp += segment->querySum(t->path->top->pos, t->pos);
    return sum + t->costPrefixSum * blackCount - 2 * tmp;
}

int main() {
    // freopen("color.in", "r", stdin);
    // freopen("color.out", "w", stdout);

    scanf("%d %d", &n, &m);

    for (int i = 1; i < n; i++) {
        int parent;
        scanf("%d", &parent);

        addRelation(i, parent);
    }

    for (int i = 1; i < n; i++) {
        scanf("%lld", &trees[i].cost);
    }

    cut();

    for (int i = 0; i < m; i++) {
        int t, u;
        scanf("%d %d", &t, &u);

        if (t == 1) updateToBlack(u);
        else printf("%lld\n", queryAllDist(u));
    }

    fclose(stdin);
    fclose(stdout);

    return 0;
}