「HDU 5906」Square Revolution - 后缀数组 + 并查集 + 树状数组

对于一个给定的字符串 ,有多少连续子串是 prefix-suffix-square free 的。

一个字符串被称为 square 当且仅当它可以由两个相同的串连接而成。例如,ababaa 是 square,而 aaaabba 不是。一个字符串是 prefix-suffix-square free 的当且仅当他的任何前缀或者后缀都不是 square。

链接

HDU 5906

题解

求出从每个位置开始、结束的最短 square,进而求出从第 个位置结束的子串,其开始位置的最小值。设从第 个位置开始的最短 square 为 ,从第 个位置结束的子串,其开始位置的最小值为 ,则答案为对于每个 ,在 内满足 的数量之和。

考虑如何求 square:类似「NOI2016 优秀的拆分」一题,枚举 ,枚举每个长度为 的区间,则所有长度为 的 square 都会跨越这个区间 ,从端点分别向前、后求最长公共后缀、前缀。进而可以求出若干个 square。如,从 内一点开始,均有一个长度为 的 square,此时用 更新

因为是从小到大枚举的 ,所以每个位置只会被赋值一次,使用并查集维护所有赋值过的点,赋值过的点合并,之后直接跳过这些区间即可。

代码

#include <cstdio>
#include <climits>
#include <cstring>
#include <algorithm>

const int MAXN = 1e5;
const int MAXN_LOG = 17; // Math.log2(1e5) = 16.609640474436812

struct SuffixArray {
    int n, sa[MAXN], rk[MAXN], ht[MAXN], st[MAXN][MAXN_LOG + 1], log[MAXN + 1];

