给一棵树,每一条边上有一个 内的数字,求有多少有序点对 满足,将 到 的最短路上所有边上的数字连接成一个数,这个数是 的倍数。其中 。
链接
题解
点分治,考虑一棵子树内满足条件的点对。
记录 为从根到 节点路径上所有边上的数字按倒序连接成的数; 为对应按正序连接成的数; 为节点 的深度,等于 和 十进制位数。如
从
题目要求的条件即为
因为这是一个同余式,所以
整理,得
将式子右边存入哈希表中,对于每个节点
代码
#include <cstdio>
#include <cassert>
#include <queue>
#include <stack>
#include <map>
#include <iostream>
const int MAXN = 100000;
struct Node {
struct Edge *e;
Node *parent;
int size, depth, max;
bool visited, solved;
long long backward, forward;
} N[MAXN];
struct Edge {
Node *s, *t;
int w;
Edge *next;
Edge(Node *s, Node *t, const int w) : s(s), t(t), w(w), next(s->e) {}
};
inline void addEdge(const int s, const int t, const int w) {
N[s].e = new Edge(&N[s], &N[t], w);
N[t].e = new Edge(&N[t], &N[s], w);
}
int n, mod;
long long pow10[MAXN + 1], pow10Inv[MAXN + 1];
void exgcd(const long long a, const long long b, long long &g, long long &x, long long &y) {
if (!b) g = a, x = 1, y = 0;
else exgcd(b, a % b, g, y, x), y -= x * (a / b);
}
inline long long inv(const long long x) {
long long tmp1, res, tmp2;
exgcd(x, mod, tmp1, res, tmp2);
return (res % mod + mod) % mod;
}
inline Node *center(Node *start) {
std::stack<Node *> s;
s.push(start);
start->visited = false;
start->parent = NULL;
static Node *a[MAXN];
int cnt = 0;
while (!s.empty()) {
Node *v = s.top();
if (!v->visited) {
a[cnt++] = v;
v->visited = true;
for (Edge *e = v->e; e; e = e->next) if (!e->t->solved && e->t != v->parent) {
e->t->visited = false;
e->t->parent = v;
s.push(e->t);
}
} else {
v->size = 1;
v->max = 0;
for (Edge *e = v->e; e; e = e->next) if (!e->t->solved && e->t->parent == v) {
v->size += e->t->size;
v->max = std::max(v->max, e->t->size);
}
s.pop();
}
}
Node *res = NULL;
for (int i = 0; i < cnt; i++) {
a[i]->max = std::max(a[i]->max, cnt - a[i]->size);
if (!res || res->max > a[i]->max) res = a[i];
}
return res;
}
inline long long calc(Node *root, const int pre = 0) {
if (pre != 0) {
root->forward = root->backward = pre;
root->depth = 1;
} else {
root->forward = root->backward = 0;
root->depth = 0;
}
root->parent = NULL;
std::queue<Node *> q;
q.push(root);
std::map<long long, int> map;
static Node *a[MAXN];
int cnt = 0;
while (!q.empty()) {
Node *v = q.front();
q.pop();
map[(mod - v->forward) * pow10Inv[v->depth] % mod]++;
assert((mod - v->forward) * pow10Inv[v->depth] % mod >= 0);
a[cnt++] = v;
for (Edge *e = v->e; e; e = e->next) if (!e->t->solved && e->t != v->parent) {
e->t->forward = (v->forward * 10 + e->w) % mod;
e->t->backward = (e->w * pow10[v->depth] + v->backward) % mod;
#ifdef DBG
printf("%lld <--> %lld\n", e->t->forward, e->t->backward);
#endif
e->t->parent = v;
e->t->depth = v->depth + 1;
q.push(e->t);
}
}
long long ans = 0; // pre ? 0 : -1 * 0;
#ifdef FORCE
for (int i = 0; i < cnt; i++) {
for (int j = 0; j < cnt; j++) {
if ((a[i]->backward * pow10[a[j]->depth] + a[j]->forward) % 3 == 0 /*&& !(a[i]->backward == 0 && a[j]->forward == 0)*/) {
ans++;
#ifdef DBG
printf("%lld\n", a[i]->backward * pow10[a[j]->depth] + a[j]->forward);
#endif
}
if (a[i]->backward == 0 && a[j]->forward == 0) printf("%lu %lu in %lu, %d\n", a[i] - N + 1, a[j] - N + 1, root - N + 1, pre), ans--;
}
}
#else
map[0]--;
for (int i = 0; i < cnt; i++) {
ans += map[a[i]->backward];
#ifdef DBG
printf("back = %lld, ans + %d\n", a[i]->backward, map[a[i]->backward]);
#endif
}
#endif
#ifdef DBG
for (std::map<long long, int>::const_iterator it = map.begin(); it != map.end(); it++) printf("[%lld => %d]\n", it->first, it->second);
printf("** ans = %lld\n", ans);
#endif
return ans;
}
inline long long solve() {
std::queue<Node *> q;
q.push(&N[0]);
long long ans = 0;
while (!q.empty()) {
Node *v = q.front();
q.pop();
Node *root = center(v);
root->solved = true;
#ifdef DBG
puts("+++++++++++++++++++");
#endif
ans += calc(root);
for (Edge *e = root->e; e; e = e->next) if (!e->t->solved) {
#ifdef DBG
puts("-------------------");
#endif
ans -= calc(e->t, e->w);
q.push(e->t);
}
}
return ans;
}
int main() {
scanf("%d %d", &n, &mod);
#ifdef FORCE
mod = 1e9 + 7;
#endif
for (int i = 0; i < n - 1; i++) {
int u, v, w;
scanf("%d %d %d", &u, &v, &w);
addEdge(u, v, w % mod);
}
pow10[0] = pow10Inv[0] = 1;
for (int i = 1; i <= n; i++) {
pow10[i] = pow10[i - 1] * 10 % mod;
pow10Inv[i] = inv(pow10[i]);
}
std::cout << solve() << std::endl;
return 0;
}