Cod sursa(job #2814272)

Utilizator horiacoolNedelcu Horia Alexandru horiacool Data 7 decembrie 2021 20:59:33
Problema Aho-Corasick Scor 80
Compilator cpp-64 Status done
Runda Arhiva educationala Marime 4.76 kb
#include <queue>
#include <string>
#include <memory>
#include <utility>
#include <unordered_map>

#include <fstream>
#include <iostream>

class Trie
{
private:
	// Define the character size
	static const size_t CHAR_SIZE = 256;

	using StrId_t = size_t;

	StrId_t idCount;
	std::unordered_map<std::string, StrId_t> strToId;
	std::unordered_map<StrId_t, std::string> idToStr;

	struct Node
	{
		Node* failNode;
		Node* character[CHAR_SIZE];
		std::vector<StrId_t> stringsMatched;

		Node(Node*);
		~Node();
		Node(const Node&) = delete;
		Node& operator =(const Node&) = delete;
	} root;

	StrId_t getStringId(const std::string&);
	const std::string& getIdString(const StrId_t);

public:
	// Insert a word into a Trie
	void insert(const std::string&);

	// Create failed edges in Aho-Corasick Trie automaton
	void buildFailEdges();

	// Find words from the dictionary in the given text
	std::unordered_map<std::string, std::vector<size_t>>
		patternMatching(const std::string&);
};

Trie::Node::Node(Node* root = nullptr)
{
	// all the nodes fail default to the root node
	this->failNode = root;

	for (size_t i = 0; i < CHAR_SIZE; ++i) {
		this->character[i] = nullptr;
	}
}

Trie::Node::~Node()
{
	// delete recursively the trie
	for (size_t i = 0; i < CHAR_SIZE; ++i) {
		if (this->character[i] != nullptr) {
			delete this->character[i];
		}
	}
}

Trie::StrId_t Trie::getStringId(const std::string& string)
{
	// if (strToId.contains(string)) {
	// 	return strToId[string];
	// }

	if (strToId.find(string) != strToId.end()) {
		return strToId[string];
	}

	strToId[string] = ++idCount;
	idToStr[idCount] = string;

	return idCount;
}

const std::string& Trie::getIdString(const StrId_t id)
{
	return idToStr[id];
}

void Trie::insert(const std::string& word)
{
	StrId_t word_id = getStringId(word);
	Node* curr = &root;

	for (size_t i = 0; i < word.length(); ++i) {
		// create a new node if the path doesn't exist
		if (curr->character[(size_t) word[i]] == nullptr) {
			curr->character[(size_t) word[i]] = new Node(&root);
		}

		// go to the next node
		curr = curr->character[(size_t) word[i]];
	}

	// mark the current node as a leaf
	curr->stringsMatched.push_back(word_id);
}

void Trie::buildFailEdges()
{
	std::queue<Node*> queue;
	Node* curr,* fail;

	// first insert in the queue the root's children
	for (size_t i = 0; i < CHAR_SIZE; ++i) {
		if (root.character[i] != nullptr) {
			queue.push(root.character[i]);
		}
	}

	while (!queue.empty()) {
		curr = queue.front();
		queue.pop();

		// for each child node compute failed node
		for (size_t i = 0; i < CHAR_SIZE; ++i) {
			if (curr->character[i] == nullptr) {
				continue;
			}

			fail = curr->failNode;

			do {
				// check if the current fail node matches the child's suffix
				if (fail->character[i] != nullptr) {
					curr->character[i]->failNode = fail->character[i];

					// add matched strings of the fail node to the child's ones
					auto& childStrings = curr->character[i]->stringsMatched;
					auto& failedStrings = fail->character[i]->stringsMatched;

					childStrings.insert(childStrings.end(),
						failedStrings.begin(), failedStrings.end());

					break;
				}

				fail = fail->failNode;
			} while (fail != nullptr);

			queue.push(curr->character[i]);
		}
	}
}

std::unordered_map<std::string, std::vector<size_t>>
	Trie::patternMatching(const std::string& text)
{
	// start from the root node
	Node* curr = &root;

	std::unordered_map<StrId_t, std::vector<size_t>> tempPatterns;
	std::unordered_map<std::string, std::vector<size_t>> patterns;

	// traverse the automaton with the given text
	for (size_t i = 0; i < text.length(); ++i) {
		// find the longest prefix match till now
		do {
			if (curr->character[(size_t) text[i]] != nullptr) {
				curr = curr->character[(size_t) text[i]];

				for (auto id : curr->stringsMatched) {
					auto found_index = i - getIdString(id).length() + 1;

					tempPatterns[id].push_back(found_index);
				}

				break;
			}

			curr = curr->failNode;
		} while (curr != nullptr);

		if (curr == nullptr) {
			curr = &root;
		}
	}

	// move the found positions to the returned map
	for (auto& id_vect : tempPatterns) {
		patterns[getIdString(id_vect.first)] = std::move(id_vect.second);
	}

	return patterns;
}


int main()
{
	std::ifstream fin("ahocorasick.in");
	std::ofstream fout("ahocorasick.out");

	std::string text;
	std::getline(fin, text);

	Trie automaton;
	int n;

	fin >> n;
	std::vector<std::string> dictionary(n);

	for (int i = 0; i < n; ++i) {
		fin >> dictionary[i];

		automaton.insert(dictionary[i]);
	}

	automaton.buildFailEdges();
	auto patternMatched = automaton.patternMatching(text);

	for (auto& str : dictionary) {
		fout << patternMatched[str].size() << '\n';

		// std::cout << str << ":";
		// for (auto it : patternMatched[str]) {
		// 	std::cout << " " << it;
		// }
		// std::cout << "\n";
	}

	return 0;
}