采药人的药田是一个树状结构,每条路径上都种植着同种药材。
采药人每种药材进行了分类。大致分为两类,一种是阴性的,一种是阳性的。
采药人每天都要进行采药活动。他走的一定是两种药材数目相等的路径。他希望他选出的路径中有一个可以作为休息站的节点(不包括起点和终点),满足起点到休息站和休息站到终点的路径也是两种药材数目相等的。他想知道他一共可以选择多少种不同的路径。
链接
题解
点分治,考虑经过根的路径中合法的路径数量。
将边权为 改为 ,对树遍历时记录路径上的前缀和。
对根的所有子树做 DFS 遍历,设 表示当前子树前缀和为 且 在路径上仅出现过一次的路径数, 表示当前子树前缀和为 且 在路径上出现过至少两次的路径数。
如果一个前缀和 在一棵子树内出现过两次,那么在根的另一棵子树选一条前缀和为 的路径与其相接,即可组成一条合法的路径 —— 休息站可以被选择在前一条路径上另一个前缀和为 的点上。
对树进行 DFS 遍历时,记录 表示当前路径前缀和为 的节点数量,根据情况将当前节点累加在 或 中。
记录 、 为之前的所有子树中对应的路径数量,每次更新答案,统计不以根节点为休息站的路径数量
令 的初始值为 ,表示根节点单独组成一条路径,统计以根节点为休息站的的路径数量
代码
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <queue>
#include <stack>
#include <algorithm>
const int MAXN = 100003;
struct Node;
struct Edge;
struct Node {
Edge *e;
int size, max, dist;
bool solved, visited;
Node *parent;
} N[MAXN];
struct Edge {
Node *s, *t;
int w;
Edge *next;
Edge(Node *s, Node *t, const int w) : s(s), t(t), w(w), next(s->e) {}
};
inline void addEdge(const int s, const int t, const int w) {
N[s].e = new Edge(&N[s], &N[t], w);
N[t].e = new Edge(&N[t], &N[s], w);
}
int n;
template <typename T, int L, int R>
struct Array {
T a[R - L + 1];
T &operator[](const int pos) { return a[pos - L]; }
const T &operator[](const int pos) const { return a[pos - L]; }
};
Array<long long [2], -(MAXN - 1), MAXN - 1> f, g;
inline Node *center(Node *start) {
std::stack<Node *> s;
s.push(start);
start->visited = false;
start->parent = NULL;
static Node *a[MAXN];
int cnt = 0;
while (!s.empty()) {
Node *v = s.top();
if (!v->visited) {
a[cnt++] = v;
v->visited = true;
for (Edge *e = v->e; e; e = e->next) if (e->t != v->parent && !e->t->solved) {
e->t->visited = false;
e->t->parent = v;
s.push(e->t);
}
} else {
v->size = 1;
v->max = 0;
for (Edge *e = v->e; e; e = e->next) if (!e->t->solved && e->t->parent == v) {
v->size += e->t->size;
v->max = std::max(v->max, e->t->size);
}
s.pop();
}
}
Node *res = NULL;
for (int i = 0; i < cnt; i++) {
a[i]->max = std::max(a[i]->max, start->size - a[i]->max);
if (!res || res->max > a[i]->max) res = a[i];
}
return res;
}
inline void dfs(Node *start, const int dist, int &max) {
std::stack<Node *> s;
s.push(start);
start->parent = NULL;
start->dist = dist;
start->visited = false;
static int _cnt[MAXN * 2 - 1], *cnt = _cnt + MAXN - 1;
while (!s.empty()) {
Node *v = s.top();
if (!v->visited) {
f[v->dist][!cnt[v->dist] ? 0 : 1]++;
cnt[v->dist]++;
max = std::max(max, v->dist);
max = std::max(max, abs(v->dist));
v->visited = true;
for (Edge *e = v->e; e; e = e->next) if (!e->t->solved && e->t != v->parent) {
e->t->parent = v;
e->t->dist = v->dist + e->w;
e->t->visited = false;
s.push(e->t);
}
} else {
cnt[v->dist]--;
s.pop();
}
}
}
/*
inline void print(const int max) {
for (int i = -max; i <= max; i++) {
printf("f[%d][0] = %lld, f[%d][1] = %lld\n", i, f[i][0], i, f[i][1]);
printf("g[%d][0] = %lld, g[%d][1] = %lld\n", i, g[i][0], i, g[i][1]);
}
putchar('\n');
}
*/
inline long long calc(Node *root) {
long long res = 0;
int max = 0;
g[0][0] = 1;
for (Edge *e = root->e; e; e = e->next) if (!e->t->solved) {
int curr = 0;
dfs(e->t, e->w, curr);
// print(max);
res += (g[0][0] - 1) * f[0][0];
for (int i = -curr; i <= curr; i++) {
// printf("res += %lld\n", (g[-i][1] * f[i][1]) + (g[-i][0] * f[i][1]) + (g[-i][1] * f[i][0]));
res += (g[-i][1] * f[i][1]) + (g[-i][0] * f[i][1]) + (g[-i][1] * f[i][0]);
}
for (int i = -curr; i <= curr; i++) {
g[i][0] += f[i][0];
g[i][1] += f[i][1];
f[i][0] = f[i][1] = 0;
}
max = std::max(max, curr);
}
for (int i = -max; i <= max; i++) {
g[i][0] = g[i][1] = 0;
}
// printf("calc(%ld) = %lld\n", root - N + 1, res);
return res;
}
inline long long solve() {
std::stack<Node *> s;
s.push(&N[0]);
long long ans = 0;
while (!s.empty()) {
Node *v = s.top();
s.pop();
Node *root = center(v);
root->solved = true;
ans += calc(root);
for (Edge *e = root->e; e; e = e->next) if (!e->t->solved) {
s.push(e->t);
}
}
return ans;
}
int main() {
scanf("%d", &n);
for (int i = 0; i < n - 1; i++) {
int u, v, w;
scanf("%d %d %d", &u, &v, &w), u--, v--;
if (w == 0) w = -1;
addEdge(u, v, w);
}
long long ans = solve();
printf("%lld\n", ans);
// printf("counter: %lld\n", counter);
return 0;
}