Cod sursa(job #2581250)

Utilizator radustn92Radu Stancu radustn92 Data 14 martie 2020 19:25:09
Problema Heavy Path Decomposition Scor 100
Compilator cpp-64 Status done
Runda Arhiva educationala Marime 3 kb
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

const int NMAX = 100505;
const int LMAX = (1 << 18);
const int INF = 0x3f3f3f3f;

int N, queries;
int values[NMAX], head[NMAX], parent[NMAX], heavyChild[NMAX], depth[NMAX];
int segmTree[LMAX], posNode[NMAX];
vector<int> G[NMAX], nodes;
bool visited[NMAX];

int dfs(int node) {
	visited[node] = true;
	int currSize = 1, largestValue = 0;
	for (auto neighbour : G[node]) {
		if (!visited[neighbour]) {
			parent[neighbour] = node;
			depth[neighbour] = depth[node] + 1;
			int childSize = dfs(neighbour);
			currSize += childSize;
			if (childSize > largestValue) {
				largestValue = childSize;
				heavyChild[node] = neighbour;
			}
		}
	} 

	return currSize;
}

void createChains(int node, int startingNode) {
	head[node] = startingNode;
	nodes.push_back(node);
	posNode[node] = (int) nodes.size();
	if (heavyChild[node] != -1) {
		createChains(heavyChild[node], startingNode);
	}

	for (auto neighbour : G[node]) {
		if (neighbour != parent[node] && neighbour != heavyChild[node]) {
			createChains(neighbour, neighbour);
		}
	}
}

void cstr(int node, int l, int r) {
	if (l == r) {
		segmTree[node] = values[nodes[l - 1]];
		return;
	}

	int mid = (l + r) / 2;
	cstr(node * 2, l, mid);
	cstr(node * 2 + 1, mid + 1, r);
	segmTree[node] = max(segmTree[node * 2], segmTree[node * 2 + 1]);
}

void update(int node, int l, int r, int pos, int newValue) {
	if (l == r) {
		segmTree[node] = newValue;
		return;
	}

	int mid = (l + r) / 2;
	if (pos <= mid) {
		update(node * 2, l, mid, pos, newValue);
	} else {
		update(node * 2 + 1, mid + 1, r, pos, newValue);
	}

	segmTree[node] = max(segmTree[node * 2], segmTree[node * 2 + 1]);	
}

int getMax(int node, int l, int r, int from, int to) {
	if (to < l || r < from) {
		return -INF;
	}
	if (from <= l && r <= to) {
		return segmTree[node];
	}

	int mid = (l + r) / 2;
	return max(
		getMax(node * 2, l, mid, from, to),
		getMax(node * 2 + 1, mid + 1, r, from, to));
}

int getMaxBetweenNodes(int x, int y) {
	int maxValue = -INF;
	while (head[x] != head[y]) {
		if (depth[head[x]] < depth[head[y]]) {
			swap(x, y);
		}
		maxValue = max(maxValue, getMax(1, 1, N, posNode[head[x]], posNode[x]));
		x = parent[head[x]];
	}
	if (depth[x] > depth[y]) {
		swap(x, y);
	}
	maxValue = max(maxValue, getMax(1, 1, N, posNode[x], posNode[y]));
	return maxValue;
}

int main() {
	freopen("heavypath.in", "r", stdin);
	freopen("heavypath.out", "w", stdout);

	scanf("%d%d", &N, &queries);
	for (int node = 1; node <= N; node++) {
		scanf("%d", &values[node]);
	}

	int x, y;
	for (int edge = 1; edge < N; edge++) {
		scanf("%d%d", &x, &y);
		G[x].push_back(y);
		G[y].push_back(x);
	}

	fill(heavyChild, heavyChild + N + 1, -1);
	dfs(1);
	createChains(1, 1);

	cstr(1, 1, N);
	int type;
	for (int queryNo = 0; queryNo < queries; queryNo++) {
		scanf("%d%d%d", &type, &x, &y);
		switch (type) {
			case 0: {
				update(1, 1, N, posNode[x], y);
				break;
			}
			case 1: {
				printf("%d\n", getMaxBetweenNodes(x, y));
				break;
			}
		}
	}
	return 0;
}