「BZOJ 1396」识别子串 - SAM + 线段树

对于一个字符串 ,和 中的第 个字符 ,定义子串 为一个关于 的识别子串,当且仅当:

  1. 中只出现一次。

关于每一位字符的最短识别子串长度。

链接

BZOJ 1396

题解

建立 SAM,找出其中所有 集合大小为 的,这些节点所表示的字符串只出现过一次。

用线段树维护每个位置的最短识别子串长度。对于一个符合条件的节点 ,它表示的所有子串的结束位置是固定的,设这个位置为 ,设最短的子串的起始位置为 ,则对于 范围内的每个位置 都是它的识别子串,用 去更新它们的答案。起始位置在 左侧的那些子串(起始位置在 之间),以每个起始位置 开始的子串长度为 ,是一个关于 的一次函数,可以用线段树(李超树)维护。

代码

#include <cstdio>
#include <climits>
#include <cstring>
#include <algorithm>
#include <vector>

const int MAXN = 1e5;
const int CHARSET_SIZE = 26;

struct SuffixAutomaton {
    struct Node {
        Node *ch[CHARSET_SIZE], *next;
        int max, posCnt, pos;

        Node(int max = 0, bool newSuffix = false) : ch(), next(NULL), max(max), posCnt(newSuffix), pos(-1) {}

        int getMin() {
            return next->max + 1;
        }
    } *start, *last, _pool[MAXN * 2 + 1], *_curr;

    std::vector<Node *> topo;

    void init() {
        _curr = _pool;
        start = last = new (_curr++) Node;
    }

    Node *extend(int c) {
        Node *u = new (_curr++) Node(last->max + 1, true), *v = last;
        do {
            v->ch[c] = u;
            v = v->next;
        } while (v && !v->ch[c]);

        if (!v) {
            u->next = start;
        } else if (v->ch[c]->max == v->max + 1) {
            u->next = v->ch[c];
        } else {
            Node *n = new (_curr++) Node(v->max + 1, false), *o = v->ch[c];
            std::copy(o->ch, o->ch + CHARSET_SIZE, n->ch);
            n->next = o->next;
            o->next = u->next = n;
            for (; v && v->ch[c] == o; v = v->next) v->ch[c] = n;
        }

        last = u;
        return u;
    }

    std::vector<Node *> &toposort() {
        static int buc[MAXN * 2 + 1];
        int max = 0;
        for (Node *p = _pool; p != _curr; p++) {
            max = std::max(max, p->max);
            buc[p->max]++;
        }
        for (int i = 1; i <= max; i++) buc[i] += buc[i - 1];

        topo.resize(_curr - _pool);
        for (Node *p = _pool; p != _curr; p++) {
            topo[--buc[p->max]] = p;
        }

        for (int i = 0; i <= max; i++) buc[i] = 0;

        return topo;
    }

    void calc() {
        toposort();

        for (int i = topo.size() - 1; i > 0; i--) {
            Node *v = topo[i];
            v->next->posCnt += v->posCnt;
        }
    }
} sam;

struct LinearFunction {
    int k, b;

    LinearFunction(int x) : k(0), b(x) {}
    LinearFunction(int k, int b) : k(k), b(b) {}

    int operator()(int x) {
        return k * x + b;
    }
};

struct SegT {
    int l, r, mid;
    SegT *lc, *rc;
    LinearFunction f;

    SegT(int l, int r, SegT *lc, SegT *rc) : l(l), r(r), mid(l + (r - l) / 2), lc(lc), rc(rc), f(INT_MAX) {}

    int query(int pos) {
        if (l == r) return f(pos);
        else return std::min(f(pos), (pos > mid ? rc : lc)->query(pos));
    }

    void cover(LinearFunction f) {
        if (f(mid) < this->f(mid)) std::swap(f, this->f);
        if (f(l) < this->f(l)) lc->cover(f);
        if (f(r) < this->f(r)) rc->cover(f);
    }

    void update(int ql, int qr, const LinearFunction &f) {
        if (ql > r || qr < l) return;
        else if (ql <= l && qr >= r) cover(f);
        else lc->update(ql, qr, f), rc->update(ql, qr, f);;
    }

    static SegT *build(int l, int r) {
        if (l == r) return new SegT(l, r, NULL, NULL);
        else {
            int mid = l + (r - l) / 2;
            return new SegT(l, r, build(l, mid), build(mid + 1, r));
        }
    }
};

int main() {
    static char s[MAXN + 1];
    scanf("%s", s);
    int n = strlen(s);
    sam.init();
    for (int i = 0; i < n; i++) {
        SuffixAutomaton::Node *v = sam.extend(s[i] - 'a');
        v->pos = i;
    }

    sam.calc();

    /*
    static int ans[MAXN + 1];
    for (int i = 0; i < n; i++) ans[i] = INT_MAX;
    */

    SegT *seg = SegT::build(0, n - 1);
    for (int i = sam.topo.size() - 1; i > 0; i--) {
        SuffixAutomaton::Node *v = sam.topo[i];
        if (v->posCnt == 1 && v->pos != -1) {
            // printf("pos = %d, [%d, %d]\n", v->pos, v->getMin(), v->max);

            int min = v->getMin(), max = v->max, right = v->pos;

            // [right - min + 1, right] <-- min
            seg->update(right - min + 1, right, LinearFunction(min));

            int l = right - max + 1, r = right - min + 1;
            // f(x) = -(x - l) + max
            //      = -x + (l + max)

            // seg->update(l, r, LinearFunction(-1, l + max));
            seg->update(l, r, LinearFunction(-1, right + 1));

            /*
            for (int j = v->getMin(); j <= v->max; j++) {
                for (int k = v->pos - j + 1; k <= v->pos; k++) {
                    putchar(s[k]);
                    ans[k] = std::min(ans[k], j);
                }
                puts("");
            }
            */
        }
    }

    for (int i = 0; i < n; i++) printf("%d\n", seg->query(i));
}