Cod sursa(job #2887359)

Utilizator ArdeleanOficialAlexandru ArdeleanOficial Data 9 aprilie 2022 14:01:47
Problema A+B Scor 0
Compilator cpp-64 Status done
Runda Arhiva de probleme Marime 8.32 kb
#include <iostream>
#include <vector>
#include <bitset>

using namespace std;

#define DEBUG


///TINE MINTE SA SCHIMBI N INAPOI DUPA DEBUG


const int N = 5e2 + 7, MOD = 1e9 + 9;
const long long INF = 1e18 + 7;

int h[N], v[N], t[N], g[N], prod[N], bigger[N];
vector < int > adia[N];
bitset < N > posibil;

pair < long long, int > imp[N], dp[N];

int s[N];
int k;

int m, invm;

int logpow(int b, int e) {
    int a(1);
    for (; e; e >>= 1, b = 1LL * b * b % MOD)
        if (e&1)
            a = 1LL * a * b % MOD;
    return a;
}

///SOL:
/// mereu ai de ales intre <, =, >
/// niciodata nu pui < de buna voie
/// mereu tintesti sa faci S pe unul dintre fii
/// pe restul ii pui >
/// nu are sens sa ii pui =, pentru ca cu siguranta incetinesti cu macar 1
/// pentru restul fiilor care raman cu ceva =, calculezi:
/// imp = cu cat te impiedica
/// dp[nod] = suma(g[fiu] | fiu < ) +
///         + suma(imp[fiu] | fiu = ) +
///         + dp[fiu], unde fiu este optim
/// imp[nod] = suma(g[fiu] | fiu < ) +
///          + suma(imp[fiu] | fiu = )
/// mai prelucrezi putin formula si ajungi la un min(dp[fiu] - imp[fiu])
/// multe alte detalii sunt luate in considerare pentru C = 2 si da
/// niste detalii:
/// cand esti >, o sa inmultesti si toate posibilitatile
/// cand e posibil sa atingi minim in mai multe locuri, insumezi niste fractii si dupa inmultesti rasp
/// etc...
/// scriu asta pentru ca stiu ca nu o sa apuc sa o depanez in concurs



void dfs3(int nod = 1) {
///niciodata nu o sa pui < s[i]
///v[nod] < s[i] => iti face doar rau
///v[nod] == s[i] nu este preferabil
///o sa faci pasi in plus daca nu intentionezi sa il creezi de acolo
///in rest faci mai mare strict

    if (h[nod] >= k) {
        dp[nod] = {0, prod[nod]};
        return;
    }
    if (!posibil[nod]) {
        dp[nod] = {INF, 0};
        return;
    }
    if (h[nod] == k - 1) {
        dp[nod] = {1, 1LL * prod[nod] * (v[nod] == -1 ? invm : 1) % MOD};
        return;
    }
    for (auto i : adia[nod]) {
        if (i == t[nod])
            continue;
        dfs3(i);
    }
    dp[nod] = {1, 1};
    long long minim(INF);
    int need = s[h[nod] + 1];
    if (need == m) {
        for (auto i : adia[nod]) {
            if (i == t[nod])
                continue;
            if (v[i] != -1 && v[i] < m) {
                dp[nod].first += g[i];
                dp[nod].second = 1LL * dp[nod].second * prod[i] % MOD;
                continue;
            }
            dp[nod].first += imp[i].first;
            dp[nod].second = 1LL * dp[nod].second * imp[i].second % MOD;
            if (minim > dp[i].first - imp[i].first)
                m
                inim = dp[i].first - imp[i].first;
        }
        dp[nod].first += minim;
        int bruh(0);
        for (auto i : adia[nod]) {
            if (i == t[nod])
                continue;
            if (v[i] != -1 && v[i] < m)
                continue;
            if (minim == dp[i].first - imp[i].first)
                bruh = (1LL * bruh + 1LL * dp[i].second * logpow(imp[i].second, MOD - 2)) % MOD;
        }
        dp[nod].second = 1LL * dp[nod].second * bruh % MOD;
    }
    else {
        for (auto i : adia[nod]) {
            if (i == t[nod])
                continue;
            if (v[i] != -1 && v[i] < need) {
                dp[nod].first += g[i];
                dp[nod].second = 1LL * dp[nod].second * prod[i] % MOD;
                continue;
            }
            if (v[i] > need) {
                dp[nod].second = 1LL * dp[nod].second * prod[i] % MOD;
                continue;
            }
            if (v[i] == need) {
                dp[nod].first += imp[i].first;
                dp[nod].second = 1LL * dp[nod].second * imp[i].second % MOD;
                if (minim > dp[i].first - imp[i].first)
                    minim = dp[i].first - imp[i].first;
            }
            else {
                dp[nod].second = 1LL * dp[nod].second * bigger[i] % MOD;
                if (minim > dp[i].first)
                    minim = dp[i].first;
            }
        }
        dp[nod].first += minim;
        int bruh(0);
        for (auto i : adia[nod]) {
            if (i == t[nod])
                continue;
            if (v[i] != -1 && v[i] < need)
                continue;
            if (v[i] == need) {
                if (minim == dp[i].first - imp[i].first)
                    bruh = (1LL * bruh + 1LL * dp[i].second * logpow(imp[i].second, MOD - 2)) % MOD;
            }
            else {
                if (minim == dp[i].first)
                    bruh = (1LL * bruh + 1LL * dp[i].second * logpow(bigger[i], MOD - 2)) % MOD;
            }
        }
        dp[nod].second = 1LL * dp[nod].second * bruh % MOD;
    }
}

void dfs2(int nod = 1) {
///daca s[i] == M, nu poti sa il ignori, cat te impiedica
    for (auto i : adia[nod]) {
        if (i == t[nod])
            continue;
        dfs2(i);
    }
    if (h[nod] >= k - 1 || (v[nod] >= 1 && v[nod] != s[h[nod]])) {
        imp[nod] = {0, 1LL * prod[nod] * (v[nod] == -1 ? invm : 1) % MOD};
        return;
    }
    imp[nod] = {1, 1};
    for (auto i : adia[nod]) {
        if (i == t[nod])
            continue;
        if (v[i] >= 1 && v[i] < s[h[i]]) {
            imp[nod].first += g[i];
            imp[nod].second = 1LL * imp[nod].second * prod[i] % MOD;
            continue;
        }
        if (v[i] > s[h[i]]) {
            imp[nod].second = 1LL * imp[nod].second * prod[i] % MOD;
            continue;
        }
        if (v[i] == s[h[i]]) {
            imp[nod].first += imp[i].first;
            imp[nod].second = 1LL * imp[nod].second * imp[i].second % MOD;
            continue;
        }
        //if (v[i] == -1)
        if (s[h[i]] == m) {
            imp[nod].first += imp[i].first;
            imp[nod].second = 1LL * imp[nod].second * imp[i].second % MOD;
        }
        else {
            if (imp[i].first == 0)
                imp[nod].second = 1LL * imp[nod].second * (imp[i].second + bigger[i]) % MOD;
            else
                imp[nod].second = 1LL * imp[nod].second * bigger[i] % MOD;
        }
    }
}

void dfs(int nod = 1) {
    bool pos((h[nod] == k - 1));
    g[nod] = 1;
    prod[nod] = (v[nod] == -1 ? m : 1);
    for (auto i : adia[nod]) {
        if (t[nod] == i)
            continue;
        t[i] = nod;
        h[i] = h[nod] + 1;
        dfs(i);
        prod[nod] = 1LL * prod[i] * prod[nod] % MOD;
        pos |= posibil[i];
        g[nod] += g[i];
    }
    if (h[nod] >= k) {
        posibil[nod] = 0;
        bigger[nod] = prod[nod];
    }
    else {
        posibil[nod] = (s[h[nod]] == v[nod] || v[nod] == -1) && pos;
        bigger[nod] = 1LL * prod[nod] * invm % MOD * (m - s[h[nod]]) % MOD;
    }
    ///cand nu e -1, in biggger e o problema
    ///Dar eu il folosesc doar cand e -1, cred ca e ok
}

int main()
{
    int c, n;

    cin >> c >> n >> m >> k;
    for (int i = 1; i <= n; ++i)
        cin >> v[i];
    for (int i = 0; i < k; ++i)
        cin >> s[i];

    invm = logpow(m, MOD - 2);

//    cout << 1LL * m * invm % MOD << '\n';

    for (int i = 1; i < n; ++i) {
        int a, b;
        cin >> a >> b;
        adia[a].push_back(b);
        adia[b].push_back(a);
    }
    dfs();
    dfs2();
//    for (int i = 1; i < n; ++i)
//        if (h[i] == k - 1)
//            bigger[i] = 1LL * bigger[i] * logpow(m - s[h[i]], MOD - 2) % MOD * (m - s[h[i]] + 1) % MOD;
    dfs3();
    #ifdef DEBUG
    cout << "prod: ";
    for (int i = 1; i <= n; ++i)
        cout << prod[i] << " \n"[i == n];
    cout << "bigger: ";
    for (int i = 1; i <= n; ++i)
        cout << bigger[i] << " \n"[i == n];
    cout << "posibil: ";
    for (int i = 1; i <= n; ++i)
        cout << posibil[i] << " \n"[i == n];
    cout << "imp: ";
    for (int i = 1; i <= n; ++i)
        cout << '(' << imp[i].first << ", " << imp[i].second << ')' << " \n"[i == n];
    cout << "dp: ";
    for (int i = 1; i <= n; ++i)
        cout << '(' << dp[i].first << ", " << dp[i].second << ')' << " \n"[i == n];
    #endif // DEBUG
    if (c == 1)
        return cout << dp[1].first << '\n', 0;
    cout << dp[1].second << '\n', 0;
    return 0;
}
/**
1
3 2 2
1 -1 -1
1 2
1 2
1 3

1
8 3 3
-1 -1 2 -1 -1 -1 1 2
1 2 2
1 2
2 3
2 4
4 5
1 6
6 7
1 8

*/