package net.myrrix.online.som;

import com.google.common.base.Preconditions;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import net.myrrix.common.LangUtils;
import net.myrrix.common.collection.FastByIDMap;
import net.myrrix.common.math.SimpleVectorMath;
import net.myrrix.common.random.RandomManager;
import net.myrrix.common.random.RandomUtils;
import org.apache.commons.math3.distribution.PascalDistribution;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.Pair;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:net/myrrix/online/som/SelfOrganizingMaps.class */
public final class SelfOrganizingMaps {
    private static final Logger log = LoggerFactory.getLogger(SelfOrganizingMaps.class);
    public static final double DEFAULT_MIN_DECAY = 1.0E-5d;
    public static final double DEFAULT_INIT_LEARNING_RATE = 0.5d;
    private final double minDecay;
    private final double initLearningRate;

    public SelfOrganizingMaps() {
        this(1.0E-5d, 0.5d);
    }

    public SelfOrganizingMaps(double d, double d2) {
        Preconditions.checkArgument(d > CMAESOptimizer.DEFAULT_STOPFITNESS, "Min decay must be positive: {}", Double.valueOf(d));
        Preconditions.checkArgument(d2 > CMAESOptimizer.DEFAULT_STOPFITNESS && d2 <= 1.0d, "Learning rate should be in (0,1]: {}", Double.valueOf(d2));
        this.minDecay = d;
        this.initLearningRate = d2;
    }

    public Node[][] buildSelfOrganizedMap(FastByIDMap<float[]> fastByIDMap, int i) {
        return buildSelfOrganizedMap(fastByIDMap, i, Double.NaN);
    }

    public Node[][] buildSelfOrganizedMap(FastByIDMap<float[]> fastByIDMap, int i, double d) {
        Preconditions.checkNotNull(fastByIDMap);
        Preconditions.checkArgument(!fastByIDMap.isEmpty());
        Preconditions.checkArgument(i > 0);
        Preconditions.checkArgument(Double.isNaN(d) || (d > CMAESOptimizer.DEFAULT_STOPFITNESS && d <= 1.0d));
        if (Double.isNaN(d)) {
            double size = fastByIDMap.size() / (i * i);
            d = size > 1.0d ? 1.0d / size : 1.0d;
        }
        log.debug("Sampling rate: {}", Double.valueOf(d));
        Node[][] buildInitialMap = buildInitialMap(fastByIDMap, FastMath.min(i, (int) FastMath.sqrt(fastByIDMap.size() * d)));
        sketchMapParallel(fastByIDMap, d, buildInitialMap);
        for (Node[] nodeArr : buildInitialMap) {
            for (Node node : nodeArr) {
                node.clearAssignedIDs();
            }
        }
        assignVectorsParallel(fastByIDMap, d, buildInitialMap);
        sortMembers(buildInitialMap);
        buildProjections(fastByIDMap.entrySet().iterator().next().getValue().length, buildInitialMap);
        return buildInitialMap;
    }

    private void sketchMapParallel(FastByIDMap<float[]> fastByIDMap, double d, Node[][] nodeArr) {
        double size = (fastByIDMap.size() * d) / Math.log(nodeArr.length);
        int i = 0;
        Iterator<FastByIDMap.MapEntry<float[]>> it = fastByIDMap.entrySet().iterator();
        while (it.hasNext()) {
            float[] value = it.next().getValue();
            double exp = FastMath.exp((-i) / size);
            i++;
            if (exp < this.minDecay) {
                return;
            }
            int[] findBestMatchingUnit = findBestMatchingUnit(value, nodeArr);
            if (findBestMatchingUnit != null) {
                updateNeighborhood(nodeArr, value, findBestMatchingUnit[0], findBestMatchingUnit[1], exp);
            }
        }
    }

    private static void assignVectorsParallel(FastByIDMap<float[]> fastByIDMap, double d, Node[][] nodeArr) {
        boolean z = d < 1.0d;
        RandomGenerator random = RandomManager.getRandom();
        for (FastByIDMap.MapEntry<float[]> mapEntry : fastByIDMap.entrySet()) {
            if (!z || random.nextDouble() <= d) {
                float[] value = mapEntry.getValue();
                int[] findBestMatchingUnit = findBestMatchingUnit(value, nodeArr);
                if (findBestMatchingUnit != null) {
                    Node node = nodeArr[findBestMatchingUnit[0]][findBestMatchingUnit[1]];
                    float[] center = node.getCenter();
                    node.addAssignedID(new Pair<>(Double.valueOf(SimpleVectorMath.dot(value, center) / (SimpleVectorMath.norm(center) * SimpleVectorMath.norm(value))), Long.valueOf(mapEntry.getKey())));
                }
            }
        }
    }

    private static Node[][] buildInitialMap(FastByIDMap<float[]> fastByIDMap, int i) {
        double size = (i * i) / fastByIDMap.size();
        PascalDistribution pascalDistribution = size >= 1.0d ? null : new PascalDistribution(RandomManager.getRandom(), 1, size);
        LongPrimitiveIterator keySetIterator = fastByIDMap.keySetIterator();
        Node[][] nodeArr = new Node[i][i];
        for (Node[] nodeArr2 : nodeArr) {
            for (int i2 = 0; i2 < i; i2++) {
                if (pascalDistribution != null) {
                    keySetIterator.skip(pascalDistribution.sample());
                }
                while (!keySetIterator.hasNext()) {
                    keySetIterator = fastByIDMap.keySetIterator();
                    Preconditions.checkState(keySetIterator.hasNext());
                    if (pascalDistribution != null) {
                        keySetIterator.skip(pascalDistribution.sample());
                    }
                }
                nodeArr2[i2] = new Node(fastByIDMap.get(keySetIterator.nextLong()));
            }
        }
        return nodeArr;
    }

