「BZOJ 3230」相似子串 - 后缀数组

对于一个长度为 的字符串 ,将其本质不同的所有子串按照字典序排序。我们定义两个子串的相似程度为 的最大值,其中 满足:

每次询问排序后的第 个和第 个子串的相似程度。

链接

BZOJ 3230

题解

一个长度为 的串,它的本质不同的子串数量是 级别的,所以不能直接对子串排序。

考虑到每个子串都是一个后缀的前缀,对所有后缀排序,相邻两个后缀的最长公共前缀即为重复的子串。设第 个与第 个后缀的最长公共前缀为 ,则由第 个后缀的前缀构成的子串中与前 个后缀的前缀构成的子串的不同的有 个,这其中长度最小的为 ,长度最大的为该后缀本身,这些子串的字典序即为它们长度顺序。

使用后缀数组求出相邻两个后缀的最长公共前缀,求出每个后缀对子串个数的贡献,通过二分即可定位产生第 个子串的后缀,进而求出这个子串在原串中的位置。 使用后缀数组求最长公共前缀、后缀解决即可。

代码

#include <cstdio>
#include <cassert>
#include <cstring>
#include <string>
#include <vector>
#include <algorithm>

const int MAXN = 100000;
const int MAXN_LOG = 17; // Math.log2(1e5) = 16.609640474436812;

char s[MAXN], sRev[MAXN];
int n, log[MAXN + 1];
long long cnt[MAXN];

struct SuffixArray {
    int n, sa[MAXN], rk[MAXN], ht[MAXN], st[MAXN][MAXN_LOG + 1];

