Cod sursa(job #2231599)

Utilizator radustn92Radu Stancu radustn92 Data 15 august 2018 02:34:57
Problema Heavy Path Decomposition Scor 0
Compilator java Status done
Runda Arhiva educationala Marime 12.42 kb
import java.io.*;
import java.util.*;

public class Main {

    public static void main(String[] args) throws IOException {
        InputStream inputStream = new FileInputStream("heavypath.in");
        OutputStream outputStream = new FileOutputStream("heavypath.out");

        try (InputReader inputReader = new InputReader(inputStream);
            PrintWriter printWriter = new PrintWriter(outputStream)) {
            int N = inputReader.nextInt();
            int M = inputReader.nextInt();
            int[] valuesPerNode = new int[N + 1];

            for (int node = 1; node <= N; node++) {
                valuesPerNode[node] = inputReader.nextInt();
            }

            Graph graph = new Graph(N);
            for (int edgeNo = 1; edgeNo < N; edgeNo++) {
                graph.addEdge(inputReader.nextInt(), inputReader.nextInt());
            }

            HeavyPathDecomp heavyPathDecomp = new HeavyPathDecomp(graph, valuesPerNode);

            int type, x, y;
            for (int queryNo = 1; queryNo <= M; queryNo++) {
                type = inputReader.nextInt();
                x = inputReader.nextInt();
                y = inputReader.nextInt();

                switch (type) {
                    case 0:
                        heavyPathDecomp.setValue(x, y);
                        break;
                    case 1:
                        printWriter.println(heavyPathDecomp.getMaxOnPath(x, y));
                        break;
                    default:
                        throw new RuntimeException("Unexpected type: " + type);
                }
            }
        }
    }

    static class HeavyPathDecomp {

        private Graph graph;

        private LcaSolver lcaSolver;

        private List<List<Integer>> paths;

        private int[] nodePathIdx;

        private int[] nodePosInPath;

        private int[] childrenPerNode;

        private int[] parent;

        private SegmentTree[] segmentTrees;

        public HeavyPathDecomp(Graph graph, int[] costPerNode) {
            this.graph = graph;
            this.lcaSolver = new LcaSolver(graph);
            this.nodePathIdx = new int[graph.getNumberOfNodes() + 1];
            this.nodePosInPath = new int[graph.getNumberOfNodes() + 1];
            this.parent = new int[graph.getNumberOfNodes() + 1];

            BitSet visited = new BitSet(graph.getNumberOfNodes() + 1);
            childrenPerNode = new int[graph.getNumberOfNodes() + 1];
            computeChildrenPerNode(1, visited);

            paths = new ArrayList<>();
            List<Integer> newPath = new ArrayList<>();
            paths.add(newPath);
            visited.clear();
            decomposeGraph(1, visited, newPath);

            segmentTrees = new SegmentTree[paths.size()];
            for (int pathIdx = 0; pathIdx < paths.size(); pathIdx++) {
                List<Integer> currPath = paths.get(pathIdx);
                int[] valuesPath = new int[currPath.size() + 1];
                for (int pathEntryIdx = 0; pathEntryIdx < currPath.size(); pathEntryIdx++) {
                    valuesPath[pathEntryIdx + 1] = costPerNode[currPath.get(pathEntryIdx)];
                }

                segmentTrees[pathIdx] = new SegmentTree(currPath.size(), valuesPath);
            }
        }

        public void setValue(int node, int newValue) {
            int pathIdx = nodePathIdx[node];
            int posInPath = nodePosInPath[node];
            segmentTrees[pathIdx].setValue(posInPath + 1, newValue);
        }

        public int getMaxOnPath(int firstNode, int secondNode) {
            int lca = lcaSolver.getLca(firstNode, secondNode);

            return Math.max(
                    getMaxOnAncestorPath(firstNode, lca),
                    getMaxOnAncestorPath(secondNode, lca));
        }

        private int getMaxOnAncestorPath(int node, int ancestor) {
            int ans = Integer.MIN_VALUE, pathIdx;
            while (nodePathIdx[node] != nodePathIdx[ancestor]) {
                pathIdx = nodePathIdx[node];
                int branchingNode = parent[paths.get(pathIdx).get(0)];
                ans = Math.max(ans, segmentTrees[pathIdx].getMax(1, nodePosInPath[node] + 1));

                node = branchingNode;
            }

            pathIdx = nodePathIdx[node];
            ans = Math.max(ans, segmentTrees[pathIdx].getMax(nodePosInPath[ancestor] + 1, nodePosInPath[node] + 1));

            return ans;
        }

        private void computeChildrenPerNode(int node, BitSet visited) {
            childrenPerNode[node] = 1;
            visited.set(node);
            for (int neighbour : graph.getEdges(node)) {
                if (!visited.get(neighbour)) {
                    parent[neighbour] = node;
                    computeChildrenPerNode(neighbour, visited);
                    childrenPerNode[node] += node;
                }
            }
        }

        private void decomposeGraph(int node, BitSet visited, List<Integer> currPath) {
            visited.set(node);
            currPath.add(node);
            nodePathIdx[node] = paths.size() - 1;
            nodePosInPath[node] = currPath.size() - 1;

            int bestChild = -1, bestChildWeight = -1;
            for (int child : graph.getEdges(node)) {
                if (!visited.get(child)) {
                    if (childrenPerNode[child] > bestChildWeight) {
                        bestChildWeight = childrenPerNode[child];
                        bestChild = child;
                    }
                }
            }

            if (bestChild == -1) {
                return;
            }

            for (int child : graph.getEdges(node)) {
                if (!visited.get(child)) {
                    if (child == bestChild) {
                        decomposeGraph(child, visited, currPath);
                    } else {
                        List<Integer> newPath = new ArrayList<>();
                        paths.add(newPath);
                        decomposeGraph(child, visited, newPath);
                    }
                }
            }
        }
    }

    static class SegmentTree {

        private int N;

        private int[] maxPerNode;

        public SegmentTree(int N, int[] values) {
            this.N = N;
            int reqSize = 1;
            while (reqSize < N) {
                reqSize *= N;
            }
            maxPerNode = new int[reqSize * 2 + 1];

            cstr(1, N, 1, values);
        }

        private void cstr(int left, int right, int node, int[] values) {
            if (left == right) {
                maxPerNode[node] = values[left];
                return;
            }

            int mid = (left + right) / 2;
            cstr(left, mid, node * 2, values);
            cstr(mid + 1, right, node * 2 + 1, values);
            maxPerNode[node] = Math.max(maxPerNode[node * 2], maxPerNode[node * 2 + 1]);
        }

        public void setValue(int pos, int value) {
            setValue(1, N, 1, pos, value);
        }

        public int getMax(int from, int to) {
            return getMax(1, N, 1, from, to);
        }

        private int getMax(int left, int right, int node, int from, int to) {
            if (to < left || right < from) {
                return Integer.MIN_VALUE;
            }
            if (from <= left && right <= to) {
                return maxPerNode[node];
            }

            int mid = (left + right) / 2;
            return Math.max(
                    getMax(left, mid, node * 2, from, to),
                    getMax(mid + 1, right, node * 2 + 1, from, to));
        }

        private void setValue(int left, int right, int node, int pos, int value) {
            if (left == right) {
                maxPerNode[node] = value;
                return;
            }

            int mid = (left + right) / 2;
            if (pos <= mid) {
                setValue(left, mid, node * 2, pos, value);
            } else {
                setValue(mid + 1, right, node * 2 + 1, pos, value);
            }
            maxPerNode[node] = Math.max(maxPerNode[node * 2], maxPerNode[node * 2 + 1]);
        }

    }

    static class LcaSolver {

        private Graph graph;

        private int[] depth;

        private int[] eulerTraversal;

        private int eulerTraversalLastPos;

        private int[] posInEulerTraversal;

        private int[] log2;

        private int[][] lowestNodeInterv;

        public LcaSolver(Graph graph) {
            this.graph = graph;
            depth = new int[graph.getNumberOfNodes() + 1];
            Arrays.fill(depth, 0);

            eulerTraversal = new int[graph.getNumberOfNodes() * 2 + 1];
            eulerTraversalLastPos = 0;
            posInEulerTraversal = new int[graph.getNumberOfNodes() + 1];

            BitSet visited = new BitSet(graph.getNumberOfNodes() + 1);
            eulerTraversal(1, visited);

            log2 = new int[eulerTraversalLastPos + 1];
            for (int i = 2; i <= eulerTraversalLastPos; i++) {
                log2[i] = log2[i >> 1] + 1;
            }

            lowestNodeInterv = new int[log2[eulerTraversalLastPos] + 1][eulerTraversalLastPos + 1];
            prepareLcaStructure();
        }

        private void eulerTraversal(int node, BitSet visited) {
            visited.set(node);
            eulerTraversal[++eulerTraversalLastPos] = node;
            posInEulerTraversal[node] = eulerTraversalLastPos;

            for (int neighbour : graph.getEdges(node)) {
                if (!visited.get(neighbour)) {
                    depth[neighbour] = depth[node] + 1;
                    eulerTraversal(neighbour, visited);
                    eulerTraversal[++eulerTraversalLastPos] = node;
                }
            }
        }

        private int getLowestNode(int node1, int node2) {
            return depth[node1] < depth[node2] ? node1 : node2;
        }

        private void prepareLcaStructure() {
            for (int pos = 1; pos <= eulerTraversalLastPos; pos++) {
                lowestNodeInterv[0][pos] = eulerTraversal[pos];
            }

            for (int lvl = 1, currInterv = 2; lvl <= log2[eulerTraversalLastPos]; lvl++, currInterv *= 2) {
                for (int pos = 1; pos <= eulerTraversalLastPos - currInterv + 1; pos++) {
                    lowestNodeInterv[lvl][pos] = getLowestNode(
                            lowestNodeInterv[lvl - 1][pos], lowestNodeInterv[lvl - 1][pos + currInterv / 2]);
                }
            }
        }

        private int getLowestInInterv(int from, int to) {
            int logInterv = log2[to - from + 1];
            return getLowestNode(
                    lowestNodeInterv[logInterv][from],
                    lowestNodeInterv[logInterv][to - (1 << logInterv) + 1]);
        }

        public int getLca(int node1, int node2) {
            if (posInEulerTraversal[node1] < posInEulerTraversal[node2]) {
                return getLowestInInterv(posInEulerTraversal[node1], posInEulerTraversal[node2]);
            }

            return getLowestInInterv(posInEulerTraversal[node2], posInEulerTraversal[node1]);
        }
    }

    static class Graph {

        private int N;


        private List<List<Integer>> edges;

        public Graph(int N) {
            this.N = N;

            edges = new ArrayList<>(N + 1);
            for (int node = 0; node <= N; node++) {
                edges.add(new ArrayList<>());
            }
        }

        public List<Integer> getEdges(int node) {
            return edges.get(node);
        }

        public void addEdge(int from, int to) {
            edges.get(from).add(to);
            edges.get(to).add(from);
        }

        public int getNumberOfNodes() {
            return N;
        }
    }

    static class InputReader implements AutoCloseable {

        private BufferedReader bufferedReader;

        private StringTokenizer stringTokenizer;

        public InputReader(InputStream inputStream) {
            bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
            stringTokenizer = null;
        }

        private String nextToken() {
            if (stringTokenizer == null || !stringTokenizer.hasMoreTokens()) {
                try {
                    stringTokenizer = new StringTokenizer(bufferedReader.readLine());
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }

            return stringTokenizer.nextToken();
        }

        public int nextInt() {
            return Integer.parseInt(nextToken());
        }

        @Override
        public void close() throws IOException {
            bufferedReader.close();
        }
    }
}