对于一个给定的字符串 ,有多少连续子串是 prefix-suffix-square free 的。
一个字符串被称为 square 当且仅当它可以由两个相同的串连接而成。例如,abab
是 square,而 aaa
不是。一个字符串是 prefix-suffix-square free 的当且仅当他的任何前缀或者后缀都不是 square。
求出从每个位置开始、结束的最短 square,进而求出从第 个位置结束的子串,其开始位置的最小值。设从第 个位置开始的最短 square 为 ,从第 个位置结束的子串,其开始位置的最小值为 ,则答案为对于每个 ,在 内满足 的 的数量之和。
考虑如何求 square:类似「NOI2016 优秀的拆分」一题,枚举 ,枚举每个长度为 的区间,则所有长度为 的 square 都会跨越这个区间 ,从端点分别向前、后求最长公共后缀、前缀。进而可以求出若干个 square。如,从 内一点开始,均有一个长度为 的 square,此时用 更新 。
因为是从小到大枚举的 ,所以每个位置只会被赋值一次,使用并查集维护所有赋值过的点,赋值过的点合并,之后直接跳过这些区间即可。
#include <cstdio>
#include <climits>
#include <cstring>
#include <algorithm>
const int MAXN = 1e5;
const int MAXN_LOG = 17; // Math.log2(1e5) = 16.609640474436812
struct SuffixArray {
int n, sa[MAXN], rk[MAXN], ht[MAXN], st[MAXN][MAXN_LOG + 1], log[MAXN + 1];
inline void build(const char *s, const int n) {
this->n = n;
static int set[MAXN], a[MAXN];
std::copy(s, s + n, set);
std::sort(set, set + n);
int *end = std::unique(set, set + n);
for (int i = 0; i < n; i++) a[i] = std::lower_bound(set, end, s[i]) - set;
static int fir[MAXN], sec[MAXN], tmp[MAXN], _buc[MAXN + 1], *buc = _buc + 1;
std::fill(buc - 1, buc + n, 0);
for (int i = 0; i < n; i++) buc[a[i]]++;
for (int i = 0; i < n; i++) buc[i] += buc[i - 1];
for (int i = 0; i < n; i++) rk[i] = buc[a[i] - 1];
for (int t = 1; t < n; t *= 2) {
for (int i = 0; i < n; i++) fir[i] = rk[i], sec[i] = i + t < n ? rk[i + t] : -1;
std::fill(buc - 1, buc + n, 0);
for (int i = 0; i < n; i++) buc[sec[i]]++;
for (int i = 0; i < n; i++) buc[i] += buc[i - 1];
for (int i = 0; i < n; i++) tmp[n - buc[sec[i]]--] = i;
std::fill(buc - 1, buc + n, 0);
for (int i = 0; i < n; i++) buc[fir[i]]++;
for (int i = 0; i < n; i++) buc[i] += buc[i - 1];
for (int i = 0; i < n; i++) sa[--buc[fir[tmp[i]]]] = tmp[i];
bool unique = true;
for (int i = 0; i < n; i++) {
if (!i) rk[sa[i]] = 0;
else if (fir[sa[i]] == fir[sa[i - 1]] && sec[sa[i]] == sec[sa[i - 1]]) rk[sa[i]] = rk[sa[i - 1]], unique = false;
else rk[sa[i]] = rk[sa[i - 1]] + 1;
if (unique) break;
for (int i = 0, k = 0; i < n; i++) {
if (!rk[i]) continue;
int j = sa[rk[i] - 1];
if (k) k--;
while (i + k < n && j + k < n && a[i + k] == a[j + k]) k++;
ht[rk[i]] = k;
#ifdef DBG
for (int i = 0; i < n; i++) printf("%d%c", sa[i], i == n - 1 ? '\n' : ' ');
for (int i = 0; i < n; i++) printf("%d%c", rk[i], i == n - 1 ? '\n' : ' ');
for (int i = 0; i < n; i++) printf("%d %s\n", ht[i], &s[sa[i]]);
for (int i = 0; i < n; i++) st[i][0] = ht[i];
for (int j = 1; (1 << j) < n; j++) {
for (int i = 0; i < n; i++) {
if (i + (1 << (j - 1)) >= n) st[i][j] = st[i][j - 1];
else st[i][j] = std::min(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
for (int i = 0; i <= n; i++) {
int x = 0;
while ((1 << x) <= i) x++;
log[i] = x - 1;
inline int rmq(const int l, const int r) {
if (l == r) return st[l][0];
int t = log[r - l];
return std::min(st[l][t], st[r - (1 << t) + 1][t]);
inline int lcp(const int i, const int j) {
if (i == j) return n - i;
int a = rk[i], b = rk[j];
if (a > b) std::swap(a, b);
return rmq(a + 1, b);
} sa1, sa2;
int n, forward[MAXN], backward[MAXN];
inline int lcp(const int i, const int j) { return sa1.lcp(i, j); }
inline int lcs(const int i, const int j) { return sa2.lcp(n - i - 1, n - j - 1); }
inline void prepare(const int n) {
struct SegmentTree {
int l, r, mid;
SegmentTree *lc, *rc;
int val;
SegmentTree(const int l, const int r, SegmentTree *lc, SegmentTree *rc) : l(l), r(r), mid(l + (r - l) / 2), lc(lc), rc(rc), val(INT_MAX) {}
~SegmentTree() {
if (lc) delete lc;
if (rc) delete rc;
void update(const int l, const int r, const int x) {
if (l > this->r || r < this->l) return;
else if (l <= this->l && r >= this->r) this->val = std::min(this->val, x);
else lc->update(l, r, x), rc->update(l, r, x);
int query(const int pos) {
return (l == r) ? val : std::min(val, ((pos <= mid) ? lc : rc)->query(pos));
static SegmentTree *build(const int l, const int r) {
if (l > r) return NULL;
else if (l == r) return new SegmentTree(l, r, NULL, NULL);
else {
const int mid = l + (r - l) / 2;
return new SegmentTree(l, r, build(l, mid), build(mid + 1, r));
} *sForawrd = SegmentTree::build(0, n - 1), *sBackward = SegmentTree::build(0, n - 1);
static struct UnionFindSet {
int a[MAXN], n;
void init(const int n) {
this->n = n;
for (int i = 0; i < n; i++) a[i] = i;
int find(const int x) {
return x == a[x] ? x : a[x] = find(a[x]);
void merge(const int child, const int parent) {
#ifdef DBG
printf("merge: %d => %d\n", child, parent);
const int _child = find(child), _parent = find(parent);
a[_child] = _parent;
void cover(const int l, const int r, const int x, int *val) {
#ifdef DBG
printf("cover(%d, %d, %d)\n", l, r, x);
for (int i = find(l); i <= r; i = find(i + 1)) {
val[i] = std::min(val[i], x);
if (i != n - 1) merge(i, i + 1);
else break;
} sForawrd, sBackward;
for (int i = 0; i < n; i++) forward[i] = backward[i] = INT_MAX;
for (int len = 1; len < n; len++) {
for (int i = 0; i + len < n; i += len) {
const int j = i + len;
const int a = std::min(lcs(i, j), len), b = std::min(lcp(i, j), len), s = a + b - 1;
// const int a = lcs(i, j), b = lcp(i, j), s = a + b - 1;
#ifdef DBG
printf("s = %d, len = %d\n", s, len);
if (s >= len) {
#ifdef DBG
printf("!FOUND! i = %d, j = %d, a = %d, b = %d, len = %d\n", i, j, a, b, len);
const int l = i - a + 1, r = j + b - 1;
sForawrd.cover(l, l + (r - (l + len * 2 - 1)), len * 2, forward);
sBackward.cover(l + len * 2 - 1, r, len * 2, backward);
// sForawrd->update(l, l + (r - (l + len * 2 - 1)), len * 2);
// sBackward->update(l + len * 2 - 1, r, len * 2);
for (int i = l + len * 2 - 1; i <= r; i++) {
printf("[%d, %d]\n", i - len * 2 + 1, i);
// for (int i = 0; i < n; i++) forward[i] = sForawrd->query(i);
// for (int i = 0; i < n; i++) backward[i] = sBackward->query(i);
#ifdef DBG
puts("forward / backward");
for (int i = 0; i < n; i++) printf("%d%c", forward[i] == INT_MAX ? 0 : forward[i], i == n - 1 ? '\n' : ' ');
for (int i = 0; i < n; i++) printf("%d%c", backward[i] == INT_MAX ? 0 : backward[i], i == n - 1 ? '\n' : ' ');
struct ChairmanTree {
struct SegmentTree {
int l, r, mid;
SegmentTree *lc, *rc;
int cnt, refCnt;
SegmentTree(const int l, const int r, SegmentTree *lc, SegmentTree *rc, const int cnt = 0) : l(l), r(r), mid(l + (r - l) / 2), lc(lc), rc(rc), cnt(cnt), refCnt(0) {}
~SegmentTree() {
if (lc && lc->unRef()) delete lc;
if (rc && rc->unRef()) delete rc;
SegmentTree *ref() {
return this;
bool unRef() {
return !refCnt--;
int query(const int l, const int r) {
if (l > this->r || r < this->l) return 0;
else if (l <= this->l && r >= this->r) return cnt;
else return (lc ? lc->query(l, r) : 0) + (rc ? rc->query(l, r) : 0);
SegmentTree *insertSelf(const int x) {
if (l == r) return this;
if (x <= mid) {
this->lc = (new SegmentTree(l, mid, NULL, NULL, 1))->insertSelf(x);
} else {
this->rc = (new SegmentTree(mid + 1, r, NULL, NULL, 1))->insertSelf(x);
return this;
SegmentTree *insert(const int x) {
if (x <= mid) {
return new SegmentTree(l, r, lc ? lc->insert(x) : (new SegmentTree(l, mid, NULL, NULL, 1))->insertSelf(x), rc ? rc->ref() : NULL, cnt + 1);
} else {
return new SegmentTree(l, r, lc ? lc->ref() : NULL, rc ? rc->insert(x) : (new SegmentTree(mid + 1, r, NULL, NULL, 1))->insertSelf(x), cnt + 1);
int lcount() {
return lc ? lc->cnt : 0;
} *root[MAXN + 1];
int n, l, r;
void clear() {
for (int i = n; i >= 0; i--) {
delete root[i];
// Build tree with array a, whose elements are limited in [l, r]
void build(const int *a, const int n, const int l, const int r) {
this->n = n, this->l = l, this->r = r;
root[0] = new SegmentTree(0, n, NULL, NULL, 0);
for (int i = 1; i <= n; i++) {
root[i] = root[i - 1]->insert(a[i - 1]);
int query(const int l, const int r, const int x) {
// Adjust input [0, n - 1] to [1, n]
SegmentTree *vr = root[r + 1], *vl = root[l];
int ans = vr->query(this->l, x);
if (vl) ans -= vl->query(this->l, x);
return ans;
/ *
while (min != max) {
const int mid = mid + (max - min) / 2, t = vr->lcount() - (vl ? vl->lcount() : 0);
if (t < k) {
k -= t, vr = vr->rc;
if (vl) vl = vl->rc;
} else {
vr = vr->lc;
if (vl) vl = vl->lc;
return min;
* /
} t;
inline unsigned long long solve(const int *limit, const int *forward, const int n) {
static struct Query {
int type;
int pos, val;
bool operator<(const Query &other) const {
if (pos < other.pos) return true;
else if (pos == other.pos && !type && other.type) return true;
else return false;
} Q[MAXN * 3];
for (int i = 0; i < n; i++) {
Q[i].pos = i;
Q[i].type = 0;
Q[i].val = limit[i];
#ifdef DBG
printf("add(pos = %d, val = %d)\n", Q[i].pos, Q[i].val);
for (int i = 0; i < n; i++) {
Q[n + i].pos = i - 1;
Q[n + i].val = i;
Q[n + i].type = -1;
#ifdef DBG
printf("query(pos = %d, val = %d, type = %d)\n", Q[n + i].pos, Q[n + i].val, Q[n + i].type);
Q[n + i + n].pos = (forward[i] == INT_MAX) ? (n - 1) : (i + forward[i] - 1);
Q[n + i + n].val = i;
Q[n + i + n].type = 1;
#ifdef DBG
printf("query(pos = %d, val = %d, type = %d)\n", Q[n + i + n].pos, Q[n + i + n].val, Q[n + i + n].type);
std::sort(Q, Q + n * 3);
static struct BinaryIndexedTree {
int a[MAXN], n;
static int lowbit(const int x) {
return x & -x;
void update(const int x, const int delta) {
a[x - 1] += delta;
for (int i = x; i <= n; i += lowbit(i)) a[i - 1] += delta;
int query(const int x) {
int ans = 0;
for (int i = x; i > 0; i -= lowbit(i)) ans += a[i - 1];
for (int i = 1; i <= x; i++) ans += a[i - 1];
return ans;
void clear(const int x) {
a[x - 1] = 0;
for (int i = x; i <= n; i += lowbit(i)) {
if (!a[i - 1]) break;
a[i - 1] = 0;
} bit;
bit.n = n;
unsigned long long ans = 0;
for (int i = 0; i < n * 3; i++) {
#ifdef DBG
printf("Query: %d %d %d\n", Q[i].type, Q[i].pos, Q[i].val);
if (Q[i].type) {
int t = bit.query(Q[i].val + 1) * Q[i].type;
ans += t;
#ifdef DBG
printf("t = %d\n", t);
} else bit.update(Q[i].val + 1, 1);
#ifdef DBG
printf("Query: %d %d %d\n", Q[i].l, Q[i].r, Q[i].val);
if (Q[i].l == -1) bit.update(Q[i].val + 1, 1);
else ans += bit.query(Q[i].r + 1) - bit.query(Q[i].l);
#ifdef DBG
printf("ans = %d\n", ans);
for (int i = 0; i < n * 2; i++) if (!Q[i].type) bit.clear(Q[i].val + 1);
return ans;
int main() {
int testcase;
scanf("%d", &testcase);
while (testcase--) {
static char s[MAXN + 1];
scanf("%s", s);
n = strlen(s);
sa1.build(s, n);
std::reverse(s, s + n);
sa2.build(s, n);
static int limit[MAXN];
for (int i = 0; i < n; i++) limit[i] = backward[i] == INT_MAX ? 0 : (i - backward[i] + 2);
#ifdef DBG
for (int i = 0; i < n; i++) printf("%d%c", limit[i], i == n - 1 ? '\n' : ' ');
// t.clear();
// t.build(limit, n, 0, n - 1);
unsigned long long ans = 0;
for (int i = 0; i < n; i++) {
ans += t.query(i, (forward[i] == INT_MAX) ? (n - 1) : (i + forward[i] - 1), i);
unsigned long long ans = solve(limit, forward, n);
printf("%llu\n", ans);
return 0;