/*
 * Decompiled with CFR 0.152.
 */
package de.citec.ml.rng;

import de.citec.ml.rng.ArrayFunctions;
import de.citec.ml.rng.CheckFunctions;
import de.citec.ml.rng.RNGErrorModel;
import de.citec.ml.rng.RNGModel;
import de.citec.ml.rng.RNGModelImpl;
import de.citec.ml.rng.RelationalDistances;
import java.util.ArrayList;
import java.util.Random;

public final class RelationalNeuralGas {
    private static final double APPROX_THRESHOLD = -Math.log(0.001);

    private RelationalNeuralGas() {
    }

    public static RNGErrorModel train(double[][] D, int K) {
        return RelationalNeuralGas.train(D, K, 30);
    }

    public static RNGErrorModel train(double[][] D, int K, int T) {
        IllegalArgumentException ex = CheckFunctions.checkDissimilarityMatrix(D);
        if (ex != null) {
            throw ex;
        }
        if (K < 1) {
            throw new IllegalArgumentException("The number of prototypes must be positive!");
        }
        if (T < 1) {
            throw new IllegalArgumentException("The number of epochs must be positive!");
        }
        int m = D.length;
        double[][] Alpha = new double[K][m];
        Random rand = new Random();
        for (int k = 0; k < K; ++k) {
            int i;
            double nrml = 0.0;
            for (i = 0; i < m; ++i) {
                Alpha[k][i] = rand.nextDouble();
                nrml += Alpha[k][i];
            }
            i = 0;
            while (i < m) {
                double[] dArray = Alpha[k];
                int n = i++;
                dArray[n] = dArray[n] / nrml;
            }
        }
        double[] errors = new double[T + 1];
        for (int t = 0; t < T; ++t) {
            double exponent;
            int max_rank;
            double invLambda = 1.0 / RelationalNeuralGas.getLambda(t, K, T);
            double[] hs = new double[K];
            hs[0] = 1.0;
            for (max_rank = 1; max_rank < K && !((exponent = invLambda * (double)max_rank) > APPROX_THRESHOLD); ++max_rank) {
                hs[max_rank] = Math.exp(-exponent);
            }
            double[] Z = RelationalDistances.getNormalizationTerms(D, Alpha);
            double[][] Dp = RelationalDistances.getDistancesToPrototypes(D, Alpha, Z);
            double[][] H = new double[K][m];
            for (int i = 0; i < m; ++i) {
                int[] ranking = new int[max_rank];
                for (int k = 1; k < K; ++k) {
                    int rank;
                    if (k >= max_rank && !(Dp[i][k] < Dp[i][ranking[max_rank - 1]])) continue;
                    int n = rank = k < max_rank ? k : max_rank - 1;
                    while (rank > 0 && Dp[i][k] < Dp[i][ranking[rank - 1]]) {
                        if (rank < max_rank) {
                            ranking[rank] = ranking[rank - 1];
                        }
                        --rank;
                    }
                    if (rank >= max_rank) continue;
                    ranking[rank] = k;
                }
                for (int r = 0; r < max_rank; ++r) {
                    H[ranking[r]][i] = hs[r];
                }
            }
            for (int k = 0; k < K; ++k) {
                int i;
                double nrml = 0.0;
                for (i = 0; i < m; ++i) {
                    if (H[k][i] == 0.0) continue;
                    nrml += H[k][i];
                    int n = t;
                    errors[n] = errors[n] + H[k][i] * Dp[i][k];
                }
                for (i = 0; i < m; ++i) {
                    Alpha[k][i] = H[k][i] / nrml;
                }
            }
        }
        double[] Z = RelationalDistances.getNormalizationTerms(D, Alpha);
        double[][] Dp = RelationalDistances.getDistancesToPrototypes(D, Alpha, Z);
        for (int i = 0; i < m; ++i) {
            int k = ArrayFunctions.getMinIdx(Dp[i]);
            int n = T;
            errors[n] = errors[n] + Dp[i][k];
        }
        return new RNGModelImpl(Alpha, Dp, Z, errors);
    }

    public static double getLambda(int t, int K, int T) {
        if (T < 1) {
            throw new IllegalArgumentException("The number of epochs must be postive!");
        }
        if (K < 1) {
            throw new IllegalArgumentException("The number of prototypes must be positive!");
        }
        if (t == T - 1) {
            return 0.01;
        }
        double lambda_0 = (double)K / 2.0;
        if (t == 0) {
            return lambda_0;
        }
        return lambda_0 * Math.pow(0.01 / lambda_0, (double)t / (double)(T - 1));
    }

    public static int[] getAssignments(RNGModel model) {
        int m = model.getNumberOfDatapoints();
        int[] assignments = new int[m];
        double[][] Dp = model.getDistancesToPrototypes();
        for (int i = 0; i < m; ++i) {
            assignments[i] = ArrayFunctions.getMinIdx(Dp[i]);
        }
        return assignments;
    }

    public static int[] classify(double[][] D, RNGModel model) {
        int n = D.length;
        int[] classification = new int[n];
        for (int j = 0; j < n; ++j) {
            classification[j] = RelationalNeuralGas.classify(D[j], model);
        }
        return classification;
    }

    public static int classify(double[] d, RNGModel model) {
        double[] dp = RelationalDistances.getDistancesToPrototypes(d, model);
        return ArrayFunctions.getMinIdx(dp);
    }

    public static int[][] getClusterMembers(RNGModel model) {
        int K = model.getNumberOfPrototypes();
        int[][] members = new int[K][];
        for (int k = 0; k < K; ++k) {
            members[k] = RelationalNeuralGas.getClusterMembers(model, k);
        }
        return members;
    }

    public static int[] getClusterMembers(RNGModel model, int k) {
        return RelationalNeuralGas.getClusterMembers(model, k, RelationalNeuralGas.getAssignments(model));
    }

    public static int[] getClusterMembers(RNGModel model, int k, int[] assignments) {
        ArrayList<Integer> members = new ArrayList<Integer>();
        for (int i = 0; i < assignments.length; ++i) {
            if (assignments[i] != k) continue;
            members.add(i);
        }
        int[] memberArr = new int[members.size()];
        for (int i = 0; i < members.size(); ++i) {
            memberArr[i] = (Integer)members.get(i);
        }
        return memberArr;
    }

    public static int[] getExamplars(RNGModel model) {
        int m = model.getNumberOfDatapoints();
        int K = model.getNumberOfPrototypes();
        double[][] Dp = model.getDistancesToPrototypes();
        int[] exemplars = new int[K];
        for (int k = 0; k < K; ++k) {
            for (int i = 1; i < m; ++i) {
                if (!(Dp[i][k] < Dp[exemplars[k]][k])) continue;
                exemplars[k] = i;
            }
        }
        return exemplars;
    }
}

