「BZOJ 3881」Divljak - AC 自动机 + 树上路径并

n 个字符串 S_1, S_2, \ldots, S_n ,另有一个集合 T ,初始为空。 有 q 次操作,每次向 T 中添加一个字符串 P ,或询问 T 中有多少串能匹配 S_i

链接

BZOJ 3881

题解

S_i 建立 AC 自动机,每次加入一个串 P 时,将所有匹配 P 的单词节点计数器 +1 ,询问时直接输出答案。

考虑到每一次能匹配一个串的单词节点数量是 O(n ^ 2) 级别的,需要优化。如果我们不考虑后缀链接,则匹配数量降为 O(n) ,而对于通过后缀链接转移到的点,都是这些点在 Fail 树上的祖先。

每次记录所有匹配 P 的首个单词节点,在 Fail 树上对这些点到根的路径求,对路径并的每个节点计数器 +1

关于求路径并,将所有节点按照 DFS 序排序,将每个点到根的路径加入到集合中,将相邻每一对点的最近公共祖先到根的路径从集合中去除即可。

体现在本题中即为将每个点到根的路径上的每一个点计数器 +1 ,相邻每一对点的最近公共祖先到根的路径上的每一个点计数器 -1 。转化为子树求和,用树状数组维护即可。

代码

#include <cstdio>
#include <cassert>
#include <cstring>
#include <algorithm>
#include <vector>
#include <queue>
#include <stack>

const int MAXN = 100000;
const int MAXLEN = 2000000;
const int CHARSET_SIZE = 'z' - 'a' + 1;
const int BASE_CHAR = 'a';

int tot;
struct Trie {
    struct Node {
        Node *c[CHARSET_SIZE], *fail, *next;
        int id;
        bool isWord, visited;

        Node(const bool isWord = false) : fail(NULL), next(NULL), id(tot++), isWord(isWord), visited(false) {
            for (int i = 0; i < CHARSET_SIZE; i++) c[i] = NULL;
        }
    } *root;

    Trie() : root(NULL) {}

    Node *insert(const char *begin, const char *end) {
        Node **v = &root;
        for (const char *p = begin; p != end; p++) {
            if (!*v) *v = new Node;
            v = &(*v)->c[*p];
        }
        if (!*v) *v = new Node(true);
        else (*v)->isWord = true;
        return *v;
    }

    void build() {
        std::queue<Node *> q;
        q.push(root);
        root->fail = root, root->next = NULL;
        while (!q.empty()) {
            Node *v = q.front();
            q.pop();

            for (int i = 0; i < CHARSET_SIZE; i++) {
                Node *&c = v->c[i];
                if (!c) {
                    c = v->fail->c[i] ? v->fail->c[i] : root;
                    continue;
                }
                Node *u = v->fail;
                while (u != root && !u->c[i]) u = u->fail;
                c->fail = v != root && u->c[i] ? u->c[i] : root;
                c->next = c->fail->isWord ? c->fail : c->fail->next;
                q.push(c);
            }
        }
    }

    void exec(const char *begin, const char *end, std::vector<int> &vec) {
        Node *v = root;
        for (const char *p = begin; p != end; p++) {
            v = v->c[*p];
            if (v->isWord) vec.push_back(v->id);
            else if (v->next) vec.push_back(v->next->id);
        }
    }
} t;

struct Node {
    struct Edge *e;
    int l, r, depth, size;
    Node *p, *c, *top;
    bool visited;
} N[MAXLEN + 1];

struct Edge {
    Node *s, *t;
    Edge *next;

    Edge(Node *s, Node *t) : s(s), t(t), next(s->e) {}
};

inline void addEdge(const int s, const int t) {
#ifdef DBG
    printf("addEdge: %d -> %d\n", s + 1, t + 1);
#endif
    N[s].e = new Edge(&N[s], &N[t]);
    N[t].e = new Edge(&N[t], &N[s]);
}

inline void build() {
    std::queue<Trie::Node *> q;
    q.push(t.root);
    t.root->visited = true;
    while (!q.empty()) {
        Trie::Node *v = q.front();
        q.pop();

        if (v->fail != v) addEdge(v->id, v->fail->id);

        for (int i = 0; i < CHARSET_SIZE; i++) {
            if (!v->c[i]->visited) {
                v->c[i]->visited = true;
                q.push(v->c[i]);
            }
        }
    }
}