    inline void build(const char *s, const int n) {
        this->n = n;
        static int set[MAXN], a[MAXN];
        std::copy(s, s + n, set);
        std::sort(set, set + n);
        int *end = std::unique(set, set + n);
        for (int i = 0; i < n; i++) a[i] = std::lower_bound(set, end, s[i]) - set;

        static int fir[MAXN], sec[MAXN], tmp[MAXN], _buc[MAXN + 1], *buc = _buc + 1;
        std::fill(buc - 1, buc + n, 0);
        for (int i = 0; i < n; i++) buc[a[i]]++;
        for (int i = 0; i < n; i++) buc[i] += buc[i - 1];
        for (int i = 0; i < n; i++) rk[i] = buc[a[i] - 1];

        for (int t = 1; t < n; t *= 2) {
            for (int i = 0; i < n; i++) fir[i] = rk[i], sec[i] = i + t < n ? rk[i + t] : -1;

            std::fill(buc - 1, buc + n, 0);
            for (int i = 0; i < n; i++) buc[sec[i]]++;
            for (int i = 0; i < n; i++) buc[i] += buc[i - 1];
            for (int i = 0; i < n; i++) tmp[n - buc[sec[i]]--] = i;

            std::fill(buc - 1, buc + n, 0);
            for (int i = 0; i < n; i++) buc[fir[i]]++;
            for (int i = 0; i < n; i++) buc[i] += buc[i - 1];
            for (int i = 0; i < n; i++) sa[--buc[fir[tmp[i]]]] = tmp[i];

            bool unique = true;
            for (int i = 0; i < n; i++) {
                if (!i) rk[sa[i]] = 0;
                else if (fir[sa[i]] == fir[sa[i - 1]] && sec[sa[i]] == sec[sa[i - 1]]) rk[sa[i]] = rk[sa[i - 1]], unique = false;
                else rk[sa[i]] = rk[sa[i - 1]] + 1;
            }
            if (unique) break;
        }

        for (int i = 0, k = 0; i < n; i++) {
            if (!rk[i]) continue;
            int j = sa[rk[i] - 1];
            if (k) k--;
            while (i + k < n && j + k < n && a[i + k] == a[j + k]) k++;
            ht[rk[i]] = k;
        }

#ifdef DBG
        for (int i = 0; i < n; i++) printf("%d%c", sa[i], i == n - 1 ? '\n' : ' ');
        for (int i = 0; i < n; i++) printf("%d%c", rk[i], i == n - 1 ? '\n' : ' ');
        for (int i = 0; i < n; i++) printf("%d %s\n", ht[i], &s[sa[i]]);
#endif

        for (int i = 0; i < n; i++) st[i][0] = ht[i];
        for (int j = 1; (1 << j) < n; j++) {
            for (int i = 0; i < n; i++) {
                if (i + (1 << (j - 1)) >= n) st[i][j] = st[i][j - 1];
                else st[i][j] = std::min(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
            }
        }

        for (int i = 0; i <= n; i++) {
            int x = 0;
            while ((1 << x) <= i) x++;
            log[i] = x - 1;
        }
    }

    inline int rmq(const int l, const int r) {
        if (l == r) return st[l][0];
        int t = log[r - l];
        return std::min(st[l][t], st[r - (1 << t) + 1][t]);
    }

    inline int lcp(const int i, const int j) {
        if (i == j) return n - i;
        int a = rk[i], b = rk[j];
        if (a > b) std::swap(a, b);
        return rmq(a + 1, b);
    }
} sa1, sa2;

int n, forward[MAXN], backward[MAXN];

inline int lcp(const int i, const int j) { return sa1.lcp(i, j); }
inline int lcs(const int i, const int j) { return sa2.lcp(n - i - 1, n - j - 1); }

inline void prepare(const int n) {
    /*
    struct SegmentTree {
        int l, r, mid;
        SegmentTree *lc, *rc;
        int val;

        SegmentTree(const int l, const int r, SegmentTree *lc, SegmentTree *rc) : l(l), r(r), mid(l + (r - l) / 2), lc(lc), rc(rc), val(INT_MAX) {}

        ~SegmentTree() {
            if (lc) delete lc;
            if (rc) delete rc;
        }

        void update(const int l, const int r, const int x) {
            if (l > this->r || r < this->l) return;
            else if (l <= this->l && r >= this->r) this->val = std::min(this->val, x);
            else lc->update(l, r, x), rc->update(l, r, x);
        }

        int query(const int pos) {
            return (l == r) ? val : std::min(val, ((pos <= mid) ? lc : rc)->query(pos));
        }

        static SegmentTree *build(const int l, const int r) {
            if (l > r) return NULL;
            else if (l == r) return new SegmentTree(l, r, NULL, NULL);
            else {
                const int mid = l + (r - l) / 2;
                return new SegmentTree(l, r, build(l, mid), build(mid + 1, r));
            }
        }
    } *sForawrd = SegmentTree::build(0, n - 1), *sBackward = SegmentTree::build(0, n - 1);
    */

    static struct UnionFindSet {
        int a[MAXN], n;

        void init(const int n) {
            this->n = n;
            for (int i = 0; i < n; i++) a[i] = i;
        }

        int find(const int x) {
            return x == a[x] ? x : a[x] = find(a[x]);
        }

        void merge(const int child, const int parent) {
#ifdef DBG
            printf("merge: %d => %d\n", child, parent);
#endif
            const int _child = find(child), _parent = find(parent);
            a[_child] = _parent;
        }

        void cover(const int l, const int r, const int x, int *val) {
#ifdef DBG
            printf("cover(%d, %d, %d)\n", l, r, x);
#endif
            for (int i = find(l); i <= r; i = find(i + 1)) {
                val[i] = std::min(val[i], x);
                if (i != n - 1) merge(i, i + 1);
                else break;
            }
        }
    } sForawrd, sBackward;

    sForawrd.init(n);
    sBackward.init(n);

    for (int i = 0; i < n; i++) forward[i] = backward[i] = INT_MAX;

    for (int len = 1; len < n; len++) {
        for (int i = 0; i + len < n; i += len) {
            const int j = i + len;
            const int a = std::min(lcs(i, j), len), b = std::min(lcp(i, j), len), s = a + b - 1;
            // const int a = lcs(i, j), b = lcp(i, j), s = a + b - 1;
#ifdef DBG
            printf("s = %d, len = %d\n", s, len);
#endif
            if (s >= len) {
#ifdef DBG
                printf("!FOUND! i = %d, j = %d, a = %d, b = %d, len = %d\n", i, j, a, b, len);
#endif
                const int l = i - a + 1, r = j + b - 1;
                sForawrd.cover(l, l + (r - (l + len * 2 - 1)), len * 2, forward);
                sBackward.cover(l + len * 2 - 1, r, len * 2, backward);
                // sForawrd->update(l, l + (r - (l + len * 2 - 1)), len * 2);
                // sBackward->update(l + len * 2 - 1, r, len * 2);
                /*
                for (int i = l + len * 2 - 1; i <= r; i++) {
                    printf("[%d, %d]\n", i - len * 2 + 1, i);
                }
                */
            }
        }
    }

    // for (int i = 0; i < n; i++) forward[i] = sForawrd->query(i);
    // for (int i = 0; i < n; i++) backward[i] = sBackward->query(i);

#ifdef DBG
    puts("forward / backward");
    for (int i = 0; i < n; i++) printf("%d%c", forward[i] == INT_MAX ? 0 : forward[i], i == n - 1 ? '\n' : ' ');
    for (int i = 0; i < n; i++) printf("%d%c", backward[i] == INT_MAX ? 0 : backward[i], i == n - 1 ? '\n' : ' ');
#endif
}

/*
struct ChairmanTree {
    struct SegmentTree {
        int l, r, mid;
        SegmentTree *lc, *rc;
        int cnt, refCnt;

        SegmentTree(const int l, const int r, SegmentTree *lc, SegmentTree *rc, const int cnt = 0) : l(l), r(r), mid(l + (r - l) / 2), lc(lc), rc(rc), cnt(cnt), refCnt(0) {}

        ~SegmentTree() {
            if (lc && lc->unRef()) delete lc;
            if (rc && rc->unRef()) delete rc;
        }

        SegmentTree *ref() {
            refCnt++;
            return this;
        }

        bool unRef() {
            return !refCnt--;
        }

        int query(const int l, const int r) {
            if (l > this->r || r < this->l) return 0;
            else if (l <= this->l && r >= this->r) return cnt;
            else return (lc ? lc->query(l, r) : 0) + (rc ? rc->query(l, r) : 0);
        }

        SegmentTree *insertSelf(const int x) {
            if (l == r) return this;
            if (x <= mid) {
                this->lc = (new SegmentTree(l, mid, NULL, NULL, 1))->insertSelf(x);
            } else {
                this->rc = (new SegmentTree(mid + 1, r, NULL, NULL, 1))->insertSelf(x);
            }
            return this;
        }

        SegmentTree *insert(const int x) {
            if (x <= mid) {
                return new SegmentTree(l, r, lc ? lc->insert(x) : (new SegmentTree(l, mid, NULL, NULL, 1))->insertSelf(x), rc ? rc->ref() : NULL, cnt + 1);
            } else {
                return new SegmentTree(l, r, lc ? lc->ref() : NULL, rc ? rc->insert(x) : (new SegmentTree(mid + 1, r, NULL, NULL, 1))->insertSelf(x), cnt + 1);
            }
        }

        int lcount() {
            return lc ? lc->cnt : 0;
        }
    } *root[MAXN + 1];
    int n, l, r;

    void clear() {
        for (int i = n; i >= 0; i--) {
            delete root[i];
        }
    }

    // Build tree with array a, whose elements are limited in [l, r]
    void build(const int *a, const int n, const int l, const int r) {
        this->n = n, this->l = l, this->r = r;
        root[0] = new SegmentTree(0, n, NULL, NULL, 0);
        for (int i = 1; i <= n; i++) {
            root[i] = root[i - 1]->insert(a[i - 1]);
        }
    }

    int query(const int l, const int r, const int x) {
        // Adjust input [0, n - 1] to [1, n]
        SegmentTree *vr = root[r + 1], *vl = root[l];
        int ans = vr->query(this->l, x);
        if (vl) ans -= vl->query(this->l, x);
        return ans;

        / *
        while (min != max) {
            const int mid = mid + (max - min) / 2, t = vr->lcount() - (vl ? vl->lcount() : 0);
            if (t < k) {
                k -= t, vr = vr->rc;
                if (vl) vl = vl->rc;
            } else {
                vr = vr->lc;
                if (vl) vl = vl->lc;
            }
        }
        return min;
        * /
    }
} t;
*/

inline unsigned long long solve(const int *limit, const int *forward, const int n) {
    static struct Query {
        int type;
        int pos, val;

        bool operator<(const Query &other) const {
            if (pos < other.pos) return true;
            else if (pos == other.pos && !type && other.type) return true;
            else return false;
        }
    } Q[MAXN * 3];

    for (int i = 0; i < n; i++) {
        Q[i].pos = i;
        Q[i].type = 0;
        Q[i].val = limit[i];

#ifdef DBG
        printf("add(pos = %d, val = %d)\n", Q[i].pos, Q[i].val);
#endif
    }

    for (int i = 0; i < n; i++) {
        Q[n + i].pos = i - 1;
        Q[n + i].val = i;
        Q[n + i].type = -1;

#ifdef DBG
        printf("query(pos = %d, val = %d, type = %d)\n", Q[n + i].pos, Q[n + i].val, Q[n + i].type);
#endif

        Q[n + i + n].pos = (forward[i] == INT_MAX) ? (n - 1) : (i + forward[i] - 1);
        Q[n + i + n].val = i;
        Q[n + i + n].type = 1;

#ifdef DBG
        printf("query(pos = %d, val = %d, type = %d)\n", Q[n + i + n].pos, Q[n + i + n].val, Q[n + i + n].type);
#endif
    }

    std::sort(Q, Q + n * 3);

    static struct BinaryIndexedTree {
        int a[MAXN], n;

        static int lowbit(const int x) {
            return x & -x;
        }

        void update(const int x, const int delta) {
            /*
            a[x - 1] += delta;
            */
            for (int i = x; i <= n; i += lowbit(i)) a[i - 1] += delta;
        }

        int query(const int x) {
            int ans = 0;
            for (int i = x; i > 0; i -= lowbit(i)) ans += a[i - 1];
            /*
            for (int i = 1; i <= x; i++) ans += a[i - 1];
            */
            return ans;
        }

        void clear(const int x) {
            /*
            a[x - 1] = 0;
            */
            for (int i = x; i <= n; i += lowbit(i)) {
                if (!a[i - 1]) break;
                a[i - 1] = 0;
            }
        }
    } bit;

    bit.n = n;

    unsigned long long ans = 0;
    for (int i = 0; i < n * 3; i++) {
#ifdef DBG
        printf("Query: %d %d %d\n", Q[i].type, Q[i].pos, Q[i].val);
#endif
        if (Q[i].type) {
            int t = bit.query(Q[i].val + 1) * Q[i].type;
            ans += t;
#ifdef DBG
            printf("t = %d\n", t);
#endif
        } else bit.update(Q[i].val + 1, 1);
        /*
#ifdef DBG
        printf("Query: %d %d %d\n", Q[i].l, Q[i].r, Q[i].val);
#endif
        if (Q[i].l == -1) bit.update(Q[i].val + 1, 1);
        else ans += bit.query(Q[i].r + 1) - bit.query(Q[i].l);
#ifdef DBG
        printf("ans = %d\n", ans);
#endif
        */
    }

    for (int i = 0; i < n * 2; i++) if (!Q[i].type) bit.clear(Q[i].val + 1);

    return ans;
}

int main() {
    int testcase;
    scanf("%d", &testcase);

    while (testcase--) {
        static char s[MAXN + 1];
        scanf("%s", s);
        n = strlen(s);
        sa1.build(s, n);
        std::reverse(s, s + n);
        sa2.build(s, n);

        prepare(n);
        static int limit[MAXN];
        for (int i = 0; i < n; i++) limit[i] = backward[i] == INT_MAX ? 0 : (i - backward[i] + 2);

#ifdef DBG
        for (int i = 0; i < n; i++) printf("%d%c", limit[i], i == n - 1 ? '\n' : ' ');
#endif

        // t.clear();
        // t.build(limit, n, 0, n - 1);

        /*
        unsigned long long ans = 0;
        for (int i = 0; i < n; i++) {
            ans += t.query(i, (forward[i] == INT_MAX) ? (n - 1) : (i + forward[i] - 1), i);
        }
        */

        unsigned long long ans = solve(limit, forward, n);
        printf("%llu\n", ans);
    }

    return 0;
}