Pagini recente » Cod sursa (job #622524) | Cod sursa (job #1049208) | Cod sursa (job #1573658) | Cod sursa (job #852913) | Cod sursa (job #2887359)
#include <iostream>
#include <vector>
#include <bitset>
using namespace std;
#define DEBUG
///TINE MINTE SA SCHIMBI N INAPOI DUPA DEBUG
const int N = 5e2 + 7, MOD = 1e9 + 9;
const long long INF = 1e18 + 7;
int h[N], v[N], t[N], g[N], prod[N], bigger[N];
vector < int > adia[N];
bitset < N > posibil;
pair < long long, int > imp[N], dp[N];
int s[N];
int k;
int m, invm;
int logpow(int b, int e) {
int a(1);
for (; e; e >>= 1, b = 1LL * b * b % MOD)
if (e&1)
a = 1LL * a * b % MOD;
return a;
}
///SOL:
/// mereu ai de ales intre <, =, >
/// niciodata nu pui < de buna voie
/// mereu tintesti sa faci S pe unul dintre fii
/// pe restul ii pui >
/// nu are sens sa ii pui =, pentru ca cu siguranta incetinesti cu macar 1
/// pentru restul fiilor care raman cu ceva =, calculezi:
/// imp = cu cat te impiedica
/// dp[nod] = suma(g[fiu] | fiu < ) +
/// + suma(imp[fiu] | fiu = ) +
/// + dp[fiu], unde fiu este optim
/// imp[nod] = suma(g[fiu] | fiu < ) +
/// + suma(imp[fiu] | fiu = )
/// mai prelucrezi putin formula si ajungi la un min(dp[fiu] - imp[fiu])
/// multe alte detalii sunt luate in considerare pentru C = 2 si da
/// niste detalii:
/// cand esti >, o sa inmultesti si toate posibilitatile
/// cand e posibil sa atingi minim in mai multe locuri, insumezi niste fractii si dupa inmultesti rasp
/// etc...
/// scriu asta pentru ca stiu ca nu o sa apuc sa o depanez in concurs
void dfs3(int nod = 1) {
///niciodata nu o sa pui < s[i]
///v[nod] < s[i] => iti face doar rau
///v[nod] == s[i] nu este preferabil
///o sa faci pasi in plus daca nu intentionezi sa il creezi de acolo
///in rest faci mai mare strict
if (h[nod] >= k) {
dp[nod] = {0, prod[nod]};
return;
}
if (!posibil[nod]) {
dp[nod] = {INF, 0};
return;
}
if (h[nod] == k - 1) {
dp[nod] = {1, 1LL * prod[nod] * (v[nod] == -1 ? invm : 1) % MOD};
return;
}
for (auto i : adia[nod]) {
if (i == t[nod])
continue;
dfs3(i);
}
dp[nod] = {1, 1};
long long minim(INF);
int need = s[h[nod] + 1];
if (need == m) {
for (auto i : adia[nod]) {
if (i == t[nod])
continue;
if (v[i] != -1 && v[i] < m) {
dp[nod].first += g[i];
dp[nod].second = 1LL * dp[nod].second * prod[i] % MOD;
continue;
}
dp[nod].first += imp[i].first;
dp[nod].second = 1LL * dp[nod].second * imp[i].second % MOD;
if (minim > dp[i].first - imp[i].first)
m
inim = dp[i].first - imp[i].first;
}
dp[nod].first += minim;
int bruh(0);
for (auto i : adia[nod]) {
if (i == t[nod])
continue;
if (v[i] != -1 && v[i] < m)
continue;
if (minim == dp[i].first - imp[i].first)
bruh = (1LL * bruh + 1LL * dp[i].second * logpow(imp[i].second, MOD - 2)) % MOD;
}
dp[nod].second = 1LL * dp[nod].second * bruh % MOD;
}
else {
for (auto i : adia[nod]) {
if (i == t[nod])
continue;
if (v[i] != -1 && v[i] < need) {
dp[nod].first += g[i];
dp[nod].second = 1LL * dp[nod].second * prod[i] % MOD;
continue;
}
if (v[i] > need) {
dp[nod].second = 1LL * dp[nod].second * prod[i] % MOD;
continue;
}
if (v[i] == need) {
dp[nod].first += imp[i].first;
dp[nod].second = 1LL * dp[nod].second * imp[i].second % MOD;
if (minim > dp[i].first - imp[i].first)
minim = dp[i].first - imp[i].first;
}
else {
dp[nod].second = 1LL * dp[nod].second * bigger[i] % MOD;
if (minim > dp[i].first)
minim = dp[i].first;
}
}
dp[nod].first += minim;
int bruh(0);
for (auto i : adia[nod]) {
if (i == t[nod])
continue;
if (v[i] != -1 && v[i] < need)
continue;
if (v[i] == need) {
if (minim == dp[i].first - imp[i].first)
bruh = (1LL * bruh + 1LL * dp[i].second * logpow(imp[i].second, MOD - 2)) % MOD;
}
else {
if (minim == dp[i].first)
bruh = (1LL * bruh + 1LL * dp[i].second * logpow(bigger[i], MOD - 2)) % MOD;
}
}
dp[nod].second = 1LL * dp[nod].second * bruh % MOD;
}
}
void dfs2(int nod = 1) {
///daca s[i] == M, nu poti sa il ignori, cat te impiedica
for (auto i : adia[nod]) {
if (i == t[nod])
continue;
dfs2(i);
}
if (h[nod] >= k - 1 || (v[nod] >= 1 && v[nod] != s[h[nod]])) {
imp[nod] = {0, 1LL * prod[nod] * (v[nod] == -1 ? invm : 1) % MOD};
return;
}
imp[nod] = {1, 1};
for (auto i : adia[nod]) {
if (i == t[nod])
continue;
if (v[i] >= 1 && v[i] < s[h[i]]) {
imp[nod].first += g[i];
imp[nod].second = 1LL * imp[nod].second * prod[i] % MOD;
continue;
}
if (v[i] > s[h[i]]) {
imp[nod].second = 1LL * imp[nod].second * prod[i] % MOD;
continue;
}
if (v[i] == s[h[i]]) {
imp[nod].first += imp[i].first;
imp[nod].second = 1LL * imp[nod].second * imp[i].second % MOD;
continue;
}
//if (v[i] == -1)
if (s[h[i]] == m) {
imp[nod].first += imp[i].first;
imp[nod].second = 1LL * imp[nod].second * imp[i].second % MOD;
}
else {
if (imp[i].first == 0)
imp[nod].second = 1LL * imp[nod].second * (imp[i].second + bigger[i]) % MOD;
else
imp[nod].second = 1LL * imp[nod].second * bigger[i] % MOD;
}
}
}
void dfs(int nod = 1) {
bool pos((h[nod] == k - 1));
g[nod] = 1;
prod[nod] = (v[nod] == -1 ? m : 1);
for (auto i : adia[nod]) {
if (t[nod] == i)
continue;
t[i] = nod;
h[i] = h[nod] + 1;
dfs(i);
prod[nod] = 1LL * prod[i] * prod[nod] % MOD;
pos |= posibil[i];
g[nod] += g[i];
}
if (h[nod] >= k) {
posibil[nod] = 0;
bigger[nod] = prod[nod];
}
else {
posibil[nod] = (s[h[nod]] == v[nod] || v[nod] == -1) && pos;
bigger[nod] = 1LL * prod[nod] * invm % MOD * (m - s[h[nod]]) % MOD;
}
///cand nu e -1, in biggger e o problema
///Dar eu il folosesc doar cand e -1, cred ca e ok
}
int main()
{
int c, n;
cin >> c >> n >> m >> k;
for (int i = 1; i <= n; ++i)
cin >> v[i];
for (int i = 0; i < k; ++i)
cin >> s[i];
invm = logpow(m, MOD - 2);
// cout << 1LL * m * invm % MOD << '\n';
for (int i = 1; i < n; ++i) {
int a, b;
cin >> a >> b;
adia[a].push_back(b);
adia[b].push_back(a);
}
dfs();
dfs2();
// for (int i = 1; i < n; ++i)
// if (h[i] == k - 1)
// bigger[i] = 1LL * bigger[i] * logpow(m - s[h[i]], MOD - 2) % MOD * (m - s[h[i]] + 1) % MOD;
dfs3();
#ifdef DEBUG
cout << "prod: ";
for (int i = 1; i <= n; ++i)
cout << prod[i] << " \n"[i == n];
cout << "bigger: ";
for (int i = 1; i <= n; ++i)
cout << bigger[i] << " \n"[i == n];
cout << "posibil: ";
for (int i = 1; i <= n; ++i)
cout << posibil[i] << " \n"[i == n];
cout << "imp: ";
for (int i = 1; i <= n; ++i)
cout << '(' << imp[i].first << ", " << imp[i].second << ')' << " \n"[i == n];
cout << "dp: ";
for (int i = 1; i <= n; ++i)
cout << '(' << dp[i].first << ", " << dp[i].second << ')' << " \n"[i == n];
#endif // DEBUG
if (c == 1)
return cout << dp[1].first << '\n', 0;
cout << dp[1].second << '\n', 0;
return 0;
}
/**
1
3 2 2
1 -1 -1
1 2
1 2
1 3
1
8 3 3
-1 -1 2 -1 -1 -1 1 2
1 2 2
1 2
2 3
2 4
4 5
1 6
6 7
1 8
*/