Cod sursa(job #2231872)

Utilizator radustn92Radu Stancu radustn92 Data 16 august 2018 13:08:23
Problema Heavy Path Decomposition Scor 0
Compilator java Status done
Runda Arhiva educationala Marime 10.63 kb
import java.io.*;
import java.time.Clock;
import java.util.*;

public class Main {

    public static void main(String[] args) throws IOException {
        InputStream inputStream = new FileInputStream("heavypath.in");
        OutputStream outputStream = new FileOutputStream("heavypath.out");

        long millisStart = System.currentTimeMillis();

        try (InputReader inputReader = new InputReader(inputStream);
            PrintWriter printWriter = new PrintWriter(outputStream)) {
            int N = inputReader.nextInt();
            int M = inputReader.nextInt();
            int[] valuesPerNode = new int[N + 1];

            for (int node = 1; node <= N; node++) {
                valuesPerNode[node] = inputReader.nextInt();
            }

            Graph graph = new Graph(N);
            for (int edgeNo = 1; edgeNo < N; edgeNo++) {
                graph.addEdge(inputReader.nextInt(), inputReader.nextInt());
            }

            long millisReadGraph = System.currentTimeMillis();
            System.out.println("Spent " + (millisReadGraph - millisStart) + " reading graph");

            HeavyPathDecomp heavyPathDecomp = new HeavyPathDecomp(graph, valuesPerNode);

            long millisBuildHeavPath = System.currentTimeMillis();

            System.out.println("Spent " + (millisBuildHeavPath - millisReadGraph) + " building heavy path");

            int type, x, y;
            for (int queryNo = 1; queryNo <= M; queryNo++) {
                type = inputReader.nextInt();
                x = inputReader.nextInt();
                y = inputReader.nextInt();

                switch (type) {
                    case 0:
                        heavyPathDecomp.setValue(x, y);
                        break;
                    case 1:
                        printWriter.println(heavyPathDecomp.getMaxOnPath(x, y));
                        break;
                    default:
                        throw new RuntimeException("Unexpected type: " + type);
                }
            }

            long millisAfterQueries = System.currentTimeMillis();
            System.out.println("Spent " + (millisAfterQueries - millisBuildHeavPath) + " answering queries");
        }
    }

    static class HeavyPathDecomp {

        private Graph graph;

        private int[] pathIdx;

        private int[] firstNodeInPath;

        private int noPaths;

        private int[] pathLength;

        private int[] posInPath;

        private int[] childrenPerNode;

        private int[] parent;

        private int[] depth;

        private SegmentTree[] segmentTrees;

        public HeavyPathDecomp(Graph graph, int[] costPerNode) {
            this.graph = graph;
            this.pathIdx = new int[graph.getNumberOfNodes() + 1];
            this.firstNodeInPath = new int[graph.getNumberOfNodes() + 1];
            this.pathLength = new int[graph.getNumberOfNodes() + 1];
            this.posInPath = new int[graph.getNumberOfNodes() + 1];
            this.parent = new int[graph.getNumberOfNodes() + 1];
            this.depth = new int[graph.getNumberOfNodes() + 1];
            Arrays.fill(pathLength, 0);
            Arrays.fill(depth, 0);

            BitSet visited = new BitSet(graph.getNumberOfNodes() + 1);
            childrenPerNode = new int[graph.getNumberOfNodes() + 1];
            computeChildrenPerNode(1, visited);

            noPaths = 1;
            visited.clear();
            decomposeGraph(1, visited, noPaths);

            int[][] values = new int[noPaths + 1][];
            for (int pathIdx = 1; pathIdx <= noPaths; pathIdx++) {
                values[pathIdx] = new int[pathLength[pathIdx] + 1];
            }
            for (int node = 1; node <= graph.getNumberOfNodes(); node++) {
                values[pathIdx[node]][posInPath[node]] = costPerNode[node];
            }

            segmentTrees = new SegmentTree[noPaths + 1];
            for (int pathIdx = 1; pathIdx <= noPaths; pathIdx++) {
                segmentTrees[pathIdx] = new SegmentTree(pathLength[pathIdx], values[pathIdx]);
            }
        }

        public void setValue(int node, int newValue) {
            int pathIdx = this.pathIdx[node];
            int posInPath = this.posInPath[node];
            segmentTrees[pathIdx].setValue(posInPath + 1, newValue);
        }

        private int getStartNodeOnPath(int node) {
            return firstNodeInPath[pathIdx[node]];
        }

        private int getMaxOnPath(int node1, int node2) {
            int ans = Integer.MIN_VALUE;
            int startNodePath1, startNodePath2;
            while (pathIdx[node1] != pathIdx[node2]) {
                startNodePath1 = getStartNodeOnPath(node1);
                startNodePath2 = getStartNodeOnPath(node2);
                if (depth[startNodePath1] > depth[startNodePath2]) {
                    ans = Math.max(ans,
                            segmentTrees[pathIdx[node1]].getMax(1, posInPath[node1] + 1));
                    node1 = parent[startNodePath1];
                } else {
                    ans = Math.max(ans,
                            segmentTrees[pathIdx[node2]].getMax(1, posInPath[node2] + 1));
                    node2 = parent[startNodePath2];
                }
            }

            if (depth[node1] < depth[node2]) {
                ans = Math.max(ans,
                        segmentTrees[pathIdx[node1]].getMax(posInPath[node1] + 1, posInPath[node2] + 1));
            } else {
                ans = Math.max(ans,
                        segmentTrees[pathIdx[node1]].getMax(posInPath[node2] + 1, posInPath[node1] + 1));
            }

            return ans;
        }

        private void computeChildrenPerNode(int node, BitSet visited) {
            childrenPerNode[node] = 1;
            visited.set(node);
            for (int neighbour : graph.getEdges(node)) {
                if (!visited.get(neighbour)) {
                    parent[neighbour] = node;
                    depth[neighbour] = depth[node] + 1;
                    computeChildrenPerNode(neighbour, visited);
                    childrenPerNode[node] += node;
                }
            }
        }

        private void decomposeGraph(int node, BitSet visited, int currPath) {
            visited.set(node);
            pathIdx[node] = currPath;
            posInPath[node] = ++pathLength[currPath];
            if (pathLength[currPath] == 1) {
                firstNodeInPath[currPath] = node;
            }

            int bestChild = -1, bestChildWeight = -1;
            for (int child : graph.getEdges(node)) {
                if (!visited.get(child)) {
                    if (childrenPerNode[child] > bestChildWeight) {
                        bestChildWeight = childrenPerNode[child];
                        bestChild = child;
                    }
                }
            }

            if (bestChild == -1) {
                return;
            }

            decomposeGraph(bestChild, visited, currPath);
            for (int child : graph.getEdges(node)) {
                if (!visited.get(child)) {
                    if (child != bestChild) {
                        decomposeGraph(child, visited, ++noPaths);
                    }
                }
            }
        }
    }

    static class SegmentTree {

        private int N;

        private int[] maxPerNode;

        public SegmentTree(int N, int[] values) {
            this.N = N;
            int reqSize = 1;
            while (reqSize < N) {
                reqSize *= 2;
            }
            maxPerNode = new int[reqSize * 2 + 1];

            cstr(1, N, 1, values);
        }

        private void cstr(int left, int right, int node, int[] values) {
            if (left == right) {
                maxPerNode[node] = values[left];
                return;
            }

            int mid = (left + right) / 2;
            cstr(left, mid, node * 2, values);
            cstr(mid + 1, right, node * 2 + 1, values);
            maxPerNode[node] = Math.max(maxPerNode[node * 2], maxPerNode[node * 2 + 1]);
        }

        public void setValue(int pos, int value) {
            setValue(1, N, 1, pos, value);
        }

        public int getMax(int from, int to) {
            return getMax(1, N, 1, from, to);
        }

        private int getMax(int left, int right, int node, int from, int to) {
            if (to < left || right < from) {
                return Integer.MIN_VALUE;
            }
            if (from <= left && right <= to) {
                return maxPerNode[node];
            }

            int mid = (left + right) / 2;
            return Math.max(
                    getMax(left, mid, node * 2, from, to),
                    getMax(mid + 1, right, node * 2 + 1, from, to));
        }

        private void setValue(int left, int right, int node, int pos, int value) {
            if (left == right) {
                maxPerNode[node] = value;
                return;
            }

            int mid = (left + right) / 2;
            if (pos <= mid) {
                setValue(left, mid, node * 2, pos, value);
            } else {
                setValue(mid + 1, right, node * 2 + 1, pos, value);
            }
            maxPerNode[node] = Math.max(maxPerNode[node * 2], maxPerNode[node * 2 + 1]);
        }

    }

    static class Graph {

        private int N;


        private List<List<Integer>> edges;

        public Graph(int N) {
            this.N = N;

            edges = new ArrayList<>(N + 1);
            for (int node = 0; node <= N; node++) {
                edges.add(new ArrayList<>());
            }
        }

        public List<Integer> getEdges(int node) {
            return edges.get(node);
        }

        public void addEdge(int from, int to) {
            edges.get(from).add(to);
            edges.get(to).add(from);
        }

        public int getNumberOfNodes() {
            return N;
        }
    }

    static class InputReader implements AutoCloseable {

        private BufferedReader bufferedReader;

        private StringTokenizer stringTokenizer;

        public InputReader(InputStream inputStream) {
            bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
            stringTokenizer = null;
        }

        private String nextToken() {
            if (stringTokenizer == null || !stringTokenizer.hasMoreTokens()) {
                try {
                    stringTokenizer = new StringTokenizer(bufferedReader.readLine());
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }

            return stringTokenizer.nextToken();
        }

        public int nextInt() {
            return Integer.parseInt(nextToken());
        }

        @Override
        public void close() throws IOException {
            bufferedReader.close();
        }
    }
}