Cod sursa(job #2124631)

Utilizator silkMarin Dragos silk Data 7 februarie 2018 13:51:02
Problema Heavy Path Decomposition Scor 100
Compilator cpp Status done
Runda Arhiva educationala Marime 2.3 kb
#include <bits/stdc++.h>
#define MaxN 100000
using namespace std;

vector<int> A[MaxN+1];
vector<int> L[MaxN+1];
int aint[4*MaxN+1];
int tata[MaxN+1];
int nxt[MaxN+1];
int in[MaxN+1];
int sz[MaxN+1];
int h[MaxN+1];
int a[MaxN+1];
int v[MaxN+1];
int n,t,N,M;

void Get_L(int x)
{
    for(auto y : A[x])
    if(!tata[y])
    {
        L[x].push_back(y);
        tata[y] = x;
        Get_L(y);
    }
}

void DFS(int x)
{
    sz[x] = 1;
    for(auto& y : L[x])
    {
        h[y] = h[x] + 1;
        DFS(y);

        sz[x] += sz[y];
        if(sz[y] > sz[ L[x][0] ]) swap(y, L[x][0]);
    }
}

void DF(int x)
{
    in[x] = ++t;
    a[t] = v[x];
    for(auto y : L[x])
    {
        nxt[y] = (y == L[x][0] ? nxt[x] : y);
        DF(y);
    }
}

void Update(int p, int x)
{
    aint[p] = x;
    while(p > 1)
    {
        int f = p / 2;
        aint[f] = max(aint[2*f], aint[2*f+1]);
        p = f;
    }
}

int Query(int nod, int x, int y, int st, int dr)
{
    if(x == st && y == dr) return aint[nod];

    int m = (st + dr) / 2;
    if(y <= m) return Query(2*nod, x, y, st, m);
    else if(m < x) return Query(2*nod+1, x, y, m+1, dr);
    return max(Query(2*nod, x, m, st, m), Query(2*nod+1, m+1, y, m+1, dr));
}

int Solve(int x, int y)
{
    int res = 0;
    while(nxt[x] != nxt[y])
    {
        if(h[ nxt[x] ] < h[ nxt[y] ]) swap(x, y);
        res = max(res, Query(1, in[ nxt[x] ], in[x], 1, N));
        x = tata[ nxt[x] ];
    }

    if(h[x] > h[y]) swap(x, y);
    return max(res, Query(1, in[x], in[y], 1, N));
}

int main(){
    FILE* fin = fopen("heavypath.in","r");
    FILE* fout = fopen("heavypath.out","w");

    int i,x,y,op;

    fscanf(fin,"%d %d",&n,&M);
    for(i = 1; i <= n; ++i) fscanf(fin,"%d",&v[i]);
    for(i = 1; i < n; ++i)
    {
        fscanf(fin,"%d %d",&x,&y);
        A[x].push_back(y);
        A[y].push_back(x);
    }

    nxt[1] = tata[1] = 1;
    Get_L(1); DFS(1); DF(1);

    for(N = 1; N < n; N *= 2);
    for(i = N; i < N+n; ++i) aint[i] = a[i-N+1];
    for(i = N-1; i >= 1; --i) aint[i] = max(aint[2*i], aint[2*i+1]);
    for(i = 1; i <= M; ++i)
    {
        fscanf(fin,"%d %d %d",&op,&x,&y);
        if(!op) Update(N+in[x]-1, y);
        else fprintf(fout,"%d\n",Solve(x,y));
    }


return 0;
}