给定一棵 个节点的树,树的节点标号从 开始。每个节点可以是白色或黑色,初始时每个节点的颜色为白色。要求支持以下两种操作:
- 将节点 涂黑;
- 查询节点 到所有黑点距离之和。
题解
先预处理出所有点到根的距离之和,维护当前所有黑点到根的距离之和 和黑点总数 。
考虑询问由根节点的向它的某个子节点 u 移动,则答案为:
即
树链剖分,用线段树维护 在链上的总和,每次 计算。
代码
#include <cstdio>
#include <stack>
#include <algorithm>
const int MAXN = 100000;
const int MAXM = 100000;
struct SegmentTree {
int l, r;
SegmentTree *lchild, *rchild;
int blackCount;
int count, lazyCount, base;
long long sum;
SegmentTree(int l, int r, SegmentTree *lchild, SegmentTree *rchild) : l(l), r(r), lchild(lchild), rchild(rchild), blackCount(0), count(0), lazyCount(0), base(0), sum(0) {}
~SegmentTree() {
if (lchild) delete lchild;
if (rchild) delete rchild;
}
void pushDown() {
if (lazyCount) {
if (lchild) lchild->lazyCount += lazyCount, lchild->count += lazyCount, lchild->sum += (long long)lchild->base * lazyCount;
if (rchild) rchild->lazyCount += lazyCount, rchild->count += lazyCount, rchild->sum += (long long)rchild->base * lazyCount;
lazyCount = 0;
}
}
long long querySum(int l, int r) {
if (l > this->r || r < this->l) return 0;
else if (l <= this->l && r >= this->r) return sum;
else {
pushDown();
long long result = 0;
if (lchild) result += lchild->querySum(l, r);
if (rchild) result += rchild->querySum(l, r);
// printf("sq(%d, %d) = %d\n", l, r, result);
return result;
}
}
int queryCount(int l, int r) {
if (l > this->r || r < this->l) return 0;
else if (l <= this->l && r >= this->r) return count;
else {
pushDown();
int result = 0;
if (lchild) result += lchild->queryCount(l, r);
if (rchild) result += rchild->queryCount(l, r);
// printf("sq(%d, %d) = %d\n", l, r, result);
return result;
}
}
void updateCount(int l, int r, int delta) {
if (l > this->r || r < this->l) return;
else if (l <= this->l && r >= this->r) lazyCount += delta, count += delta, sum += (long long)delta * base;
else {
pushDown();
sum = count = 0;
if (lchild) lchild->updateCount(l, r, delta), count += lchild->count, sum += lchild->sum;
if (rchild) rchild->updateCount(l, r, delta), count += rchild->count, sum += rchild->sum;
}
}
void setBase(int x, int base) {
if (x > this->r || x < this->l) return;
else if (x == this->l && x == this->r) this->base = base;
else {
if (lchild) lchild->setBase(x, base);
if (rchild) rchild->setBase(x, base);
this->base = lchild->base + rchild->base;
}
}
int queryBlackCount(int l, int r) {
if (l > this->r || r < this->l) return 0;
else if (l <= this->l && r >= this->r) return blackCount;
else {
pushDown();
int result = 0;
if (lchild) result += lchild->queryBlackCount(l, r);
if (rchild) result += rchild->queryBlackCount(l, r);
return result;
}
}
void updateBlackCount(int x, int delta) {
if (x > this->r || x < this->l) return;
else if (x == this->l && x == this->r) blackCount += delta;
else {
pushDown();
blackCount = 0;
if (lchild) lchild->updateBlackCount(x, delta), blackCount += lchild->blackCount;
if (rchild) rchild->updateBlackCount(x, delta), blackCount += rchild->blackCount;
}
}
};
SegmentTree *buildSegmentTree(int l, int r) {
if (l > r) return NULL;
else if (l == r) return new SegmentTree(l, r, NULL, NULL);
else return new SegmentTree(l, r, buildSegmentTree(l, l + ((r - l) >> 1)), buildSegmentTree(l + ((r - l) >> 1) + 1, r));
}
struct Tree;
struct Path;
struct Tree {
Tree *firstChild, *next, *parent, *maxSizeChild;
bool visited, colored;
Path *path;
int size, depth, pos, posEnd;
long long cost, costPrefixSum;
} trees[MAXN];
struct Path {
Tree *top;
int length;
Path(Tree *top) : top(top), length(1) {}
};
inline void addRelation(int child, int parent) {
trees[child].parent = &trees[parent];
trees[child].next = trees[parent].firstChild;
trees[parent].firstChild = &trees[child];
}
int n, m, blackCount;
long long sum;
SegmentTree *segment;
inline void cut() {
std::stack<Tree *> s;
s.push(&trees[0]);
trees[0].depth = 1;
trees[0].costPrefixSum = 0;
while (!s.empty()) {
Tree *t = s.top();
if (!t->visited) {
t->visited = true;
for (Tree *c = t->firstChild; c; c = c->next) {
c->depth = t->depth + 1;
c->costPrefixSum = t->costPrefixSum + c->cost;
s.push(c);
}
} else {
t->size = 1;
for (Tree *c = t->firstChild; c; c = c->next) {
t->size += c->size;
if (t->maxSizeChild == NULL || t->maxSizeChild->size < c->size) t->maxSizeChild = c;
}
s.pop();
}
}
for (int i = 0; i < n; i++) trees[i].visited = false;
s.push(&trees[0]);
int timeStamp = 0;
while (!s.empty()) {
Tree *t = s.top();
if (!t->visited) {
t->visited = true;
if (t->parent == NULL || t->parent->maxSizeChild != t) t->path = new Path(t);
else t->path = t->parent->path, t->path->length++;
t->pos = ++timeStamp;
for (Tree *c = t->firstChild; c; c = c->next) {
if (c != t->maxSizeChild) s.push(c);
}
if (t->maxSizeChild) s.push(t->maxSizeChild);
} else {
t->posEnd = timeStamp;
s.pop();
}
}
segment = buildSegmentTree(1, n);
for (int i = 0; i < n; i++) {
segment->setBase(trees[i].pos, trees[i].cost);
}
}
inline int queryDist(int u, int v) {
if (u == v) return 0;
Tree *a = &trees[u], *b = &trees[v];
int sum = 0;
while (a->path != b->path) {
if (a->path->top->depth < b->path->top->depth) std::swap(a, b);
sum += segment->querySum(a->path->top->pos, a->pos);
// printf("query(%d, %d)::sum = %d\n", u, v, sum);
a = a->path->top->parent;
}
if (a->pos > b->pos) std::swap(a, b);
sum += segment->querySum(a->pos, b->pos);
sum -= a->cost;
return sum;
}
inline void updateToBlack(int u) {
if (trees[u].colored) return;
trees[u].colored = true;
sum += trees[u].costPrefixSum;
for (Tree *t = &trees[u]; t; t = t->path->top->parent) {
segment->updateCount(t->path->top->pos, t->pos, 1);
}
segment->updateBlackCount(trees[u].pos, 1);
blackCount++;
}
inline long long queryAllDist(int u) {
Tree *t = &trees[u];
int childBlackCount = segment->queryBlackCount(t->pos, t->posEnd);
long long tmp = 0;
for (Tree *t = &trees[u]; t; t = t->path->top->parent) tmp += segment->querySum(t->path->top->pos, t->pos);
return sum + t->costPrefixSum * blackCount - 2 * tmp;
}
int main() {
// freopen("color.in", "r", stdin);
// freopen("color.out", "w", stdout);
scanf("%d %d", &n, &m);
for (int i = 1; i < n; i++) {
int parent;
scanf("%d", &parent);
addRelation(i, parent);
}
for (int i = 1; i < n; i++) {
scanf("%lld", &trees[i].cost);
}
cut();
for (int i = 0; i < m; i++) {
int t, u;
scanf("%d %d", &t, &u);
if (t == 1) updateToBlack(u);
else printf("%lld\n", queryAllDist(u));
}
fclose(stdin);
fclose(stdout);
return 0;
}