给定一棵 个节点的树,每个点有一个权值,对于 个询问 ,你需要回答 和 这两个节点间第 小的点权。
链接
题解
对树上从根向下的路径做主席树前缀和,需要倍增求 LCA。
代码
#include <cstdio>
#include <climits>
#include <queue>
const int MAXN = 100000;
const int MAXN_LOG = 17;
struct SegT *null;
struct SegT {
SegT *lc, *rc;
int cnt;
SegT(SegT *lc, SegT *rc) : lc(lc), rc(rc), cnt(lc->cnt + rc->cnt) {}
SegT(SegT *lc, SegT *rc, int cnt) : lc(lc), rc(rc), cnt(cnt) {}
SegT *insert(int l, int r, int x) {
if (l == r) return new SegT(null, null, cnt + 1);
else {
int mid = l + (r - l) / 2;
if (x <= mid) return new SegT(lc->insert(l, mid, x), rc);
else return new SegT(lc, rc->insert(mid + 1, r, x));
}
}
};
struct Node {
std::vector<Node *> adj;
Node *fa;
int dep, w;
bool vis;
SegT *seg;
} N[MAXN + 1];
inline void addEdge(int u, int v) {
N[u].adj.push_back(&N[v]);
N[v].adj.push_back(&N[u]);
}
inline void init() {
null = new SegT(NULL, NULL, 0);
null->lc = null->rc = null;
}
int n, f[MAXN + 1][MAXN_LOG + 1], logn;
inline void build() {
N[0].vis = true;
N[0].seg = null;
std::queue<Node *> q;
q.push(&N[1]);
N[1].vis = true;
N[1].dep = 1;
N[1].fa = &N[0];
while (!q.empty()) {
Node *v = q.front();
q.pop();
v->seg = v->fa->seg->insert(0, INT_MAX, v->w);
for (Node **p = &v->adj.front(), *u = *p; p <= &v->adj.back(); u = *++p) {
if (!u->vis) {
u->vis = true;
u->dep = v->dep + 1;
u->fa = v;
q.push(u);
}
}
}
while ((1 << (logn + 1)) <= n) logn++;
f[1][0] = 1;
for (int i = 2; i <= n; i++) f[i][0] = N[i].fa - N;
for (int j = 1; j <= logn; j++) {
for (int i = 1; i <= n; i++) {
f[i][j] = f[f[i][j - 1]][j - 1];
}
}
}
inline int lca(int u, int v) {
if (N[u].dep < N[v].dep) std::swap(u, v);
if (N[u].dep > N[v].dep) {
for (int i = logn; i >= 0; i--) {
if (N[f[u][i]].dep >= N[v].dep) u = f[u][i];
}
}
if (u != v) {
for (int i = logn; i >= 0; i--) {
if (f[u][i] != f[v][i]) {
u = f[u][i];
v = f[v][i];
}
}
return f[u][0];
}
return u;
}
inline int query(int u, int v, int k) {
int p = lca(u, v);
SegT *su = N[u].seg, *sv = N[v].seg, *sp = N[p].seg, *sf = N[p].fa->seg;
int l = 0, r = INT_MAX;
while (l < r) {
int mid = l + (r - l) / 2;
int s = su->lc->cnt + sv->lc->cnt - sp->lc->cnt - sf->lc->cnt;
if (k > s) {
k -= s;
l = mid + 1;
su = su->rc;
sv = sv->rc;
sp = sp->rc;
sf = sf->rc;
} else {
r = mid;
su = su->lc;
sv = sv->lc;
sp = sp->lc;
sf = sf->lc;
}
}
return l;
}
int main() {
int m;
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &N[i].w);
for (int i = 1; i <= n - 1; i++) {
int u, v;
scanf("%d %d", &u, &v);
addEdge(u, v);
}
init();
build();
int lastAns = 0;
while (m--) {
int u, v, k;
scanf("%d %d %d", &u, &v, &k);
u ^= lastAns;
printf(m ? "%d\n" : "%d", lastAns = query(u, v, k));
}
return 0;
}