Cod sursa(job #3276010)

Utilizator AlexandruBenescuAlexandru Benescu AlexandruBenescu Data 12 februarie 2025 12:36:50
Problema Lowest Common Ancestor Scor 0
Compilator cpp-64 Status done
Runda Arhiva educationala Marime 2.12 kb
#include <bits/stdc++.h>
#define L 50005
#define S 17
using namespace std;

int n, q, col[L];
vector <int> G[L];
int firstAp[L], lev[L], dist[2][L], p2[2 * L], rmq[S][2 * L], euler[2 * L], eulerSize, bucketSize;
set <int> toggled;

void dfsEuler(int node, int prev) {
  for (auto it : G[node]) {
    if (it == prev)
      continue;
    euler[++eulerSize] = node;
    lev[it] = lev[node] + 1;
    dfsEuler(it, node);
  }
  euler[++eulerSize] = node;
}

int minLev(int x, int y) {
  if (lev[x] < lev[y])
    return x;
  return y;
}

void computeLCA() {
  lev[1] = 1;
  dfsEuler(1, 0);
  for (int i = 1; i <= eulerSize; i++)
    if (firstAp[euler[i]] == 0)
      firstAp[euler[i]] = i;

  int p = 0;
  for (int i = 1; i <= 30; i++) {
    if ((1 << (p + 1)) <= i)
      p++;
    p2[i] = p;
  }

  for (int i = 1; i <= eulerSize; i++)
    rmq[0][i] = euler[i];
  for (int bit = 1; bit < S; bit++)
    for (int i = 1; i <= eulerSize - (1 << bit) + 1; i++)
      rmq[bit][i] = minLev(rmq[bit - 1][i], rmq[bit - 1][i + (1 << (bit - 1))]);
}

void computeDist() {

}

void refresh() {
  toggled.clear();
  computeDist();
}

int distBetween(int x, int y) {
  x = firstAp[x];
  y = firstAp[y];
  int p = p2[y - x + 1];
  return minLev(rmq[p][x], rmq[p][y - (1 << p) + 1]);
}

int main() {
  cin >> n >> q;
  for (int i = 1; i <= n; i++)
    cin >> col[i];
  for (int i = 1; i < n; i++) {
    int a, b;
    cin >> a >> b;
    G[a].push_back(b);
    G[b].push_back(a);
  }

  computeLCA();
  computeDist();

  bucketSize = sqrt(q);
  for (int i = 1; i <= q; i++) {
    if (i % bucketSize == 0)
      refresh();
    int t, node;
    cin >> t >> node;
    if (t == 1) {
      col[node] = 1 - col[node];
      if (toggled.count(node))
        toggled.erase(node);
      else
        toggled.insert(node);
    }
    else {
      int sum = dist[col[node]][node];
      for (auto it : toggled) {
        if (col[it] == col[node])
          sum += distBetween(node, it);
        else
          sum -= distBetween(node, it);
      }
      //cout << sum << "\n";
    }
  }
  return 0;
}