Cod sursa(job #2917612)

Utilizator PopoviciRobertPopovici Robert PopoviciRobert Data 5 august 2022 22:29:06
Problema Heavy Path Decomposition Scor 0
Compilator cpp-64 Status done
Runda Arhiva educationala Marime 4.42 kb
#include <bits/stdc++.h>
#define lsb(x) (x & (-x))
#define ll long long
#define ull unsigned long long
#define Test(tt) cout << "Case #" << tt << ": "

using namespace std;

class SegTree {
public:
    SegTree() = default;

    void resize(int n) {
        this->n = n;
        tree.resize(4 * n + 1);
    }

    void set(int pos, int val) {
        set(1, 1, n, pos, val);
    }

    int get(int l, int r) {
        return get(1, 1, n, l, r);
    }

private:
    void set(int node, int left, int right, int pos, int val) {
        if (left == right) {
            tree[node] = val;
        } else {
            int mid = (left + right) / 2;

            if (pos <= mid) set(2 * node, left, mid, pos, val);
            else set(2 * node + 1, mid + 1, right, pos, val);

            tree[node] = max(tree[2 * node], tree[2 * node + 1]);
        }
    }

    int get(int node, int left, int right, int l, int r) {
        if (l <= left && right <= r) {
            return tree[node];
        } else {
            int mid = (left + right) / 2;
            int answer = std::numeric_limits<int>::min();

            if (l <= mid) answer = max(answer, get(2 * node, left, mid, l, r));
            if (mid < r) answer = max(answer, get(2 * node + 1, mid + 1, right, l, r));

            return answer; 
        }
    }

private:
    vector<int> tree;
    int n;
};


int main() {
#ifdef HOME
    ifstream cin("A.in");
    ofstream cout("A.out");
#else
    ifstream cin("heavypath.in");
    ofstream cout("heavypath.out");
#endif
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    int n, q;
    cin >> n >> q;

    vector<int> val(n + 1);
    for (int i = 1; i <= n; i++) {
        cin >> val[i];
    }

    vector<vector<int>> g(n + 1);
    for (int i = 1; i < n; i++) {
        int x, y;
        cin >> x >> y;
        g[x].push_back(y);
        g[y].push_back(x);
    }

    vector<int> chain(n + 1), weight(n + 1, 1);
    vector<int> lvl(n + 1), parent(n + 1);

    auto dfs = [&](auto self, int node, int par) {
        lvl[node] = lvl[par] + 1;
        parent[node] = par;

        static int chain_id = 0;
        if ((int)g[node].size() == 1 && par != 0) {
            chain[node] = chain_id++;
            return;
        }

        int heavy_son = -1;
        for (auto son : g[node]) {
            if (son != par) {
                self(self, son, node);
                weight[node] += weight[son];

                if (heavy_son == -1 || weight[heavy_son] < weight[son]) {
                    heavy_son = son;
                }
            }
        }

        chain[node] = chain[heavy_son];
    };

    dfs(dfs, 1, 0);

    int num = *max_element(chain.begin(), chain.end()) + 1;

    vector<vector<int>> chains(num);
    vector<int> pos(n + 1);

    auto dfs2 = [&](auto self, int node, int par) -> void {
        chains[chain[node]].push_back(node);
        pos[node] = (int)chains[chain[node]].size();

        for (auto son : g[node]) {
            if (son != par) {
                self(self, son, node);
            }
        }
    };

    dfs2(dfs2, 1, 0);

    // for (int i = 1; i <= n; i++) {
    //     cerr << i << " " << chain[i] << " " << pos[i] << "\n";
    // }

    vector<SegTree> st(num);
    for (int i = 0; i < num; i++) {
        st[i].resize(chains[i].size());
    }

    for (int i = 1; i <= n; i++) {
        st[chain[i]].set(pos[i], val[i]);
    }

    auto Query = [&](int x, int y) -> int {
        int answer = std::numeric_limits<int>::min();
        while (chain[x] != chain[y]) {
            int first_x = chains[chain[x]][0];
            int first_y = chains[chain[y]][0];

            if (lvl[first_x] < lvl[first_y]) {
                swap(x, y);
            }

            answer = max(answer, st[chain[x]].get(1, pos[x]));

            x = parent[first_x];
        }

        if (lvl[x] < lvl[y]) {
            swap(x, y);
        }

        answer = max(answer, st[chain[x]].get(pos[y], pos[x]));

        return answer;
    };


    while (q--) {
        int type;
        cin >> type;

        if (type == 0) {
            int node, new_val;
            cin >> node >> new_val;
            st[chain[node]].set(pos[node], new_val);
        } else {
            int x, y;
            cin >> x >> y;
            cout << Query(x, y) << "\n";
        }
    }
   
    return 0;
}