    private static int[] findBestMatchingUnit(float[] fArr, Node[][] nodeArr) {
        int length = nodeArr.length;
        double norm = SimpleVectorMath.norm(fArr);
        double d = Double.NEGATIVE_INFINITY;
        int i = -1;
        int i2 = -1;
        for (int i3 = 0; i3 < length; i3++) {
            Node[] nodeArr2 = nodeArr[i3];
            for (int i4 = 0; i4 < length; i4++) {
                float[] center = nodeArr2[i4].getCenter();
                double dot = SimpleVectorMath.dot(fArr, center) / (SimpleVectorMath.norm(center) * norm);
                if (LangUtils.isFinite(dot) && dot > d) {
                    d = dot;
                    i = i3;
                    i2 = i4;
                }
            }
        }
        if (i == -1 || i2 == -1) {
            return null;
        }
        return new int[]{i, i2};
    }

    private void updateNeighborhood(Node[][] nodeArr, float[] fArr, int i, int i2, double d) {
        int length = nodeArr.length;
        double d2 = length * d;
        int max = FastMath.max(0, (int) FastMath.floor(i - d2));
        int min = FastMath.min(length, (int) FastMath.ceil(i + d2));
        int max2 = FastMath.max(0, (int) FastMath.floor(i2 - d2));
        int min2 = FastMath.min(length, (int) FastMath.ceil(i2 + d2));
        for (int i3 = max; i3 < min; i3++) {
            Node[] nodeArr2 = nodeArr[i3];
            for (int i4 = max2; i4 < min2; i4++) {
                double d3 = this.initLearningRate * d;
                double distance = distance(i3, i4, i, i2);
                double exp = d3 * FastMath.exp((-(distance * distance)) / ((2.0d * d2) * d2));
                float[] center = nodeArr2[i4].getCenter();
                int length2 = center.length;
                for (int i5 = 0; i5 < length2; i5++) {
                    int i6 = i5;
                    center[i6] = center[i6] + ((float) (exp * (fArr[i5] - center[i5])));
                }
            }
        }
    }

    private static void sortMembers(Node[][] nodeArr) {
        for (Node[] nodeArr2 : nodeArr) {
            for (Node node : nodeArr2) {
                Collections.sort(node.getAssignedIDs(), new Comparator<Pair<Double, Long>>() { // from class: net.myrrix.online.som.SelfOrganizingMaps.1
                    @Override // java.util.Comparator
                    public int compare(Pair<Double, Long> pair, Pair<Double, Long> pair2) {
                        if (pair.getFirst().doubleValue() > pair2.getFirst().doubleValue()) {
                            return -1;
                        }
                        return pair.getFirst().doubleValue() < pair2.getFirst().doubleValue() ? 1 : 0;
                    }
                });
            }
        }
    }

    private static void buildProjections(int i, Node[][] nodeArr) {
        int length = nodeArr.length;
        float[] fArr = new float[i];
        for (Node[] nodeArr2 : nodeArr) {
            for (int i2 = 0; i2 < length; i2++) {
                add(nodeArr2[i2].getCenter(), fArr);
            }
        }
        divide(fArr, length * length);
        RandomGenerator random = RandomManager.getRandom();
        float[] randomUnitVector = RandomUtils.randomUnitVector(i, random);
        float[] randomUnitVector2 = RandomUtils.randomUnitVector(i, random);
        float[] randomUnitVector3 = RandomUtils.randomUnitVector(i, random);
        for (Node[] nodeArr3 : nodeArr) {
            for (int i3 = 0; i3 < length; i3++) {
                float[] fArr2 = (float[]) nodeArr3[i3].getCenter().clone();
                subtract(fArr, fArr2);
                double norm = SimpleVectorMath.norm(fArr2);
                float[] projection3D = nodeArr3[i3].getProjection3D();
                projection3D[0] = (float) ((1.0d + (SimpleVectorMath.dot(fArr2, randomUnitVector) / norm)) / 2.0d);
                projection3D[1] = (float) ((1.0d + (SimpleVectorMath.dot(fArr2, randomUnitVector2) / norm)) / 2.0d);
                projection3D[2] = (float) ((1.0d + (SimpleVectorMath.dot(fArr2, randomUnitVector3) / norm)) / 2.0d);
            }
        }
    }

    private static void add(float[] fArr, float[] fArr2) {
        int length = fArr.length;
        for (int i = 0; i < length; i++) {
            int i2 = i;
            fArr2[i2] = fArr2[i2] + fArr[i];
        }
    }

    private static void subtract(float[] fArr, float[] fArr2) {
        int length = fArr.length;
        for (int i = 0; i < length; i++) {
            int i2 = i;
            fArr2[i2] = fArr2[i2] - fArr[i];
        }
    }

    private static void divide(float[] fArr, float f) {
        int length = fArr.length;
        for (int i = 0; i < length; i++) {
            int i2 = i;
            fArr[i2] = fArr[i2] / f;
        }
    }

    private static double distance(int i, int i2, int i3, int i4) {
        int i5 = i - i3;
        int i6 = i2 - i4;
        return FastMath.sqrt((i5 * i5) + (i6 * i6));
    }
}