inline void split() {
    std::stack<Node *> s;
    s.push(&N[0]);
    N[0].depth = 1;

    while (!s.empty()) {
        Node *v = s.top();
        if (!v->visited) {
            v->visited = true;
            for (Edge *e = v->e; e; e = e->next) if (!e->t->depth) {
                e->t->depth = v->depth + 1;
                e->t->p = v;
                s.push(e->t);
            }
        } else {
            v->size = 1;
            for (Edge *e = v->e; e; e = e->next) if (e->t->p == v) {
                v->size += e->t->size;
                if (!v->c || v->c->size < e->t->size) v->c = e->t;
            }
            s.pop();
        }
    }

    for (int i = 0; i < tot; i++) N[i].visited = false;

    s.push(&N[0]);
    int ts = 0;
    while (!s.empty()) {
        Node *v = s.top();
        if (!v->visited) {
            v->visited = true;
            v->top = (!v->p || v != v->p->c) ? v : v->p->top;
            v->l = ++ts;
            for (Edge *e = v->e; e; e = e->next) if (e->t->p == v && e->t != v->c) {
                s.push(e->t);
            }
            if (v->c) s.push(v->c);
        } else {
            v->r = ts;
            s.pop();
        }
    }

    for (int i = 0; i < tot; i++) assert(N[i].l && N[i].r);

#ifdef DBG
    for (int i = 0; i < tot; i++) {
        printf("N[%d] => [%d, %d]\n", i + 1, N[i].l, N[i].r);
    }
#endif
}

inline Node *lca(Node *a, Node *b) {
    while (a->top != b->top) {
        if (a->top->depth < b->top->depth) std::swap(a, b);
        a = a->top->p;
    }
    Node *res = a->depth < b->depth ? a : b;
#ifdef DBG
    printf("lca(%lu, %lu) = %lu\n", a - N + 1, b - N + 1, res - N + 1);
#endif
    return res;
}

inline bool compare(const int a, const int b) {
    return N[a].l < N[b].l;
}

struct BinaryIndexedTree {
    int a[MAXLEN + 1 + 1 + 1], n;

    void init(const int n) { this->n = n; }

    static int lowbit(const int x) { return x & -x; }

    void update(const int pos, const int delta) {
#ifdef DBG
        printf("BIT: %d add %d\n", pos, delta);
#endif
#ifdef FORCE
        a[pos] += delta;
        return;
#endif
        for (int i = pos; i <= n; i += lowbit(i)) a[i] += delta;
    }

    int sum(const int pos) {
        int res = 0;
#ifdef FORCE
        for (int i = 1; i <= pos; i++) res += a[i];
        // printf("sum(%d) = %d\n", pos, res);
        return res;
#endif
        for (int i = pos; i > 0; i -= lowbit(i)) res += a[i];
        return res;
    }

    int query(const int l, const int r) {
#ifdef FORCE
        int res = 0;
        for (int i = l; i <= r; i++) res += a[i];
        // printf("query(%d, %d) = %d\n", l, r, res);
        return res;
#endif
        return sum(r) - sum(l - 1);
    }
} bit;

inline void add(const char *begin, const char *end) {
    std::vector<int> vec;
    t.exec(begin, end, vec);
    std::sort(vec.begin(), vec.end(), compare);
    vec.erase(std::unique(vec.begin(), vec.end()), vec.end());

#ifdef DBG
    printf("%lu matched:", vec.size());
    for (size_t i = 0; i < vec.size(); i++) printf(" %d", vec[i] + 1);
    putchar('\n');
#endif

    for (size_t i = 0; i < vec.size(); i++) {
        bit.update(N[vec[i]].l, 1);
        if (i != 0) {
            Node *p = lca(&N[vec[i]], &N[vec[i - 1]]);
            bit.update(p->l, -1);
        }
    }
}

inline int solve(const int x) {
    return bit.query(N[x].l, N[x].r);
}

int main() {
    int n;
    static char s[MAXLEN + 1];
    scanf("%d", &n);
    Trie::Node *a[MAXN];
    for (int i = 0; i < n; i++) {
        scanf("%s", s);
        const int len = strlen(s);
        for (int i = 0; i < len; i++) s[i] -= BASE_CHAR;
        a[i] = t.insert(s, s + len);
    }

    t.build();
    build();
    split();
    bit.init(tot);

    int q;
    scanf("%d", &q);
    while (q--) {
        int cmd;
        scanf("%d", &cmd);
        if (cmd == 1) {
            scanf("%s", s);
            const int len = strlen(s);
#ifdef DBG
            printf("add(\"%s\"): ", s);
#endif
            for (int i = 0; i < len; i++) s[i] -= BASE_CHAR;
            add(s, s + len);
        } else {
            int x;
            scanf("%d", &x), x--;
            printf("%d\n", solve(a[x]->id));
        }
    }

    return 0;
}