给定一个字符串,每次修改一个字符、插入一个字符、查询某两个后缀的最长公共前缀。
链接
题解
使用 Splay 维护字符串 Hash,在每个节点上维护整棵子树的 Hash 值,合并两棵子树的 Hash 值时,右子树的 Hash 值乘以一个较高次幂。
二分求 LCP 即可。
代码
#include <cstdio>
#include <cstring>
#include <algorithm>
const int MAXN = 100000;
const int MAXM = 150000;
const unsigned long long BASE = 233;
unsigned long long base[MAXN + 1];
struct Splay {
struct Node {
Node *c[2], *p, **r;
int size;
char val;
unsigned long long hash;
Node(Node *p, Node **r, const char val) : p(p), r(r), size(1), val(val), hash(val) {
c[0] = c[1] = NULL;
}
int relation() { return this == p->c[0] ? 0 : 1; }
void rotate() {
Node *o = p;
int x = relation();
p = o->p;
if (o->p) o->p->c[o->relation()] = this;
o->c[x] = c[x ^ 1];
if (c[x ^ 1]) c[x ^ 1]->p = o;
c[x ^ 1] = o;
o->p = this;
o->maintain(), maintain();
if (!p) *r = this;
}
Node *splay(Node *targetParent = NULL) {
while (p != targetParent) {
if (p->p == targetParent) rotate();
else if (relation() == p->relation()) p->rotate(), rotate();
else rotate(), rotate();
}
return this;
}
void maintain() {
size = 1;
if (c[0]) size += c[0]->size;
if (c[1]) size += c[1]->size;
hash = val;
if (c[1]) hash += c[1]->hash * BASE;
if (c[0]) hash = hash * base[c[0]->size] + c[0]->hash;
}
int lsize() { return c[0] ? c[0]->size : 0; }
void print(const int depth = 0) {
if (c[1]) c[1]->print(depth + 1);
for (int i = 0; i < depth; i++) putchar(' ');
printf("%c\n", val == 0 ? ' ' : val);
if (c[0]) c[0]->print(depth + 1);
}
} *r;
Splay() : r(NULL) {}
Node *build(const char *first, const char *last, Node *p) {
if (first > last) return NULL;
if (first == last) return new Node(p, &r, *first);
else {
const char *mid = first + (last - first) / 2;
Node *v = new Node(p, &r, *mid);
v->c[0] = build(first, mid - 1, v);
v->c[1] = build(mid + 1, last, v);
v->maintain();
return v;
}
}
void buildBounds(const int x) {
Node *v = r;
while (v->c[x]) v = v->c[x];
v->c[x] = new Node(v, &r, 0);
Node *u = v;
do {
u->maintain();
u = u->p;
} while (u);
v->c[x]->splay();
}
void build(const char *first, const char *last) {
r = build(first, last, NULL);
buildBounds(0);
buildBounds(1);
}
Node *select(const int k) {
int x = k + 1;
Node *v = r;
while (x != v->lsize() + 1) {
if (x < v->lsize() + 1) v = v->c[0];
else x -= v->lsize() + 1, v = v->c[1];
}
return v->splay();
}
Node *select(const int l, const int r) {
Node *a = select(l - 1), *b = select(r + 1);
a->splay();
b->splay(a);
return b->c[0];
}
Node *insert(const int pos, const char ch) {
Node *a = select(pos), *b = select(pos + 1);
a->splay();
b->splay(a);
b->c[0] = new Node(b, &r, ch);
Node *v = b->c[0];
do {
v->maintain();
v = v->p;
} while (v);
return b->c[0]->splay();
}
void update(const int pos, const char ch) {
Node *v = select(pos);
v->val = ch;
v->maintain();
}
unsigned long long query(const int l, const int r) {
return select(l, r)->hash;
}
int size() { return r->size - 2; }
} splay;
inline void print() {
splay.r->print();
puts("------------------------");
}
inline int lcp(const int a, const int b) {
int l = 0, r = std::min(splay.size() - a + 1, splay.size() - b + 1);
while (l != r) {
const int mid = l + (r - l) / 2 + 1;
if (splay.query(a, a + mid - 1) == splay.query(b, b + mid - 1)) {
l = mid;
} else {
r = mid - 1;
}
}
return l;
}
int main() {
base[0] = 1;
for (int i = 1; i <= MAXN; i++) base[i] = base[i - 1] * BASE;
static char s[MAXN + 1];
scanf("%s", s);
int n = strlen(s);
splay.build(s, s + n - 1);
int m;
scanf("%d", &m);
while (m--) {
char cmd[2];
scanf("%s", cmd);
if (cmd[0] == 'Q') {
int a, b;
scanf("%d %d", &a, &b);
printf("%d\n", lcp(a, b));
} else if (cmd[0] == 'R') {
int pos;
char ch[2];
scanf("%d %s", &pos, ch);
splay.update(pos, ch[0]);
} else if (cmd[0] == 'I') {
int pos;
char ch[2];
scanf("%d %s", &pos, ch);
splay.insert(pos, ch[0]);
}
// print();
}
return 0;
}