Cod sursa(job #3156158)

Utilizator IvanAndreiIvan Andrei IvanAndrei Data 10 octombrie 2023 18:29:08
Problema Heavy Path Decomposition Scor 100
Compilator cpp-64 Status done
Runda Arhiva educationala Marime 3.1 kb
#include <fstream>
#include <vector>

using namespace std;

ifstream in ("heavypath.in");
ofstream out ("heavypath.out");

const int max_size = 1e5 + 1, max_aint = 4e5 + 1;

int v[max_size], cap[max_size], t[max_size], sz[max_size], poz[max_size], aint[max_aint], timp, lin[max_size], n, lvl[max_size];
vector <int> mc[max_size];

void dfssz (int nod, int par)
{
    lvl[nod] = lvl[par] + 1;
    t[nod] = par;
    cap[nod] = nod;
    sz[nod] = 1;
    for (auto f : mc[nod])
    {
        if (f == par)
        {
            continue;
        }
        dfssz(f, nod);
        sz[nod] += sz[f];
    }
}

void dfsh (int nod, int par)
{
    lin[++timp] = nod;
    poz[nod] = timp;
    int mx = -1, urm;
    for (auto f : mc[nod])
    {
        if (f == par)
        {
            continue;
        }
        if (mx < sz[f])
        {
            mx = sz[f];
            urm = f;
        }
    }
    if (mx == -1)
    {
        return;
    }
    cap[urm] = cap[nod];
    dfsh(urm, nod);
    for (auto f : mc[nod])
    {
        if (f == par || f == urm)
        {
            continue;
        }
        dfsh(f, nod);
    }
}

void init (int l, int r, int nod)
{
    if (l == r)
    {
        aint[nod] = v[lin[l]];
        return;
    }
    int m = (l + r) / 2;
    init(l, m, 2 * nod);
    init(m + 1, r, 2 * nod + 1);
    aint[nod] = max(aint[2 * nod], aint[2 * nod + 1]);
}

void upd (int l, int r, int poz, int val, int nod)
{
    if (l == r)
    {
        aint[nod] = val;
        return;
    }
    int m = (l + r) / 2;
    if (poz <= m)
    {
        upd(l, m, poz, val, 2 * nod);
    }
    else
    {
        upd(m + 1, r, poz, val, 2 * nod + 1);
    }
    aint[nod] = max(aint[2 * nod], aint[2 * nod + 1]);
}

int query (int l, int r, int st, int dr, int nod)
{
    if (st <= l && r <= dr)
    {
        return aint[nod];
    }
    int m = (l + r) / 2, ans1 = 0, ans2 = 0;
    if (st <= m)
    {
        ans1 = query(l, m, st, dr, 2 * nod);
    }
    if (dr > m)
    {
        ans2 = query(m + 1, r, st, dr, 2 * nod + 1);
    }
    return max(ans1, ans2);
}

int queryheavy (int x, int y)
{
    if (cap[x] == cap[y])
    {
        x = poz[x];
        y = poz[y];
        if (x > y)
        {
            swap(x, y);
        }
        return query(1, n, x, y, 1);
    }
    if (lvl[cap[x]] < lvl[cap[y]])
    {
        swap(x, y);
    }
    return max(query(1, n, poz[cap[x]], poz[x], 1), queryheavy(t[cap[x]], y));
}

int main ()
{
    int q;
    in >> n >> q;
    for (int i = 1; i <= n; i++)
    {
        in >> v[i];
    }
    for (int i = 1; i < n; i++)
    {
        int x, y;
        in >> x >> y;
        mc[x].push_back(y);
        mc[y].push_back(x);
    }
    dfssz(1, 0);
    dfsh(1, 0);
    init(1, n, 1);
    while (q--)
    {
        int op, x, y;
        in >> op >> x >> y;
        if (op == 0)
        {
            upd(1, n, poz[x], y, 1);
        }
        else
        {
            out << queryheavy(x, y) << '\n';
        }
    }
    in.close();
    out.close();
    return 0;
}