「BZOJ 3277」串 - 后缀数组 + 并查集 + 启发式合并

n 个字符串,询问每个字符串有多少子串(不要求本质不同,不包括空串)是所有 n 个字符串中至少 k 个字符串的子串。

链接

BZOJ 3277
Codeforces 204E

题解

将所有串用在字符集之外的不同字符连接,对所得串建立后缀数组,考虑一个极大的两两间最长公共前缀(下文中省略) \geq x 的连通块。如果这个块内包含来自 \geq k 个原串前缀,则这些后缀长度为 x 的前缀是满足题意的子串。

从大到小枚举 x ,用并查集维护每个 \geq x 的连通块,用 set 维护连通块内。每次合并相邻的连通块,统计长度为 x 的子串对答案的贡献。将 < x 的子串留到下一次继续合并的时候统计。

但是会有一种情况,一个 x 的连通块附近没有 x - 1 的连通块,则长度为 x - 1 的子串不会被统计,并且长度为 x - 1 的子串和长度为 x 的子串是相同的,一起统计即可。

对于 k = 1 的情况,需要特判,输出每个字符串的子串数量。

代码

#include <cstdio>
#include <cstring>
#include <vector>
#include <set>
#include <algorithm>
#include <iostream>

const int MAXN = 100000;
const int MAXLEN = 100000 + MAXN;

int s[MAXLEN];
int n, m, sa[MAXLEN], rk[MAXLEN], ht[MAXLEN], pos[MAXN], belong[MAXLEN];

inline void suffixArray() {
    static int set[MAXLEN], a[MAXLEN];
    std::copy(s, s + n, set);
    std::sort(set, set + n);
    int *end = std::unique(set, set + n);
    for (int i = 0; i < n; i++) a[i] = std::lower_bound(set, end, s[i]) - set;

    static int fir[MAXLEN], sec[MAXLEN], tmp[MAXLEN], _buc[MAXLEN + 1], *buc = _buc + 1;
    for (int i = 0; i < n; i++) buc[a[i]]++;
    for (int i = 0; i < n; i++) buc[i] += buc[i - 1];
    for (int i = 0; i < n; i++) rk[i] = buc[a[i] - 1];

    for (int t = 1; t < n; t *= 2) {
        for (int i = 0; i < n; i++) fir[i] = rk[i], sec[i] = i + t < n ? rk[i + t] : -1;

        std::fill(buc - 1, buc + n, 0);
        for (int i = 0; i < n; i++) buc[sec[i]]++;
        for (int i = 0; i < n; i++) buc[i] += buc[i - 1];
        for (int i = 0; i < n; i++) tmp[n - buc[sec[i]]--] = i;

        std::fill(buc - 1, buc + n, 0);
        for (int i = 0; i < n; i++) buc[fir[i]]++;
        for (int i = 0; i < n; i++) buc[i] += buc[i - 1];
        for (int i = 0; i < n; i++) sa[--buc[fir[tmp[i]]]] = tmp[i];

        bool unique = true;
        for (int i = 0; i < n; i++) {
            if (!i) rk[sa[i]] = 0;
            else if (fir[sa[i]] == fir[sa[i - 1]] && sec[sa[i]] == sec[sa[i - 1]]) rk[sa[i]] = rk[sa[i - 1]], unique = false;
            else rk[sa[i]] = rk[sa[i - 1]] + 1;
        }
        if (unique) break;
    }

    for (int i = 0, k = 0; i < n; i++) {
        if (!rk[i]) continue;
        int j = sa[rk[i] - 1];
        if (k) k--;
        while (i + k < n && j + k < n && s[i + k] == s[j + k]) k++;
        ht[rk[i]] = k;
    }

#ifdef DBG
    for (int i = 0; i < n; i++) printf("%d%c", sa[i], i == n - 1 ? '\n' : ' ');
    for (int i = 0; i < n; i++) printf("%d%c", rk[i], i == n - 1 ? '\n' : ' ');
    for (int i = 0; i < n; i++) printf("%d%c", ht[i], i == n - 1 ? '\n' : ' ');
#endif
}

