You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

6.9 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

##P5357 【模板】AC 自动机(二次加强版)

一、题目描述

给你一个文本串 Sn 个模式串 T_{1 \sim n},请你分别求出每个模式串 T_i​在 S 中出现的次数。

二、对加强版进行改造

因为是什么二次加强版,所以大家先去做一下 加强版 吧,做法差不多。

好了,看到这里大家都一定做过加强版了吧,那么这道题的做法也是差不多的: 我们这一次不需要求出现最多的字符串啦,直接将cnt数组输出就好了!(应该都知道cnt数组是什么吧,就是统计每个模式串在文本串出现多少次的数组

重复的单词有没有影响啊!有啊!对于加强版这一次重复的单词就会有影响啦,怎么办?

这道题有相同字符串要统计,设当前字符串是第x个,我们用family[x]数组存当前字符串在Trie中的那个位置输入模式串序号,最后把cnt[family[i]]输出就OK了。另外id只在第一次赋值时变化,其他都不变。

本题思路很简单,如果你做过加强版的话。 这个思路很好搞,就是简单统计出现次数,然后输出。 不过如果你直接交会发现TLE。 我当时就是非常高兴的把加强版的代码改了改交了上去:

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
const int N = 2 * 1e5 + 10;
const int M = 2 * 1e6 + 10;

char s[N], T[M];
int n;
int tr[N][26], idx, ne[N];
int id[N];
int cnt[N];
int family[N];

void insert(char *s, int x) {
    int p = 0;
    for (int i = 0; s[i]; i++) {
        int t = s[i] - 'a';
        if (!tr[p][t]) tr[p][t] = ++idx;
        p = tr[p][t];
    }

    if (!id[p]) id[p] = x; // id[p]记录的是首个入驻的模式串号x
    family[x] = id[p];     // 将所有最终位置是p号节点也就是重复模式串都划归到family[x]这个首次入驻模式串x为同一家族
}

int q[N], hh, tt = -1;
void bfs() {
    for (int i = 0; i < 26; i++)
        if (tr[0][i]) q[++tt] = tr[0][i];

    while (hh <= tt) {
        int p = q[hh++];
        for (int i = 0; i < 26; i++) {
            int t = tr[p][i];
            if (!t)
                tr[p][i] = tr[ne[p]][i];
            else {
                ne[t] = tr[ne[p]][i];
                q[++tt] = t;
            }
        }
    }
}

void query(char *s) {
    int p = 0;
    for (int i = 0; s[i]; i++) {
        p = tr[p][s[i] - 'a'];
        for (int j = p; j; j = ne[j])
            if (id[j]) cnt[id[j]]++;
    }
}

int main() {
    //加快读入
    ios::sync_with_stdio(false), cin.tie(0);
    cin >> n;
    for (int i = 1; i <= n; i++) {
        cin >> s;
        insert(s, i);
    }
    bfs();

    cin >> T;
    query(T);

    for (int i = 1; i <= n; i++) printf("%d\n", cnt[family[i]]);
    return 0;
}

三、持续优化

那么我们现在得到的算法是直接跑自动机跳ne树,然后每跳到一个标记点就计数加一。

考虑对其进行优化:

当匹配到单词的时,我们会不断地去跳ne,同一个点会可能被跳多次。 那么我们可以想一下把ne指针单独建出来他是个什么样子的。

它就是一棵树。

那么对于这一棵树,我们每次匹配时都会去更新它的父亲节点(ne树),那么对于树上的一条链,每一个子节点也有父子关系,他们会有共同的祖先。对于一对被遍历过的父子节点,它们的共同祖先显然会被父亲跳一次再被儿子跳一次,如果能够减少跳的次数,同时不丢失贡献,那么我们就能降低复杂度,从而完成本题。

那么我们思考一下,我们在跑自动机时如果先不跳ne,而是单纯的跑trie树,是比连跑带跳(ne)复杂度小不少的。那么跑完trie树,我们得到的是什么?

我们得到的是文本串在自动机上跑过的痕迹(脚印),我们也就得到了每个节点(不跳ne)被遍历的次数,在这些节点中,我们可以拿出来再更新ne

这时我们应该想一下既然都拿出来了,有没有什么方法能优化更新? 这样我们就需要思考ne树的性质 我们思考一下ne指针的建立:当前节点的 / 父亲节点的 / ne指针指向节点的 / 子节点 。 ne指针在树上跳时是一定向上跳的,最下面的节点会更新上面的父亲节点。 那么一个点被遍历的次数就是:trie树上遍历次数 + ne树上子节点被遍历次数

而子节点被遍历次数又取决于其trie树上遍历次数和自身子节点的个数。 那么最下面的点是不需要被其他点通过ne更新的

如果最下面的点更新过自己的父亲节点,那么它的父亲节点也就是次深的点就成了刚才最下面的点的状态。 而且一个节点只会被更新一次。

于是我们就得到了trie树的更新方法, 通过拓扑序 更新ne树,从底往上不断累加,最后输出结果。

拓扑序优化递推版本

#include <bits/stdc++.h>
using namespace std;
const int N = 200010;
const int M = 2000010; //文本串长度

int f[N];
char s[N];
char T[M];
int tr[N][26], idx, ne[N], id[N];

void insert(char *s, int x) {
    int p = 0;
    for (int i = 0; s[i]; i++) {
        int t = s[i] - 'a';
        if (!tr[p][t]) tr[p][t] = ++idx;
        p = tr[p][t];
    }
    id[x] = p;
}

//构建AC自动机
int q[N], hh, tt = -1;
void bfs() {
    for (int i = 0; i < 26; i++)
        if (tr[0][i]) q[++tt] = tr[0][i];

    while (hh <= tt) {
        int t = q[hh++];
        for (int i = 0; i < 26; ++i) {
            if (tr[t][i]) {
                ne[tr[t][i]] = tr[ne[t]][i];
                q[++tt] = tr[t][i];
            } else
                tr[t][i] = tr[ne[t]][i];
        }
    }
}

void query(char *s) {
    int p = 0;
    for (int i = 0; s[i]; i++) { // 枚举文本串每一个字符
        int t = s[i] - 'a';      // 字符映射的数字t,可以理解为边
        p = tr[p][t];            // 走进去到达的位置替换p
        f[p]++;                  // 标识此位置有人走过,记录走的次数
    }
}

int main() {
    //加快读入
    ios::sync_with_stdio(false), cin.tie(0);
    int n;
    cin >> n;
    for (int i = 1; i <= n; i++) {
        cin >> s;
        insert(s, i);
    }
    //构建AC自动机
    bfs();
    //文本串
    cin >> T;
    query(T);
    for (int i = idx; i; i--) f[ne[q[i]]] += f[q[i]]; //一路向上,计算叠加值
    //输出
    for (int i = 1; i <= n; i++) printf("%d\n", f[id[i]]);
    return 0;
}