    void build(const char *s, const int n) {
        this->n = n;
        static int set[MAXN], a[MAXN];
        std::copy(s, s + n, set);
        std::sort(set, set + n);
        int *end = std::unique(set, set + n);
        for (int i = 0; i < n; i++) a[i] = std::lower_bound(set, end, s[i]) - set;

        static int fir[MAXN], sec[MAXN], tmp[MAXN], _buc[MAXN + 1], *buc = _buc + 1;
        std::fill(buc - 1, buc + n, 0);
        for (int i = 0; i < n; i++) buc[a[i]]++;
        for (int i = 0; i < n; i++) buc[i] += buc[i - 1];
        for (int i = 0; i < n; i++) rk[i] = buc[a[i] - 1];

        for (int t = 1; t < n; t *= 2) {
            for (int i = 0; i < n; i++) fir[i] = rk[i], sec[i] = i + t < n ? rk[i + t] : -1;

            std::fill(buc - 1, buc + n, 0);
            for (int i = 0; i < n; i++) buc[sec[i]]++;
            for (int i = 0; i < n; i++) buc[i] += buc[i - 1];
            for (int i = 0; i < n; i++) tmp[n - buc[sec[i]]--] = i;

            std::fill(buc - 1, buc + n, 0);
            for (int i = 0; i < n; i++) buc[fir[i]]++;
            for (int i = 0; i < n; i++) buc[i] += buc[i - 1];
            for (int i = 0; i < n; i++) sa[--buc[fir[tmp[i]]]] = tmp[i];

            bool unique = true;
            for (int i = 0; i < n; i++) {
                if (!i) rk[sa[i]] = 0;
                else if (fir[sa[i]] == fir[sa[i - 1]] && sec[sa[i]] == sec[sa[i - 1]]) rk[sa[i]] = rk[sa[i - 1]], unique = false;
                else rk[sa[i]] = rk[sa[i - 1]] + 1;
            }
            if (unique) break;
        }

        for (int i = 0, k = 0; i < n; i++) {
            if (!rk[i]) continue;
            int j = sa[rk[i] - 1];
            if (k) k--;
            while (i + k < n && j + k < n && s[i + k] == s[j + k]) k++;
            ht[rk[i]] = k;
        }

        /*
        for (int i = 0; i < n; i++) printf("%d%c", sa[i], i == n - 1 ? '\n' : ' ');
        for (int i = 0; i < n; i++) printf("%d%c", rk[i], i == n - 1 ? '\n' : ' ');
        for (int i = 0; i < n; i++) printf("%d%c", ht[i], i == n - 1 ? '\n' : ' ');
        for (int i = 0; i < n; i++) printf("%d %s\n", ht[i], &s[sa[i]]);
        // */

        for (int i = 0; i < n; i++) st[i][0] = ht[i];
        for (int j = 1; (1 << j) < n; j++) {
            for (int i = 0; i < n; i++) {
                if (i + (1 << (j - 1)) >= n) st[i][j] = st[i][j - 1];
                else st[i][j] = std::min(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
            }
        }
    }

    int query(const int l, const int r) {
#ifdef TEST
        assert(l <= r);
#endif
        if (l == r) return st[l][0];
        int t = log[r - l];
        int res = std::min(st[l][t], st[r - (1 << t) + 1][t]);
#ifdef TEST
        int ans = n;
        for (int k = l; k <= r; k++) ans = std::min(ans, ht[k]);
        assert(res == ans);
#endif
        return res;
    }

    int lcp(const int i, const int j) {
        if (i == j) return n - i;
        int a = rk[i], b = rk[j];
        if (a > b) std::swap(a, b);
        return query(a + 1, b);
    }
} sa, saRev;

inline int lcp(const int i, const int j) {
    int res = sa.lcp(i, j);
#ifdef TEST
    int ans;
    for (ans = 0; i + ans < n && j + ans < n && s[i + ans] == s[j + ans]; ans++);
    assert(res == ans);
#endif
    return res;
}

inline int lcs(const int i, const int j) {
    int res = saRev.lcp(n - i - 1, n - j - 1);
#ifdef TEST
    int ans;
    for (ans = 0; i - ans >= 0 && j - ans >= 0 && s[i - ans] == s[j - ans]; ans++);
    assert(res == ans);
#endif
    return res;
}

inline void prepare() {
    for (int i = 0; i <= n; i++) {
        int x = 0;
        while ((1 << x) <= i) x++;
        log[i] = x - 1;
        // printf("log(%d) = %d\n", i, log[i]);
    }

    sa.build(s, n);

    long long x = 0;
    for (int i = 0; i < n; i++) {
        x += n - sa.sa[i];
        if (i) x -= sa.ht[i];
        cnt[i] = x;
        // printf("%lld\n", x);
    }

    std::copy(s, s + n, sRev);
    std::reverse(sRev, sRev + n);
    saRev.build(sRev, n);
}

inline bool locate(const long long k, int &l, int &r) {
    long long *p = std::upper_bound(cnt, cnt + n, k - 1);
    if (p == cnt + n) return false;
    int t = *p - k;
    l = sa.sa[p - cnt];
    r = n - t - 1;
    return true;
}

inline void all() {
    std::vector<std::string> v;
    for (int i = 0; i < n; i++) for (int j = i + 1; j <= n; j++) v.push_back(std::string(&s[i], &s[j]));
    std::sort(v.begin(), v.end());
    v.erase(std::unique(v.begin(), v.end()), v.end());
    int i = 1;
    for (std::vector<std::string>::const_iterator it = v.begin(); it != v.end(); it++) {
        int l, r;
        locate(i++, l, r);
        // printf("%d %d\n", l, r);
        for (int j = l; j <= r; j++) putchar(s[j]);
        putchar('\n');
        printf("%s\n", it->c_str());
    }
}

int main() {
    int q;
    scanf("%d %d\n%s", &n, &q, s);
    n = strlen(s);
    prepare();
    // all();

    /*
    puts("left query:");
    for (int i = 0; i < n; i++) for (int j = i; j < n; j++) printf("lq(%d, %d) = %d\n", i, j, query(stp, i, j));

    puts("right query:");
    for (int i = 0; i < n; i++) for (int j = i; j < n; j++) printf("rq(%d, %d) = %d\n", i, j, query(sts, i, j));
    */

    /*
    for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) printf("lcp(%d, %d) = %d\n", i, j, lcp(i, j));
    putchar('\n');
    for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) printf("lcs(%d, %d) = %d\n", i, j, lcs(i, j));
    */

    while (q--) {
        long long i, j;
        scanf("%lld %lld", &i, &j);
        int l1, r1, l2, r2;
        if (!locate(i, l1, r1) || !locate(j, l2, r2)) {
            puts("-1");
            continue;
        }
        int lim = std::min(r1 - l1 + 1, r2 - l2 + 1);
        long long a = std::min(lim, lcp(l1, l2)), b = std::min(lim, lcs(r1, r2));
        printf("%lld\n", a * a + b * b);
    }

    return 0;
}