/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.knn.utils.indices;

import java.util.Collections;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;
import org.apache.ignite.ml.knn.utils.PointWithDistance;
import org.apache.ignite.ml.knn.utils.PointWithDistanceUtil;
import org.apache.ignite.ml.knn.utils.indices.SpatialIndex;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.structures.LabeledVector;

public class KDTreeSpatialIndex<L>
implements SpatialIndex<L> {
    private final DistanceMeasure distanceMeasure;
    private TreeNode root;

    public KDTreeSpatialIndex(List<LabeledVector<L>> data, DistanceMeasure distanceMeasure) {
        this.distanceMeasure = distanceMeasure;
        data.forEach(dataPnt -> {
            this.root = this.add(this.root, (LabeledVector<L>)dataPnt);
        });
    }

    @Override
    public List<LabeledVector<L>> findKClosest(int k, Vector pnt) {
        if (k <= 0) {
            throw new IllegalArgumentException("Number of neighbours should be positive.");
        }
        PriorityQueue<PointWithDistance<L>> heap = new PriorityQueue<PointWithDistance<L>>(Collections.reverseOrder());
        this.findKClosest(pnt, this.root, 0, heap, k);
        return PointWithDistanceUtil.transformToListOrdered(heap);
    }

    private void findKClosest(Vector pnt, TreeNode node, int splitDim, Queue<PointWithDistance<L>> heap, int k) {
        if (node == null) {
            return;
        }
        PointWithDistanceUtil.tryToAddIntoHeap(heap, k, node.val, this.distanceMeasure.compute(pnt, (Vector)node.val.features()));
        double pntPrj = pnt.get(splitDim);
        double splitPrj = node.val.get(splitDim);
        TreeNode primaryBranch = pntPrj > splitPrj ? node.right : node.left;
        TreeNode secondaryBranch = primaryBranch == node.right ? node.left : node.right;
        this.findKClosestInSplittedSpace(pnt, primaryBranch, secondaryBranch, (splitDim + 1) % pnt.size(), Math.abs(pntPrj - splitPrj), heap, k);
    }

    private void findKClosestInSplittedSpace(Vector pnt, TreeNode primaryBrach, TreeNode secondaryBranch, int splitDim, double distToPlane, Queue<PointWithDistance<L>> heap, int k) {
        this.findKClosest(pnt, primaryBrach, splitDim, heap, k);
        if (heap.size() < k || distToPlane < heap.peek().getDistance()) {
            this.findKClosest(pnt, secondaryBranch, splitDim, heap, k);
        }
    }

    private TreeNode add(TreeNode root, LabeledVector<L> val) {
        if (root == null) {
            return new TreeNode(val);
        }
        this.addIntoExistingTree(root, val);
        return root;
    }

    private void addIntoExistingTree(TreeNode node, LabeledVector<L> pnt) {
        int splitDim = 0;
        while (true) {
            if (pnt.get(splitDim) > node.val.get(splitDim)) {
                if (node.right == null) {
                    node.right = new TreeNode(pnt);
                    break;
                }
                node = node.right;
            } else {
                if (node.left == null) {
                    node.left = new TreeNode(pnt);
                    break;
                }
                node = node.left;
            }
            splitDim = (splitDim + 1) % pnt.size();
        }
    }

    private final class TreeNode {
        private final LabeledVector<L> val;
        private TreeNode left;
        private TreeNode right;

        TreeNode(LabeledVector<L> val) {
            this.val = val;
        }
    }
}

