「BZOJ 2152」聪聪可可 - 点分治

求在一棵 个点的带权树上随机选择两个有序点(可以相同),两点距离为 的倍数的概率。

链接

BZOJ 2152

题解

点分治,每次分治计算从根出发的长度 的值为 的路径数量为

点对有序,考虑两条路径可以连接,因此 是答案的一部分。 考虑 对答案的贡献,每条路径可以单独算,也可以与另一条连接,注意到每条路径可以与其连接的路径条数是一个等差数列,因此该部分答案为 ,点对有序,再乘以 ,加上单独算的答案,为

因此,统计 即可。

最终答案需要处以 (总方案数),之后约分即可。

代码

#include <cstdio>
#include <queue>
#include <stack>
#include <algorithm>

const int MAXN = 20000;

struct Node;
struct Edge;

struct Node {
    Edge *e;
    int size, dist, max;
    bool solved, visited;
    Node *parent;
} N[MAXN];

struct Edge {
    Node *s, *t;
    int w;
    Edge *next;

    Edge(Node *s, Node *t, const int w) : s(s), t(t), w(w), next(s->e) {}
};

inline void addEdge(const int s, const int t, const int w) {
    N[s].e = new Edge(&N[s], &N[t], w);
    N[t].e = new Edge(&N[t], &N[s], w);
}

int n;

inline Node *center(Node *start) {
    std::stack<Node *> s;
    s.push(start);
    start->visited = false;
    start->parent = NULL;

    static Node *a[MAXN];
    int cnt = 0;
    while (!s.empty()) {
        Node *v = s.top();
        if (!v->visited) {
            v->visited = true;
            a[cnt++] = v;
            for (Edge *e = v->e; e; e = e->next) if (!e->t->solved && e->t != v->parent) {
                e->t->visited = false;
                e->t->parent = v;
                s.push(e->t);
            }
        } else {
            v->size = 1;
            v->max = 0;
            for (Edge *e = v->e; e; e = e->next) if (!e->t->solved && e->t->parent == v) {
                v->size += e->t->size;
                v->max = std::max(v->max, e->t->size);
            }
            s.pop();
        }
    }

    Node *res = NULL;
    for (int i = 0; i < cnt; i++) {
        a[i]->max = std::max(start->size - a[i]->max, a[i]->max);
        if (!res || res->max > a[i]->max) res = a[i];
    }

    return res;
}

inline int calc(Node *root, const int dist = 0) {
    std::queue<Node *> q;
    q.push(root);
    root->parent = NULL;
    root->dist = dist;

    int a[3] = { 0, 0, 0 }, cnt = 0;
    while (!q.empty()) {
        Node *v = q.front();
        q.pop();

        cnt++;
        a[v->dist %= 3]++;

        for (Edge *e = v->e; e; e = e->next) if (!e->t->solved && e->t != v->parent) {
            e->t->parent = v;
            e->t->dist = v->dist + e->w;
            q.push(e->t);
        }
    }

    return a[0] * a[0] + a[1] * a[2] * 2;
}

inline int solve() {
    std::stack<Node *> s;
    s.push(&N[0]);

    int res = 0;
    while (!s.empty()) {
        Node *v = s.top();
        s.pop();

        Node *root = center(v);
        res += calc(root);
        root->solved = true;

        for (Edge *e = root->e; e; e = e->next) if (!e->t->solved) {
            res -= calc(e->t, e->w);
            s.push(e->t);
        }
    }

    return res;
}

int main() {
    scanf("%d", &n);
    for (int i = 0; i < n - 1; i++) {
        int u, v, w;
        scanf("%d %d %d", &u, &v, &w), u--, v--, w %= 3;
        addEdge(u, v, w);
    }

    int cnt = solve(), sum = n * n, g = std::__gcd(cnt, sum);
    printf("%d/%d\n", cnt / g, sum / g);

    return 0;
}