「BZOJ 3940」Censoring - AC 自动机

给定一个串 和一些单词串,每一次在 中寻找第一次出现的单词串,并将其删除,求最终串。

链接

BZOJ 3940

题解

对单词串建立 AC 自动机,用链表维护原串,匹配到单词串时从链表中删除。

删去一个单词时,原单词的左右两边可能会连接形成一个新的单词,为了不漏掉这些新单词,必须在删掉一个单词后以在这个单词之前的状态继续匹配 —— 使用栈维护历史状态即可。

代码

#include <cstdio>
#include <cassert>
#include <cstring>
#include <stack>
#include <queue>
#include <list>

const int MAXN = 1e5;
const int CHARSET_SIZE = 'z' - 'a' + 1;
const int BASE_CHAR = 'a';

struct Trie {
    struct Node {
        Node *c[CHARSET_SIZE], *fail, *next;
        bool isWord;
        int length;

        Node(const bool isWord = false, const int length = 0) : fail(NULL), next(NULL), isWord(isWord), length(length) {
            for (int i = 0; i < CHARSET_SIZE; i++) c[i] = NULL;
        }
    } *root;

    Trie() : root(NULL) {}

    void insert(const char *begin, const char *end) {
        Node **v = &root;
        for (const char *p = begin; p != end; p++) {
            if (!*v) *v = new Node();
            v = &(*v)->c[*p];
        }
        if (!*v) *v = new Node(true, end - begin);
        else (*v)->isWord = true, (*v)->length = end - begin;
    }

    void build() {
        std::queue<Node *> q;
        q.push(root);
        root->fail = root, root->next = NULL;
        while (!q.empty()) {
            Node *v = q.front();
            q.pop();

            for (int i = 0; i < CHARSET_SIZE; i++) {
                Node *&c = v->c[i];
                if (!c) {
                    c = v == root ? root : v->fail->c[i];
                    continue;
                }
                Node *u = v->fail;
                while (u != root && !u->c[i]) u = u->fail;
                c->fail = v != root && u->c[i] ? u->c[i] : root;
                c->next = c->fail->isWord ? c->fail : c->fail->next;
                q.push(c);
            }
        }
    }

    void exec(std::list<char> &list) {
        Node *v = root;
        std::stack<Node *> history;
        for (std::list<char>::iterator it = list.begin(); it != list.end(); ) {
            history.push(v);
            // while (v != root && !v->c[*it]) v = v->fail, assert(false);
            if (v->c[*it]) v = v->c[*it];
            if (!v->isWord && v->next) v = v->next;
            if (v->isWord) {
                for (int i = 1; i < v->length; i++) it--;
                for (int i = 0; i < v->length; i++) it = list.erase(it);
                for (int i = 1; i < v->length; i++) history.pop();
                v = history.top();
                history.pop();
            } else it++;
        }
    }
} t;

int main() {
    static char s[MAXN + 1];
    scanf("%s", s);
    std::list<char> list;
    for (const char *p = s; *p; p++) list.push_back(*p - BASE_CHAR);

    int n;
    scanf("%d", &n);
    while (n--) {
        scanf("%s", s);
        const int len = strlen(s);
        for (int i = 0; i < len; i++) s[i] -= BASE_CHAR;
        t.insert(s, s + len);
    }

    t.build();
    t.exec(list);

    for (std::list<char>::const_iterator it = list.begin(); it != list.end(); it++) putchar(*it + BASE_CHAR);
    putchar('\n');

    return 0;
}