/*
 * Decompiled with CFR 0.152.
 */
package de.bioforscher.singa.mathematics.algorithms.clustering;

import de.bioforscher.singa.mathematics.algorithms.clustering.Clustering;
import de.bioforscher.singa.mathematics.matrices.LabeledMatrix;
import de.bioforscher.singa.mathematics.matrices.LabeledRegularMatrix;
import de.bioforscher.singa.mathematics.matrices.Matrix;
import de.bioforscher.singa.mathematics.matrices.RegularMatrix;
import de.bioforscher.singa.mathematics.vectors.RegularVector;
import de.bioforscher.singa.mathematics.vectors.Vectors;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AffinityPropagation<DataType>
implements Clustering<DataType> {
    private static final Logger logger = LoggerFactory.getLogger(AffinityPropagation.class);
    private static final int MIN_STABLE_EPOCHS = 10;
    private final List<DataType> dataPoints;
    private final int dataSize;
    private final int maximalEpochs;
    private final LabeledMatrix<DataType> distanceMatrix;
    private final double lambda;
    private LabeledMatrix<DataType> similarityMatrix;
    private LabeledMatrix<DataType> availabilityMatrix;
    private LabeledMatrix<DataType> responsibilityMatrix;
    private int epoch;
    private List<List<DataType>> exemplarDecisions;
    private Map<DataType, List<DataType>> clusters;

    private AffinityPropagation(Builder<DataType> builder) {
        double[][] invertedValues;
        logger.info("affinity propagation initialized with {} data points", (Object)((Builder)builder).dataPoints.size());
        this.dataPoints = ((Builder)builder).dataPoints;
        this.dataSize = this.dataPoints.size();
        this.similarityMatrix = ((Builder)builder).matrix;
        this.lambda = ((Builder)builder).lambda;
        double selfSimilarity = ((Builder)builder).selfSimilarity;
        this.maximalEpochs = ((Builder)builder).maximalEpochs;
        this.checkInput(this.dataPoints, this.similarityMatrix);
        if (((Builder)builder).distance) {
            this.distanceMatrix = this.similarityMatrix;
            invertedValues = ((Matrix)this.similarityMatrix.additivelyInvert()).getElements();
            for (int i = 0; i < invertedValues.length; ++i) {
                invertedValues[i][i] = -selfSimilarity;
            }
        } else {
            invertedValues = this.similarityMatrix.getElements();
            this.distanceMatrix = new LabeledRegularMatrix<DataType>(new RegularMatrix(invertedValues).additivelyInvert().getElements());
            this.distanceMatrix.setRowLabels(this.dataPoints);
            this.distanceMatrix.setColumnLabels(this.dataPoints);
            for (int i = 0; i < invertedValues.length; ++i) {
                invertedValues[i][i] = selfSimilarity;
            }
        }
        this.similarityMatrix = new LabeledRegularMatrix<DataType>(invertedValues);
        this.similarityMatrix.setRowLabels(this.dataPoints);
        this.similarityMatrix.setColumnLabels(this.dataPoints);
        this.initialize();
        this.run();
    }

    public static <DataType> DataStep<DataType> create() {
        return new Builder();
    }

    @Override
    public List<DataType> getDataPoints() {
        return this.dataPoints;
    }

    @Override
    public LabeledMatrix<DataType> getDistanceMatrix() {
        return this.distanceMatrix;
    }

    @Override
    public Map<DataType, List<DataType>> getClusters() {
        return this.clusters;
    }

    private void initialize() {
        this.responsibilityMatrix = new LabeledRegularMatrix<DataType>(new double[this.dataSize][this.dataSize]);
        this.responsibilityMatrix.setRowLabels(this.dataPoints);
        this.responsibilityMatrix.setColumnLabels(this.dataPoints);
        this.availabilityMatrix = new LabeledRegularMatrix<DataType>(new double[this.dataSize][this.dataSize]);
        this.availabilityMatrix.setRowLabels(this.dataPoints);
        this.availabilityMatrix.setColumnLabels(this.dataPoints);
        this.exemplarDecisions = new ArrayList<List<DataType>>();
    }

    private void checkInput(List<DataType> data, LabeledMatrix<DataType> matrix) {
        List<DataType> rowLabels = matrix.getRowLabels();
        Objects.requireNonNull(rowLabels);
        if (!data.equals(rowLabels)) {
            throw new IllegalArgumentException("The data does not match the labels of the provided matrix.");
        }
    }

    private void run() {
        while (this.epoch < this.maximalEpochs) {
            this.updateResponsibilities();
            this.updateAvailabilities();
            this.assignExemplars();
            this.assignClusters();
            if (this.isConverged()) break;
            ++this.epoch;
            if (this.epoch != this.maximalEpochs) continue;
            logger.info("terminating after reaching maximal epoch limit");
        }
        logger.info("obtained {} clusters", (Object)this.clusters.size());
    }

    private void assignClusters() {
        this.clusters = new HashMap<DataType, List<DataType>>();
        for (DataType currentDataPoint : this.dataPoints) {
            double bestSimilarity = -1.7976931348623157E308;
            Object bestExemplar = null;
            for (DataType exemplar : this.exemplarDecisions.get(this.exemplarDecisions.size() - 1)) {
                if (exemplar.equals(currentDataPoint)) {
                    bestExemplar = currentDataPoint;
                    break;
                }
                double similarity = this.similarityMatrix.getValueForLabel(currentDataPoint, exemplar);
                if (!(similarity > bestSimilarity)) continue;
                bestSimilarity = similarity;
                bestExemplar = exemplar;
            }
            if (this.clusters.containsKey(bestExemplar)) {
                this.clusters.get(bestExemplar).add(currentDataPoint);
                continue;
            }
            ArrayList<DataType> cluster = new ArrayList<DataType>();
            cluster.add(currentDataPoint);
            this.clusters.put(bestExemplar, cluster);
        }
    }

    private void assignExemplars() {
        ArrayList<DataType> exemplars = new ArrayList<DataType>();
        LabeledRegularMatrix<DataType> ra = new LabeledRegularMatrix<DataType>(((Matrix)this.responsibilityMatrix.add(this.availabilityMatrix)).getElements());
        ra.setRowLabels(this.dataPoints);
        ra.setColumnLabels(this.dataPoints);
        for (int i = 0; i < ra.getRowDimension(); ++i) {
            if (!(ra.getElement(i, i) > 0.0)) continue;
            exemplars.add(this.dataPoints.get(i));
        }
        this.exemplarDecisions.add(exemplars);
    }

    private void updateResponsibilities() {
        double[][] updatedResponsibilities = new double[this.dataSize][this.dataSize];
        Matrix as = this.similarityMatrix.add(this.availabilityMatrix);
        for (int i = 0; i < this.dataPoints.size(); ++i) {
            for (int j = 0; j < this.dataPoints.size(); ++j) {
                double finalValue;
                double[] row = Arrays.copyOf(as.getRow(i).getElements(), as.getRow(i).getElements().length);
                row[j] = -1.7976931348623157E308;
                RegularVector rowVector = new RegularVector(row);
                int positionOfMax = Vectors.getIndexWithMaximalElement(rowVector);
                double maxValue = rowVector.getElement(positionOfMax);
                updatedResponsibilities[i][j] = finalValue = this.similarityMatrix.getElement(i, j) - maxValue;
            }
        }
        LabeledRegularMatrix updatedResponsibilityMatrix = new LabeledRegularMatrix(updatedResponsibilities);
        this.responsibilityMatrix = this.applyLambda(updatedResponsibilityMatrix, this.responsibilityMatrix);
    }

    private void updateAvailabilities() {
        double[][] updatedAvailabilities = new double[this.dataSize][this.dataSize];
        for (int i = 0; i < this.dataSize; ++i) {
            for (int j = 0; j < this.dataSize; ++j) {
                double[] column = this.responsibilityMatrix.getColumn(i).getElements();
                double sum = 0.0;
                for (int k = 0; k < column.length; ++k) {
                    if (k == i || k == j || !(column[k] > 0.0)) continue;
                    sum += column[k];
                }
                updatedAvailabilities[j][i] = i == j ? sum : ((sum += this.responsibilityMatrix.getElement(i, i)) < 0.0 ? sum : 0.0);
            }
        }
        LabeledRegularMatrix updatedAvailabilityMatrix = new LabeledRegularMatrix(updatedAvailabilities);
        this.availabilityMatrix = this.applyLambda(updatedAvailabilityMatrix, this.availabilityMatrix);
    }

    private LabeledMatrix<DataType> applyLambda(LabeledMatrix<DataType> updatedMatrix, LabeledMatrix<DataType> oldMatrix) {
        LabeledRegularMatrix<DataType> dampenedMatrix = new LabeledRegularMatrix<DataType>(updatedMatrix.multiply(1.0 - this.lambda).add(oldMatrix.multiply(this.lambda)).getElements());
        dampenedMatrix.setRowLabels(this.dataPoints);
        dampenedMatrix.setColumnLabels(this.dataPoints);
        return dampenedMatrix;
    }

    private boolean isConverged() {
        if (this.exemplarDecisions.size() < 10) {
            return false;
        }
        boolean converged = true;
        int lowerBound = this.exemplarDecisions.size() - 10;
        for (int i = this.exemplarDecisions.size() - 1; i > lowerBound; --i) {
            if (this.exemplarDecisions.get(i).equals(this.exemplarDecisions.get(i - 1))) continue;
            converged = false;
        }
        if (converged) {
            logger.debug("converged in epoch {}/{}", (Object)this.epoch, (Object)this.maximalEpochs);
        } else {
            logger.debug("not converged in epoch {}/{}", (Object)this.epoch, (Object)this.maximalEpochs);
        }
        return converged;
    }

    public LabeledMatrix<DataType> getSimilarityMatrix() {
        return this.similarityMatrix;
    }

    public void setSimilarityMatrix(LabeledMatrix<DataType> similarityMatrix) {
        this.similarityMatrix = similarityMatrix;
    }

    public LabeledMatrix<DataType> getAvailabilityMatrix() {
        return this.availabilityMatrix;
    }

    public void setAvailabilityMatrix(LabeledMatrix<DataType> availabilityMatrix) {
        this.availabilityMatrix = availabilityMatrix;
    }

    public LabeledMatrix<DataType> getResponsibilityMatrix() {
        return this.responsibilityMatrix;
    }

    public void setResponsibilityMatrix(LabeledMatrix<DataType> responsibilityMatrix) {
        this.responsibilityMatrix = responsibilityMatrix;
    }

    public static class Builder<DataType>
    implements DataStep<DataType>,
    MatrixStep<DataType>,
    DistanceStep<DataType>,
    ParameterStep<DataType> {
        private static final double DEFAULT_SELF_SIMILARITY = -0.5;
        private static final double DEFAULT_LAMBDA = 0.5;
        private static final int DEFAULT_MAXIMAL_EPOCHS = 1000;
        private List<DataType> dataPoints;
        private LabeledMatrix<DataType> matrix;
        private double selfSimilarity = -0.5;
        private double lambda = 0.5;
        private int maximalEpochs = 1000;
        private boolean distance;

        @Override
        public MatrixStep<DataType> dataPoints(List<DataType> dataPoints) {
            this.dataPoints = dataPoints;
            return this;
        }

        @Override
        public DistanceStep<DataType> matrix(LabeledMatrix<DataType> matrix) {
            this.matrix = matrix;
            return this;
        }

        @Override
        public ParameterStep<DataType> selfSimilarity(double selfSimilarity) {
            this.selfSimilarity = selfSimilarity;
            return this;
        }

        @Override
        public ParameterStep<DataType> isDistance(boolean distance) {
            this.distance = distance;
            return this;
        }

        @Override
        public ParameterStep<DataType> lambda(double lambda) {
            this.lambda = lambda;
            return this;
        }

        @Override
        public ParameterStep<DataType> maximalEpochs(int maximalEpochs) {
            this.maximalEpochs = maximalEpochs;
            return this;
        }

        @Override
        public AffinityPropagation<DataType> run() {
            return new AffinityPropagation(this);
        }
    }

    public static interface ParameterStep<DataType> {
        public ParameterStep<DataType> selfSimilarity(double var1);

        public ParameterStep<DataType> lambda(double var1);

        public ParameterStep<DataType> maximalEpochs(int var1);

        public AffinityPropagation<DataType> run();
    }

    public static interface DistanceStep<DataType> {
        public ParameterStep<DataType> isDistance(boolean var1);
    }

    public static interface MatrixStep<DataType> {
        public DistanceStep<DataType> matrix(LabeledMatrix<DataType> var1);
    }

    public static interface DataStep<DataType> {
        public MatrixStep<DataType> dataPoints(List<DataType> var1);
    }
}

