给你一棵 Tree,以及这棵 Tree 上边的距离。问有多少对点它们两者间的距离小于等于 。
链接
题解
点分治,考虑「经过根且两端在不同子树的路径」中满足条件的路径数目(如果两端在同一子树内,则距离不是最短)。
遍历整棵树,得到所有点到根的距离,排序后得到一个递增序列。问题转化为序列中找两个元素加起来小于等于 的方案数。
显然,对于给定的数 ,满足 的数 是连续的一段。我们只要从小到大枚举 ,并维护 指向的位置,每次 增加时减小 ,直到 ,此时 的数都满足条件,对答案的贡献即为 。
这样求出来的路径会包含两端在同一子树的路径,我们需要再减去每棵子树的「经过根的路径」的答案数。实现时相当于对每棵子树做一遍上述过程,但计算距离时计算的还是相对于当前根的距离。
时间复杂度为 。
代码
#include <cstdio>
#include <algorithm>
#include <stack>
#include <queue>
#include <vector>
const int MAXN = 40000;
struct Node;
struct Edge;
struct Node {
Edge *e;
bool solved, visited;
int size, dist, max;
Node *parent;
} N[MAXN];
struct Edge {
Node *s, *t;
int w;
Edge *next;
Edge(Node *s, Node *t, const int w) : s(s), t(t), w(w), next(s->e) {}
};
inline void addEdge(const int s, const int t, const int w) {
N[s].e = new Edge(&N[s], &N[t], w);
N[t].e = new Edge(&N[t], &N[s], w);
}
int n, k;
// int cnt_root, cnt_calc;
inline Node *center(Node *start) {
std::stack<Node *> s;
s.push(start);
start->visited = false;
start->parent = NULL;
static Node *a[MAXN];
int cnt = 0;
while (!s.empty()) {
Node *v = s.top();
if (!v->visited) {
a[cnt++] = v;
v->visited = true;
for (Edge *e = v->e; e; e = e->next) if (e->t != v->parent && !e->t->solved) {
e->t->parent = v;
e->t->visited = false;
s.push(e->t);
}
} else {
v->size = 1;
v->max = 0;
for (Edge *e = v->e; e; e = e->next) if (!e->t->solved && e->t->parent == v) {
v->size += e->t->size;
v->max = std::max(v->max, e->t->size);
}
s.pop();
}
}
Node *res = NULL;
for (int i = 0; i < cnt; i++) {
a[i]->max = std::max(a[i]->max, start->size - a[i]->max);
if (!res || res->max > a[i]->max) res = a[i];
}
// printf("root(%ld) = %ld\n", start - N + 1, res - N + 1);
// cnt_root++;
return res;
}
inline int calc(Node *root, const int dist = 0) {
static int a[MAXN];
int cnt = 0;
std::queue<Node *> q;
q.push(root);
root->dist = dist;
root->parent = NULL;
while (!q.empty()) {
Node *v = q.front();
q.pop();
a[cnt++] = v->dist;
for (Edge *e = v->e; e; e = e->next) if (!e->t->solved && e->t != v->parent) {
e->t->parent = v;
e->t->dist = v->dist + e->w;
q.push(e->t);
}
}
int res = 0;
std::sort(a, a + cnt);
for (int i = 0, j = cnt - 1; i < j; i++) {
while (i < j && a[i] + a[j] > k) j--;
res += j - i;
}
// cnt_calc++;
return res;
}
inline int solve() {
std::stack<Node *> s;
s.push(&N[0]);
int ans = 0;
while (!s.empty()) {
Node *v = s.top();
s.pop();
// printf("work(%ld)\n", v - N + 1);
Node *root = center(v);
root->solved = true;
ans += calc(root);
for (Edge *e = root->e; e; e = e->next) if (!e->t->solved) {
ans -= calc(e->t, e->w);
s.push(e->t);
}
}
return ans;
}
int main() {
scanf("%d", &n);
for (int i = 0; i < n - 1; i++) {
int u, v, w;
scanf("%d %d %d", &u, &v, &w), u--, v--;
addEdge(u, v, w);
}
scanf("%d", &k);
printf("%d\n", solve());
// printf("cnt_root = %d, cnt_calc = %d\n", cnt_root, cnt_calc);
return 0;
}