Cod sursa(job #1201367)

Utilizator AlexandruValeanuAlexandru Valeanu AlexandruValeanu Data 24 iunie 2014 23:36:26
Problema Potrivirea sirurilor Scor 0
Compilator cpp Status done
Runda Arhiva educationala Marime 3.43 kb
#include <iostream>
#include <fstream>
#include <cstring>
#include <unordered_set>
#include <cassert>

using namespace std;

const int BASE = 91;
const int MOD = 1299721;

const int Nmax = 5e5 + 2;

int B[2 * Nmax];

char A[Nmax];
char S[2 * Nmax];
int P[2 * Nmax];
int N, NN;
int invBASE;

unordered_set <int> HT1, HT2;

void transform( char A[], int N )
{
    memset( S, 0, sizeof( S ) );

    NN = 0;

    S[0] = '#';

    S[ ++NN ] = '$';

    for ( int i = 1; i <= N; ++i )
    {
        S[ ++NN ] = A[i];
        S[ ++NN ] = '$';
    }

    S[NN + 1] = '@';
}

void Manacher()
{
    memset( P, 0, sizeof( P ) );

    int indexPal = 0, rightEndPal = 0;

    for ( int i = 1; i <= NN; ++i )
    {
        if ( rightEndPal > i )
                P[i] = min( rightEndPal - i, P[2 * indexPal - i] );

        while ( S[i - P[i] - 1] == S[i + P[i] + 1] )
                P[i]++;

        if ( P[i] + i > rightEndPal )
        {
            rightEndPal = P[i] + i;
            indexPal = i;
        }
    }
}

int pw( int a, int p )
{
    int res = 1;

    for ( int i = 0; ( 1 << i ) <= p; ++i )
    {
        if ( p & ( 1 << i ) )
                res = ( 1LL * res * a ) % MOD;

        a = ( 1LL * a * a ) % MOD;
    }

    return res;
}

int inv( int a )
{
    return pw( a, MOD - 2 );
}

int getCodeHash( int i, int j )
{
    int hash_p = 0;

    for ( int k = i; k <= j; ++k )
    {
        if ( S[k] != '$' )
                hash_p = ( 1ULL * hash_p * BASE + S[k] ) % MOD;
    }

    return hash_p;
}

int eraseCharHash( int &i, int &j, int hash_p, int lg )
{
    while ( S[i] == '$' ) i++;
    while ( S[j] == '$' ) j--;

    if ( i > j ) return -1;
    if ( i == j && S[j] == '$' ) return -1;

    int new_h = ( hash_p - S[j] + MOD ) % MOD;

    new_h = ( new_h - ( 1ULL * B[lg - 1] * S[i] ) % MOD + MOD ) % MOD;
    new_h = ( 1ULL * new_h * invBASE ) % MOD;

    i++;
    j--;

    return new_h;
}

void solveString( unordered_set <int> &HT )
{
    for ( int i = 1; i <= NN; ++i )
    {
        if ( P[i] )
        {
            int j1 = i - P[i];
            int j2 = i + P[i];

            int lg = P[i];

            int hsh = getCodeHash( j1, j2 );

            assert( hsh >= 0 );

            if ( HT.find( hsh ) == HT.end() )
            {
                HT.insert( hsh );
            }
            else
                continue;

            while ( j1 <= j2 && lg >= 3 )
            {
                hsh = eraseCharHash( j1, j2, hsh, lg );
                lg -= 2;

                assert( hsh >= 0 );

                if ( HT.find( hsh ) == HT.end() )
                {
                    HT.insert( hsh );
                }
                else
                    break;
            }
        }
    }
}

int main()
{
    ifstream in("unicat.in");
    ofstream out("unicat.out");

    B[0] = 1;

    for ( int i = 1; i < Nmax; ++i )
    {
        B[i] = ( 1LL * B[i - 1] * BASE ) % MOD;
    }

    invBASE = inv( BASE );

    in >> ( A + 1 );

    N = strlen( A + 1 );

    transform( A, N );
    Manacher();
    solveString( HT1 );

    in >> ( A + 1 );

    N = strlen( A + 1 );

    transform( A, N );
    Manacher();
    solveString( HT2 );

    int sol = 0;

    for ( auto x: HT2 )
            if ( HT1.find( x ) != HT1.end() )
                sol++;

    out << sol << "\n";

    return 0;
}