Cod sursa(job #2352135)

Utilizator flibiaVisanu Cristian flibia Data 22 februarie 2019 23:55:11
Problema Heavy Path Decomposition Scor 40
Compilator cpp-64 Status done
Runda Arhiva educationala Marime 3.3 kb
#include <bits/stdc++.h>
#define L (pos << 1)
#define R (L | 1)

using namespace std;

ifstream in("heavypath.in");
ofstream out("heavypath.out");

class chain {
public:
	int sz;	
	vector <int> data;
	void init() {
		data.resize(4 * sz + 100);
	}
	void upd(int st, int dr, int pos, int id, int val) {
		if (st == dr) {
			data[pos] = val;
			return;
		}
		int mid = st + dr >> 1;
		if (id <= mid)
			upd(st, mid, L, id, val);
		else upd(mid + 1, dr, R, id, val);
		data[pos] = max(data[L], data[R]);
	}
	int que(int st, int dr, int pos, int l, int r) {
		if (l <= st && dr <= r)
			return data[pos];
		int mid = st + dr >> 1;
		int lft = 0, rgt = 0;
		if (l <= mid)
			lft = que(st, mid, L, l, r);
		if (r > mid)
			rgt = que(mid + 1, dr, R, l, r);
		return max(lft, rgt);
	}
};

int n, m, x, y, vf, t;
int viz[100100], niv[100100], stk[200100], aint[800100], dim[100100], a[100100], pos[100100], wh[100100];
int cnt, id[100100], dad[100100];
vector <int> v[100100], g[100100];
chain ch[100100];

void calc(int x) {
	viz[x] = 1;
	dim[x] = 1;
	for (auto y : v[x])
		if (!viz[y]) {
			calc(y);
			dim[x] += dim[y];
		}
}	

void dfs(int x, int lvl) {
	viz[x] = 1;
	stk[++vf] = x;
	pos[x] = vf;
	niv[x] = lvl;
	int mx = 0, p = 0;
	for (auto y : v[x])
		if (!viz[y]) {
			dad[y] = x;
			dfs(y, lvl + 1);
			stk[++vf] = x;
			if (dim[y] > mx) {
				mx = dim[y];
				p = y;
			}
		}
	if (!p) {
		++cnt;
		id[x] = cnt;
		g[cnt].push_back(x);
	} else {
		id[x] = id[p];
		g[id[x]].push_back(x);
	}
}

void build(int st, int dr, int pos) {
	if (st == dr) {
		aint[pos] = stk[st];
		return;
	}
	int mid = st + dr >> 1;
	build(st, mid, L);
	build(mid + 1, dr, R);
	if (niv[aint[L]] < niv[aint[R]])
		aint[pos] = aint[L];
	else aint[pos] = aint[R];
}

int lca(int st, int dr, int pos, int l, int r) {
	if (l <= st && dr <= r)
		return aint[pos];
	int mid = st + dr >> 1;
	int lft = 0, rgt = 0;
	if (l <= mid)
		lft = lca(st, mid, L, l, r);
	if (r > mid)
		rgt = lca(mid + 1, dr, R, l, r);
	if (niv[lft] < niv[rgt])
		return lft;
	return rgt;
}

void hld() {
	for (int i = 1; i <= cnt; i++) {
		int sz = g[i].size();
		ch[i].sz = sz;
		ch[i].init();
		for (int j = 0; j < sz; j++) {
			ch[i].upd(1, sz, 1, j + 1, a[g[i][j]]);
			wh[g[i][j]] = j + 1;
		}
	}	
}

int solve(int x, int y) {
	int px = pos[x];
	int py = pos[y];
	if (px > py)
		swap(px, py);
	int p = lca(1, vf, 1, px, py);
	int ans = 0;
	while (id[x] != id[p]) {
		int sz = ch[id[x]].sz;
		ans = max(ans, ch[id[x]].que(1, sz, 1, wh[x], sz));
		x = dad[g[id[x]][sz - 1]];
	}
	while (id[y] != id[p]) {
		int sz = ch[id[y]].sz;
		ans = max(ans, ch[id[y]].que(1, sz, 1, wh[y], sz));
		y = dad[g[id[y]][sz - 1]];
	}
	if (wh[x] > wh[y])
		swap(x, y);
	ans = max(ans, ch[id[p]].que(1, ch[id[p]].sz, 1, wh[x], wh[y]));
	return ans;
}

int main() {
	in >> n >> m;
	for (int i = 1; i <= n; i++)
		in >> a[i];
	for (int i = 1; i < n; i++) {
		in >> x >> y;
		v[x].push_back(y);
		v[y].push_back(x);
	}
	calc(1);
	memset(viz, 0, sizeof viz);
	dfs(1, 1);	
	niv[0] = 2e9;
	build(1, vf, 1);
	hld();
	while (m--) {
		in >> t >> x >> y;
		if (t == 0) {
			a[x] = y;
			ch[id[x]].upd(1, ch[id[x]].sz, 1, wh[x], y);
		} else {
			out << solve(x, y) << '\n';
		}
	}
	return 0;
}