「BZOJ 3697」采药人的路径 - 点分治

采药人的药田是一个树状结构,每条路径上都种植着同种药材。

采药人每种药材进行了分类。大致分为两类,一种是阴性的,一种是阳性的。

采药人每天都要进行采药活动。他走的一定是两种药材数目相等的路径。他希望他选出的路径中有一个可以作为休息站的节点(不包括起点和终点),满足起点到休息站和休息站到终点的路径也是两种药材数目相等的。他想知道他一共可以选择多少种不同的路径。

链接

BZOJ 3697

题解

点分治,考虑经过根的路径中合法的路径数量。

将边权为 改为 ,对树遍历时记录路径上的前缀和。

对根的所有子树做 DFS 遍历,设 表示当前子树前缀和为 在路径上仅出现过一次的路径数, 表示当前子树前缀和为 在路径上出现过至少两次的路径数。

如果一个前缀和 在一棵子树内出现过两次,那么在根的另一棵子树选一条前缀和为 的路径与其相接,即可组成一条合法的路径 —— 休息站可以被选择在前一条路径上另一个前缀和为 的点上。

对树进行 DFS 遍历时,记录 表示当前路径前缀和为 的节点数量,根据情况将当前节点累加在 中。

记录 为之前的所有子树中对应的路径数量,每次更新答案,统计不以根节点为休息站的路径数量

的初始值为 ,表示根节点单独组成一条路径,统计以根节点为休息站的的路径数量

代码

#include <cstdio>
#include <cstdlib>
#include <climits>
#include <queue>
#include <stack>
#include <algorithm>

const int MAXN = 100003;

struct Node;
struct Edge;

struct Node {
    Edge *e;
    int size, max, dist;
    bool solved, visited;
    Node *parent;
} N[MAXN];

struct Edge {
    Node *s, *t;
    int w;
    Edge *next;

    Edge(Node *s, Node *t, const int w) : s(s), t(t), w(w), next(s->e) {}
};

inline void addEdge(const int s, const int t, const int w) {
    N[s].e = new Edge(&N[s], &N[t], w);
    N[t].e = new Edge(&N[t], &N[s], w);
}

int n;

template <typename T, int L, int R>
struct Array {
    T a[R - L + 1];

    T &operator[](const int pos) { return a[pos - L]; }

    const T &operator[](const int pos) const { return a[pos - L]; }
};

Array<long long [2], -(MAXN - 1), MAXN - 1> f, g;

inline Node *center(Node *start) {
    std::stack<Node *> s;
    s.push(start);
    start->visited = false;
    start->parent = NULL;

    static Node *a[MAXN];
    int cnt = 0;
    while (!s.empty()) {
        Node *v = s.top();
        if (!v->visited) {
            a[cnt++] = v;
            v->visited = true;
            for (Edge *e = v->e; e; e = e->next) if (e->t != v->parent && !e->t->solved) {
                e->t->visited = false;
                e->t->parent = v;
                s.push(e->t);
            }
        } else {
            v->size = 1;
            v->max = 0;
            for (Edge *e = v->e; e; e = e->next) if (!e->t->solved && e->t->parent == v) {
                v->size += e->t->size;
                v->max = std::max(v->max, e->t->size);
            }
            s.pop();
        }
    }

    Node *res = NULL;
    for (int i = 0; i < cnt; i++) {
        a[i]->max = std::max(a[i]->max, start->size - a[i]->max);
        if (!res || res->max > a[i]->max) res = a[i];
    }

    return res;
}

inline void dfs(Node *start, const int dist, int &max) {
    std::stack<Node *> s;
    s.push(start);
    start->parent = NULL;
    start->dist = dist;
    start->visited = false;

    static int _cnt[MAXN * 2 - 1], *cnt = _cnt + MAXN - 1;
    while (!s.empty()) {
        Node *v = s.top();
        if (!v->visited) {
            f[v->dist][!cnt[v->dist] ? 0 : 1]++;
            cnt[v->dist]++;

            max = std::max(max, v->dist);
            max = std::max(max, abs(v->dist));

            v->visited = true;
            for (Edge *e = v->e; e; e = e->next) if (!e->t->solved && e->t != v->parent) {
                e->t->parent = v;
                e->t->dist = v->dist + e->w;
                e->t->visited = false;
                s.push(e->t);
            }
        } else {
            cnt[v->dist]--;
            s.pop();
        }
    }
}

/*
inline void print(const int max) {
    for (int i = -max; i <= max; i++) {
        printf("f[%d][0] = %lld, f[%d][1] = %lld\n", i, f[i][0], i, f[i][1]);
        printf("g[%d][0] = %lld, g[%d][1] = %lld\n", i, g[i][0], i, g[i][1]);
    }
    putchar('\n');
}
*/

inline long long calc(Node *root) {
    long long res = 0;
    int max = 0;
    g[0][0] = 1;
    for (Edge *e = root->e; e; e = e->next) if (!e->t->solved) {
        int curr = 0;
        dfs(e->t, e->w, curr);
        // print(max);

        res += (g[0][0] - 1) * f[0][0];
        for (int i = -curr; i <= curr; i++) {
            // printf("res += %lld\n", (g[-i][1] * f[i][1]) + (g[-i][0] * f[i][1]) + (g[-i][1] * f[i][0]));
            res += (g[-i][1] * f[i][1]) + (g[-i][0] * f[i][1]) + (g[-i][1] * f[i][0]);
        }

        for (int i = -curr; i <= curr; i++) {
            g[i][0] += f[i][0];
            g[i][1] += f[i][1];
            f[i][0] = f[i][1] = 0;
        }

        max = std::max(max, curr);
    }

    for (int i = -max; i <= max; i++) {
        g[i][0] = g[i][1] = 0;
    }

    // printf("calc(%ld) = %lld\n", root - N + 1, res);
    return res;
}

inline long long solve() {
    std::stack<Node *> s;
    s.push(&N[0]);

    long long ans = 0;
    while (!s.empty()) {
        Node *v = s.top();
        s.pop();

        Node *root = center(v);
        root->solved = true;
        ans += calc(root);

        for (Edge *e = root->e; e; e = e->next) if (!e->t->solved) {
            s.push(e->t);
        }
    }

    return ans;
}

int main() {
    scanf("%d", &n);
    for (int i = 0; i < n - 1; i++) {
        int u, v, w;
        scanf("%d %d %d", &u, &v, &w), u--, v--;
        if (w == 0) w = -1;
        addEdge(u, v, w);
    }

    long long ans = solve();
    printf("%lld\n", ans);
    // printf("counter: %lld\n", counter);

    return 0;
}