Cod sursa(job #1201509)

Utilizator AlexandruValeanuAlexandru Valeanu AlexandruValeanu Data 25 iunie 2014 12:40:50
Problema Potrivirea sirurilor Scor 16
Compilator cpp Status done
Runda Arhiva educationala Marime 2.21 kb
#include <iostream>
#include <fstream>
#include <cstring>
#include <vector>

using namespace std;

const int BASE = 91;
const int MOD = 666013;

const int Lmax = 2e6 + 2;

/// H( S[i...j] ) = S[i] + S[i + 1] * base + ... + S[j] * base ^ ( j - i );

vector <int> match;

int H[Lmax];
char A[Lmax];
char B[Lmax];
int invBasePw[Lmax];
int N, M;

int pw( int a, int p )
{
    int res = 1;

    for ( int i = 0; ( 1 << i ) <= p; ++i )
    {
        if ( p & ( 1 << i ) )
                res = ( 1LL * res * a ) % MOD;

        a = ( 1LL * a * a ) % MOD;
    }

    return res;
}

int inv( int a )
{
    return pw( a, MOD - 2 );
}

void calc_inv()
{
    invBasePw[0] = 1;

    int I = inv( BASE );

    for ( int i = 1; i <= M; ++i )
            invBasePw[i] = ( 1LL * invBasePw[i - 1] * I ) % MOD;
}

void compute_hash()
{
    int baseI = 1;

    H[0] = 0;

    for ( int i = 1; i <= M; ++i )
    {
        H[i] = ( H[i - 1] + baseI * B[i] ) % MOD;
        baseI = ( baseI * BASE ) % MOD;
    }
}

int getHash( int i, int j )
{
    int c = H[j] - H[i - 1];

    if ( c < MOD ) c += MOD;

    return ( 1LL * c * invBasePw[i - 1] ) % MOD;
}

int getCodeHash( char X[], int i, int j )
{
    int hash_p = 0;

    for ( int k = i; k <= j; ++k )
    {
        hash_p = ( hash_p * BASE + X[k] ) % MOD;
    }

    return hash_p;
}

int main()
{
    ifstream in("strmatch.in");
    ofstream out("strmatch.out");

    in >> ( A + 1 );
    in >> ( B + 1 );

    N = strlen( A + 1 );
    M = strlen( B + 1 );

    if ( N > M )
    {
        out << "0\n";
        return 0;
    }

    calc_inv();
    compute_hash();

    int hashPattern = getCodeHash( A, 1, N );
    int rollingHash = getHash( 1, N );

    if ( hashPattern == rollingHash )
    {
        match.push_back( 0 );
    }

    for ( int i = N + 1; i <= M; ++i )
    {
        rollingHash = getHash( i - N + 1, i );

        if ( hashPattern == rollingHash )
        {
            match.push_back( i - N );
        }
    }

    out << match.size() << "\n";

    for ( int i = 0; i < min( 1000, (int)match.size() ); ++i )
    {
        out << match[i] << " ";
    }

    return 0;
}