Cod sursa(job #1708871)

Utilizator Vlad_317Vlad Panait Vlad_317 Data 28 mai 2016 04:03:55
Problema Lowest Common Ancestor Scor 40
Compilator cpp Status done
Runda Arhiva educationala Marime 2.09 kb
#include <stdio.h>
#include <vector>
using namespace std;

#define INF 0x3f3f3f3f
const int MAX = 100001;

int n, m;
vector<int> g[MAX];
///graph
int e[2 * MAX + 10], l[2 * MAX + 10], first[MAX];
int k;

///tree
int tree[MAX * 4 + 1];
int sol = INF, hsol = INF;

int min(int a, int b)
{
    if(a < b)
        return a;
    return b;
}

void euler(int node, int level)
{
    e[++k] = node;
    l[k] = level;
    first[node] = k;
    for(int i = 0; i < g[node].size(); i++)
    {
        euler(g[node][i], level + 1);

        e[++k] = node;
        l[k] = level;
    }
}

void update(int node, int left, int right)
{
    if(left == right)
    {
        tree[node] = left;
        return;
    }

    int mid = (left + right) / 2;
    update(2 * node, left, mid);
    update(2 * node + 1, mid + 1, right);

    if(l[tree[node * 2]] <= l[tree[node * 2 + 1]])
        tree[node] = tree[node * 2];
    else
        tree[node] = tree[node * 2 + 1];

}

void query(int node, int left, int right, int x, int y)
{
    if(x <= left && right <= y)
    {
        /// min in tree[node]
        if(l[tree[node]] < hsol)
        {
            hsol = l[tree[node]];
            sol = e[tree[node]];
        }
        return;
    }

    int mid = (left + right) / 2;
    if(x <= mid)
        query(2 * node, left, mid, x, y);
    if(y > mid)
        query(2 * node + 1, mid + 1, right, x, y);

}

int lca(int x, int y)
{
    int left = first[x], right = first[y];
    if(left > right)
        swap(left, right);
    sol = hsol = INF;
    query(1, 1, k, left, right);
    return sol;
}

int main()
{
    FILE *fin, *fout;

    fin = fopen("lca.in", "r");
    fout = fopen("lca.out", "w");

    fscanf(fin, "%d%d", &n, &m);

    for(int i = 2; i <= n; i++)
    {
        int x;
        fscanf(fin, "%d", &x);
        g[x].push_back(i);
    }

    euler(1, 0);
    update(1, 1, k);

    for(int i = 1; i <= m; i++)
    {
        int x, y;
        fscanf(fin, "%d%d", &x, &y);
        fprintf(fout, "%d\n", lca(x, y));
    }

    return 0;
}