struct UnionFindSet {
    int a[MAXLEN], l[MAXLEN], r[MAXLEN];
    std::set<int> set[MAXLEN];

    void init(const int n, const int *belong) {
        for (int i = 0; i < n; i++) {
            a[i] = i;
            rk[i] = 1;
            l[i] = r[i] = i;
            set[i].insert(belong[i]);
        }
    }

    int find(const int x) {
        return x == a[x] ? x : a[x] = find(a[x]);
    }

    int merge(const int x, const int y) {
#ifdef DBG
        printf("merge(%d, %d)\n", x, y);
#endif
        int _x = find(x), _y = find(y);
        if (set[_x].size() < set[_y].size()) std::swap(_x, _y);
        a[_y] = _x;
        set[_x].insert(set[_y].begin(), set[_y].end());
        l[_x] = std::min(l[_x], l[_y]);
        r[_x] = std::max(r[_x], r[_y]);
        return _x;
    }

    int uniqueCount(const int x) {
        return set[x].size();
    }

    void getRange(const int x, int &l, int &r) {
        l = this->l[x];
        r = this->r[x];
    }
} ufs;

long long gap[MAXLEN + 1];
inline void apply(const int id, const int val = 1) {
    int l, r;
    ufs.getRange(id, l, r);
    gap[l] += val, gap[r + 1] -= val;
#ifdef DBG
    for (int i = l; i <= r; i++) printf("ans[%d] += %d\n", belong[i], val);
#endif
}

int main() {
    int k;
    scanf("%d %d", &m, &k);

    int *p = s;
    int spliter = 233;
    for (int i = 0; i < m; i++) {
#ifdef DBG
        printf("%d\n", i);
#endif
        pos[i] = p - s;
        static char buf[MAXLEN];
        scanf("%s", buf);
        int len = strlen(buf);
        if (k == 1) {
            std::cout << (static_cast<long long>(len) * (len + 1) / 2) << (i == m - 1 ? '\n' : ' ');
        }
        for (int i = 0; i < len; i++) *p++ = buf[i];
        *p++ = spliter++;
    }

    if (k == 1) return 0;

    *--p = '\0';
    n = p - s;

    suffixArray();

    for (int i = 0; i < n; i++) belong[i] = std::upper_bound(pos, pos + m, sa[i]) - pos - 1;
    ufs.init(n, belong);

#ifdef DBG
    for (int i = 0; i < n; i++) {
        printf("%d %d %s\n", belong[i], ht[i], &s[sa[i]]);
    }
#endif

    std::vector<int> v[MAXLEN];
    int maxH = 0;
    for (int i = 0; i < n; i++) v[ht[i]].push_back(i), maxH = std::max(maxH, ht[i]);
    for (int i = maxH; i > 0; i--) {
#ifdef DBG
        printf("Now processing h = %d\n", i);
#endif
        std::vector<int> vec;
        for (std::vector<int>::const_iterator it = v[i].begin(); it != v[i].end(); it++) {
            int id = ufs.merge(*it, *it - 1);
            if (ufs.uniqueCount(id) >= k) {
                // apply(id);
                vec.push_back(id);
            }
        }
        for (std::vector<int>::iterator it = vec.begin(); it != vec.end(); it++) *it = ufs.find(*it);
        std::sort(vec.begin(), vec.end());
        std::vector<int>::const_iterator end = std::unique(vec.begin(), vec.end());
        for (std::vector<int>::const_iterator it = vec.begin(); it != end; it++) {
            int l, r;
            ufs.getRange(*it, l, r);
            int x = ht[l];
            if (r != n - 1) x = std::max(x, ht[r + 1]);
            apply(*it, i - x);
        }
    }

    static long long ans[MAXN];
    for (int i = 1; i < n; i++) gap[i] += gap[i - 1];
    for (int i = 0; i < n; i++) ans[belong[i]] += gap[i];

    // for (int i = 0; i < m; i++) printf("%lld%c", ans[i], i == m - 1 ? '\n' : ' ');
    for (int i = 0; i < m; i++) {
        std::cout << ans[i]; //  << (i == m - 1 ? '\0' : ' ');
        if (i != m - 1) std::cout << ' ';
    }

    return 0;
}