Cod sursa(job #2570815)

Utilizator PopoviciRobertPopovici Robert PopoviciRobert Data 4 martie 2020 19:31:01
Problema Aho-Corasick Scor 100
Compilator cpp-32 Status done
Runda Arhiva educationala Marime 4.23 kb
#include <bits/stdc++.h>
#define lsb(x) (x & (-x))
#define ll long long
#define ull unsigned long long
#define uint unsigned int


using namespace std;

struct Node {
    Node* son[26];
    Node* fail;

    int num;
    vector <int> ids;

    Node() {
        memset(son, NULL, sizeof(son));
        fail = NULL;
        num = 0;
    }
};

struct AHO {
    Node* root;

    AHO() {
        root = new Node;
    }

    void ins(Node* nod, int pos, string &str, int id) {
        if(pos == str.size()) {
            nod -> ids.push_back(id);
        }
        else {
            char ch = str[pos] - 'a';
            if(nod -> son[ch] == NULL) {
                nod -> son[ch] = new Node;
            }
            ins(nod -> son[ch], pos + 1, str, id);
        }
    }

    vector <Node*> Q;
    inline void bfs() {
        root -> fail = root;
        Q.push_back(root);
        int pos = 0;
        while(pos < Q.size()) {
            Node* nod = Q[pos++];
            for(int ch = 0; ch < 26; ch++) {
                if(nod -> son[ch] == NULL) continue;
                Node* cur = nod -> fail;
                while(cur != root && cur -> son[ch] == NULL) {
                    cur = cur -> fail;
                }
                nod -> son[ch] -> fail = root;
                if(cur -> son[ch] != NULL && cur != nod) {
                    nod -> son[ch] -> fail = cur -> son[ch];
                }
                Q.push_back(nod -> son[ch]);
            }
        }
    }

    inline void solve(string &str) {
        Node* nod = root;
        for(auto it : str) {
            char ch = it - 'a';
            while(nod != root && nod -> son[ch] == NULL) {
                nod = nod -> fail;
            }
            if(nod -> son[ch] != NULL) {
                nod = nod -> son[ch];
            }
            nod -> num++;
        }
    }

    inline void antibfs(vector <int> &sol) {
        for(int i = (int)Q.size() - 1; i >= 0; i--) {
            Node* nod = Q[i];
            for(auto id : nod -> ids) {
                sol[id] += nod -> num;
            }
            nod -> fail -> num += nod -> num;
        }
    }
};

//namespace Brut {
//    static const int SIGMA = 2;
//    inline void Gen() {
//        ofstream fout("input.txt");
//        int sz = rand() % 100 + 1;
//        for(int i = 0; i < sz; i++) {
//            fout << (char)(rand() % SIGMA + 'a');
//        }
//        fout << "\n";
//        int n = rand() % 50;
//        fout << n << "\n";
//        for(int i = 0; i < n; i++) {
//            int len = rand() % 100 + 1;
//            for(int j = 0; j < len; j++) {
//                fout << (char)(rand() % SIGMA + 'a');
//            }
//            fout << "\n";
//        }
//        fout.close();
//    }
//};


int main() {
#ifdef HOME
    ifstream cin("A.in");
    ofstream cout("A.out");
#endif
    ifstream cin("ahocorasick.in");
    ofstream cout("ahocorasick.out");
    int i, n;
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    string str; cin >> str >> n;
    AHO aho;
    for(i = 1; i <= n; i++) {
        string pat; cin >> pat;
        aho.ins(aho.root, 0, pat, i);
    }
    aho.bfs();
    aho.solve(str);
    vector <int> sol(n + 1);
    aho.antibfs(sol);
    for(i = 1; i <= n; i++) {
        cout << sol[i] << "\n";
    }

//    srand(time(NULL));
//
//    while(1) {
//        Brut::Gen();
//
//        ifstream fin("input.txt");
//        string str; fin >> str >> n;
//        vector <string> pat(n + 1);
//        AHO aho;
//        for(i = 1; i <= n; i++) {
//            fin >> pat[i];
//            aho.ins(aho.root, 0, pat[i], i);
//        }
//        fin.close();
//
//        aho.bfs();
//        aho.solve(str);
//        vector <int> sol(n + 1);
//        aho.antibfs(sol);
//
//        for(i = 1; i <= n; i++) {
//            int num = 0;
//            for(int j = 0; j + pat[i].size() <= str.size(); j++) {
//                num += (str.substr(j, pat[i].size()) == pat[i]);
//            }
//            if(sol[i] != num) {
//                cerr << "WA";
//                exit(0);
//            }
//        }
//
//        cerr << "OK\n";
//    }

    return 0;
}