给 个字符串,询问每个字符串有多少子串(不要求本质不同,不包括空串)是所有 个字符串中至少 个字符串的子串。
链接
题解
将所有串用在字符集之外的不同字符连接,对所得串建立后缀数组,考虑一个极大的两两间最长公共前缀(下文中省略) 的连通块。如果这个块内包含来自 个原串前缀,则这些后缀长度为 的前缀是满足题意的子串。
从大到小枚举 ,用并查集维护每个 的连通块,用 set
维护连通块内。每次合并相邻的连通块,统计长度为 的子串对答案的贡献。将 的子串留到下一次继续合并的时候统计。
但是会有一种情况,一个 的连通块附近没有 的连通块,则长度为 的子串不会被统计,并且长度为 的子串和长度为 的子串是相同的,一起统计即可。
对于 的情况,需要特判,输出每个字符串的子串数量。
代码
#include <cstdio>
#include <cstring>
#include <vector>
#include <set>
#include <algorithm>
#include <iostream>
const int MAXN = 100000;
const int MAXLEN = 100000 + MAXN;
int s[MAXLEN];
int n, m, sa[MAXLEN], rk[MAXLEN], ht[MAXLEN], pos[MAXN], belong[MAXLEN];
inline void suffixArray() {
static int set[MAXLEN], a[MAXLEN];
std::copy(s, s + n, set);
std::sort(set, set + n);
int *end = std::unique(set, set + n);
for (int i = 0; i < n; i++) a[i] = std::lower_bound(set, end, s[i]) - set;
static int fir[MAXLEN], sec[MAXLEN], tmp[MAXLEN], _buc[MAXLEN + 1], *buc = _buc + 1;
for (int i = 0; i < n; i++) buc[a[i]]++;
for (int i = 0; i < n; i++) buc[i] += buc[i - 1];
for (int i = 0; i < n; i++) rk[i] = buc[a[i] - 1];
for (int t = 1; t < n; t *= 2) {
for (int i = 0; i < n; i++) fir[i] = rk[i], sec[i] = i + t < n ? rk[i + t] : -1;
std::fill(buc - 1, buc + n, 0);
for (int i = 0; i < n; i++) buc[sec[i]]++;
for (int i = 0; i < n; i++) buc[i] += buc[i - 1];
for (int i = 0; i < n; i++) tmp[n - buc[sec[i]]--] = i;
std::fill(buc - 1, buc + n, 0);
for (int i = 0; i < n; i++) buc[fir[i]]++;
for (int i = 0; i < n; i++) buc[i] += buc[i - 1];
for (int i = 0; i < n; i++) sa[--buc[fir[tmp[i]]]] = tmp[i];
bool unique = true;
for (int i = 0; i < n; i++) {
if (!i) rk[sa[i]] = 0;
else if (fir[sa[i]] == fir[sa[i - 1]] && sec[sa[i]] == sec[sa[i - 1]]) rk[sa[i]] = rk[sa[i - 1]], unique = false;
else rk[sa[i]] = rk[sa[i - 1]] + 1;
}
if (unique) break;
}
for (int i = 0, k = 0; i < n; i++) {
if (!rk[i]) continue;
int j = sa[rk[i] - 1];
if (k) k--;
while (i + k < n && j + k < n && s[i + k] == s[j + k]) k++;
ht[rk[i]] = k;
}
#ifdef DBG
for (int i = 0; i < n; i++) printf("%d%c", sa[i], i == n - 1 ? '\n' : ' ');
for (int i = 0; i < n; i++) printf("%d%c", rk[i], i == n - 1 ? '\n' : ' ');
for (int i = 0; i < n; i++) printf("%d%c", ht[i], i == n - 1 ? '\n' : ' ');
#endif
}
struct UnionFindSet {
int a[MAXLEN], l[MAXLEN], r[MAXLEN];
std::set<int> set[MAXLEN];
void init(const int n, const int *belong) {
for (int i = 0; i < n; i++) {
a[i] = i;
rk[i] = 1;
l[i] = r[i] = i;
set[i].insert(belong[i]);
}
}
int find(const int x) {
return x == a[x] ? x : a[x] = find(a[x]);
}
int merge(const int x, const int y) {
#ifdef DBG
printf("merge(%d, %d)\n", x, y);
#endif
int _x = find(x), _y = find(y);
if (set[_x].size() < set[_y].size()) std::swap(_x, _y);
a[_y] = _x;
set[_x].insert(set[_y].begin(), set[_y].end());
l[_x] = std::min(l[_x], l[_y]);
r[_x] = std::max(r[_x], r[_y]);
return _x;
}
int uniqueCount(const int x) {
return set[x].size();
}
void getRange(const int x, int &l, int &r) {
l = this->l[x];
r = this->r[x];
}
} ufs;
long long gap[MAXLEN + 1];
inline void apply(const int id, const int val = 1) {
int l, r;
ufs.getRange(id, l, r);
gap[l] += val, gap[r + 1] -= val;
#ifdef DBG
for (int i = l; i <= r; i++) printf("ans[%d] += %d\n", belong[i], val);
#endif
}
int main() {
int k;
scanf("%d %d", &m, &k);
int *p = s;
int spliter = 233;
for (int i = 0; i < m; i++) {
#ifdef DBG
printf("%d\n", i);
#endif
pos[i] = p - s;
static char buf[MAXLEN];
scanf("%s", buf);
int len = strlen(buf);
if (k == 1) {
std::cout << (static_cast<long long>(len) * (len + 1) / 2) << (i == m - 1 ? '\n' : ' ');
}
for (int i = 0; i < len; i++) *p++ = buf[i];
*p++ = spliter++;
}
if (k == 1) return 0;
*--p = '\0';
n = p - s;
suffixArray();
for (int i = 0; i < n; i++) belong[i] = std::upper_bound(pos, pos + m, sa[i]) - pos - 1;
ufs.init(n, belong);
#ifdef DBG
for (int i = 0; i < n; i++) {
printf("%d %d %s\n", belong[i], ht[i], &s[sa[i]]);
}
#endif
std::vector<int> v[MAXLEN];
int maxH = 0;
for (int i = 0; i < n; i++) v[ht[i]].push_back(i), maxH = std::max(maxH, ht[i]);
for (int i = maxH; i > 0; i--) {
#ifdef DBG
printf("Now processing h = %d\n", i);
#endif
std::vector<int> vec;
for (std::vector<int>::const_iterator it = v[i].begin(); it != v[i].end(); it++) {
int id = ufs.merge(*it, *it - 1);
if (ufs.uniqueCount(id) >= k) {
// apply(id);
vec.push_back(id);
}
}
for (std::vector<int>::iterator it = vec.begin(); it != vec.end(); it++) *it = ufs.find(*it);
std::sort(vec.begin(), vec.end());
std::vector<int>::const_iterator end = std::unique(vec.begin(), vec.end());
for (std::vector<int>::const_iterator it = vec.begin(); it != end; it++) {
int l, r;
ufs.getRange(*it, l, r);
int x = ht[l];
if (r != n - 1) x = std::max(x, ht[r + 1]);
apply(*it, i - x);
}
}
static long long ans[MAXN];
for (int i = 1; i < n; i++) gap[i] += gap[i - 1];
for (int i = 0; i < n; i++) ans[belong[i]] += gap[i];
// for (int i = 0; i < m; i++) printf("%lld%c", ans[i], i == m - 1 ? '\n' : ' ');
for (int i = 0; i < m; i++) {
std::cout << ans[i]; // << (i == m - 1 ? '\0' : ' ');
if (i != m - 1) std::cout << ' ';
}
return 0;
}