「Codeforces 716E」Digit Tree - 点分治

给一棵树,每一条边上有一个 内的数字,求有多少有序点对 满足,将 的最短路上所有边上的数字连接成一个数,这个数是 的倍数。其中

链接

Codeforces 716E

题解

点分治,考虑一棵子树内满足条件的点对。

记录 为从根到 节点路径上所有边上的数字按倒序连接成的数; 为对应按正序连接成的数; 为节点 的深度,等于 十进制位数。如

例子

的路径组成的数可以表示为

题目要求的条件即为

因为这是一个同余式,所以 可以是模意义下的。

整理,得

将式子右边存入哈希表中,对于每个节点 ,对答案的贡献即为哈希表中 出现的次数。

代码

#include <cstdio>
#include <cassert>
#include <queue>
#include <stack>
#include <map>
#include <iostream>

const int MAXN = 100000;

struct Node {
    struct Edge *e;
    Node *parent;
    int size, depth, max;
    bool visited, solved;
    long long backward, forward;
} N[MAXN];

struct Edge {
    Node *s, *t;
    int w;
    Edge *next;

    Edge(Node *s, Node *t, const int w) : s(s), t(t), w(w), next(s->e) {}
};

inline void addEdge(const int s, const int t, const int w) {
    N[s].e = new Edge(&N[s], &N[t], w);
    N[t].e = new Edge(&N[t], &N[s], w);
}

int n, mod;
long long pow10[MAXN + 1], pow10Inv[MAXN + 1];

void exgcd(const long long a, const long long b, long long &g, long long &x, long long &y) {
    if (!b) g = a, x = 1, y = 0;
    else exgcd(b, a % b, g, y, x), y -= x * (a / b);
}

inline long long inv(const long long x) {
    long long tmp1, res, tmp2;
    exgcd(x, mod, tmp1, res, tmp2);
    return (res % mod + mod) % mod;
}

inline Node *center(Node *start) {
    std::stack<Node *> s;
    s.push(start);
    start->visited = false;
    start->parent = NULL;

    static Node *a[MAXN];
    int cnt = 0;
    while (!s.empty()) {
        Node *v = s.top();
        if (!v->visited) {
            a[cnt++] = v;
            v->visited = true;
            for (Edge *e = v->e; e; e = e->next) if (!e->t->solved && e->t != v->parent) {
                e->t->visited = false;
                e->t->parent = v;
                s.push(e->t);
            }
        } else {
            v->size = 1;
            v->max = 0;
            for (Edge *e = v->e; e; e = e->next) if (!e->t->solved && e->t->parent == v) {
                v->size += e->t->size;
                v->max = std::max(v->max, e->t->size);
            }
            s.pop();
        }
    }

    Node *res = NULL;
    for (int i = 0; i < cnt; i++) {
        a[i]->max = std::max(a[i]->max, cnt - a[i]->size);
        if (!res || res->max > a[i]->max) res = a[i];
    }

    return res;
}

inline long long calc(Node *root, const int pre = 0) {
    if (pre != 0) {
        root->forward = root->backward = pre;
        root->depth = 1;
    } else {
        root->forward = root->backward = 0;
        root->depth = 0;
    }
    root->parent = NULL;

    std::queue<Node *> q;
    q.push(root);

    std::map<long long, int> map;
    static Node *a[MAXN];
    int cnt = 0;
    while (!q.empty()) {
        Node *v = q.front();
        q.pop();

        map[(mod - v->forward) * pow10Inv[v->depth] % mod]++;
        assert((mod - v->forward) * pow10Inv[v->depth] % mod >= 0);
        a[cnt++] = v;

        for (Edge *e = v->e; e; e = e->next) if (!e->t->solved && e->t != v->parent) {
            e->t->forward = (v->forward * 10 + e->w) % mod;
            e->t->backward = (e->w * pow10[v->depth] + v->backward) % mod;
#ifdef DBG
            printf("%lld <--> %lld\n", e->t->forward, e->t->backward);
#endif
            e->t->parent = v;
            e->t->depth = v->depth + 1;
            q.push(e->t);
        }
    }

    long long ans = 0; // pre ? 0 : -1 * 0;
#ifdef FORCE
    for (int i = 0; i < cnt; i++) {
        for (int j = 0; j < cnt; j++) {
            if ((a[i]->backward * pow10[a[j]->depth] + a[j]->forward) % 3 == 0 /*&& !(a[i]->backward == 0 && a[j]->forward == 0)*/) {
                ans++;
#ifdef DBG
                printf("%lld\n", a[i]->backward * pow10[a[j]->depth] + a[j]->forward);
#endif
            }
            if (a[i]->backward == 0 && a[j]->forward == 0) printf("%lu %lu in %lu, %d\n", a[i] - N + 1, a[j] - N + 1, root - N + 1, pre), ans--;
        }
    }
#else
    map[0]--;
    for (int i = 0; i < cnt; i++) {
        ans += map[a[i]->backward];
#ifdef DBG
        printf("back = %lld, ans + %d\n", a[i]->backward, map[a[i]->backward]);
#endif
    }
#endif
#ifdef DBG
    for (std::map<long long, int>::const_iterator it = map.begin(); it != map.end(); it++) printf("[%lld => %d]\n", it->first, it->second);
    printf("** ans = %lld\n", ans);
#endif

    return ans;
}

inline long long solve() {
    std::queue<Node *> q;
    q.push(&N[0]);

    long long ans = 0;
    while (!q.empty()) {
        Node *v = q.front();
        q.pop();

        Node *root = center(v);
        root->solved = true;

#ifdef DBG
        puts("+++++++++++++++++++");
#endif
        ans += calc(root);
        for (Edge *e = root->e; e; e = e->next) if (!e->t->solved) {
#ifdef DBG
            puts("-------------------");
#endif
            ans -= calc(e->t, e->w);
            q.push(e->t);
        }
    }

    return ans;
}

int main() {
    scanf("%d %d", &n, &mod);
#ifdef FORCE
    mod = 1e9 + 7;
#endif
    for (int i = 0; i < n - 1; i++) {
        int u, v, w;
        scanf("%d %d %d", &u, &v, &w);
        addEdge(u, v, w % mod);
    }

    pow10[0] = pow10Inv[0] = 1;
    for (int i = 1; i <= n; i++) {
        pow10[i] = pow10[i - 1] * 10 % mod;
        pow10Inv[i] = inv(pow10[i]);
    }

    std::cout << solve() << std::endl;

    return 0;
}