Cod sursa(job #2079559)

Utilizator MaligMamaliga cu smantana Malig Data 1 decembrie 2017 15:45:20
Problema Heavy Path Decomposition Scor 100
Compilator cpp Status done
Runda Arhiva educationala Marime 4.1 kb
#include <iostream>
#include <fstream>
#include <cstring>
#include <vector>
#include <algorithm>

using namespace std;
ifstream in("heavypath.in");
ofstream out("heavypath.out");

#define ll long long
#define ull unsigned long long
#define pb push_back
const int NMax = 1e5 + 5;
const int arbMax = 4*NMax;

int N,M,nrChain;
int value[NMax],depth[NMax], sub[NMax], chainOf[NMax],
    chainOffset[NMax], chainDim[NMax], chainDad[NMax], chainDepth[NMax],
    aint[arbMax];
vector<int> v[NMax],chain[NMax];

void read();
void getChains();
void solveQueries();
void dfs(int);
void build(int,int,int,int);
void update(int,int,int,int,int,int);
int query(int,int,int,int,int,int);

int main() {
    read();
    getChains();
    solveQueries();

    in.close();out.close();
    return 0;
}

void read() {
    in>>N>>M;

    for (int i=1;i <= N;++i) {
        in>>value[i];
    }

    for (int i=1;i < N;++i) {
        int x,y;
        in>>x>>y;

        v[x].pb(y);
        v[y].pb(x);
    }

}

void getChains() {
    depth[1] = 1;
    dfs(1);

    for (int i=1;i <= nrChain;++i) {
        reverse(chain[i].begin(),chain[i].end());

        chainOffset[i] = chainOffset[i-1] + 4 * chainDim[i-1];
        build(1,1,chainDim[i],i);
    }
}

void dfs(int node) {
    sub[node] = 1;
    bool leaf = true;
    int heavy = 0;

    for (int nxt : v[node]) {
        if (depth[nxt] != 0) {
            continue;
        }

        leaf = false;
        depth[nxt] = depth[node] + 1;
        dfs(nxt);
        sub[node] += sub[nxt];

        if (sub[heavy] < sub[nxt]) {
            heavy = nxt;
        }
    }

    if (leaf) {
        chain[++nrChain].pb(node);
        chainDim[nrChain] = 1;
        chainOf[node] = nrChain;
        return;
    }

    chainOf[node] = chainOf[heavy];
    ++chainDim[ chainOf[heavy] ];
    chain[ chainOf[heavy] ].pb(node);

    for (int nxt : v[node]) {
        if (nxt == heavy || depth[nxt] < depth[node]) {
            continue;
        }

        chainDad[ chainOf[nxt] ] = node;
        chainDepth[ chainOf[nxt] ] = depth[node];
    }
}

void solveQueries() {
    while (M--) {
        int tip,x,y;
        in>>tip>>x>>y;

        if (!tip) {
            update(1,1,chainDim[ chainOf[x] ],chainOffset[ chainOf[x] ],depth[x] - chainDepth[ chainOf[x] ],y);
        }
        else {
            int mx = -1;

            while (chainOf[x] != chainOf[y]) {
                if (chainDepth[ chainOf[x] ] < chainDepth[ chainOf[y] ]) {
                    swap(x,y);
                }

                int val = query(1,1,chainDim[ chainOf[x] ],chainOffset[ chainOf[x] ],1,depth[x] - chainDepth[ chainOf[x] ]);
                mx = max(mx,val);
                x = chainDad[ chainOf[x] ];
            }
            if (depth[x] > depth[y]) {
                swap(x,y);
            }

            int val = query(1,1,chainDim[ chainOf[x] ],chainOffset[ chainOf[x] ],depth[x] - chainDepth[ chainOf[x] ],depth[y] - chainDepth[ chainOf[x] ]);
            mx = max(mx,val);

            out<<mx<<'\n';
        }
    }
}

#define mij ((st+dr)>>1)
#define fs (node<<1)
#define ss (fs+1)
void build(int node,int st,int dr,int id) {
    int off = chainOffset[id];

    if (st == dr) {
        aint[node + off] = value[ chain[id][st-1] ];
        return;
    }

    build(fs,st,mij,id);
    build(ss,mij+1,dr,id);
    aint[node + off] = max(aint[fs + off],aint[ss + off]);
}

void update(int node,int st,int dr,int off,int pos,int val) {
    if (st == dr) {
        aint[node + off] = val;
        return;
    }

    if (pos <= mij) {
        update(fs,st,mij,off,pos,val);
    }
    else {
        update(ss,mij+1,dr,off,pos,val);
    }

    aint[node + off] = max(aint[fs + off],aint[ss + off]);
}

int query(int node,int st,int dr,int off,int a,int b) {
    if (a <= st && dr <= b) {
        return aint[node + off];
    }

    int mx = 0;
    if (a <= mij) {
        mx = query(fs,st,mij,off,a,b);
    }
    if (mij+1 <= b) {
        mx = max(mx,query(ss,mij+1,dr,off,a,b));
    }

    return mx;
}