/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.ner;

import de.jungblut.datastructure.ArrayUtils;
import de.jungblut.distance.CosineDistance;
import de.jungblut.distance.DistanceMeasurer;
import de.jungblut.distance.SimilarityMeasurer;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.tuple.Tuple;
import gnu.trove.list.array.TIntArrayList;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public final class IterativeSimilarityAggregation {
    private static final Logger LOG = LogManager.getLogger(IterativeSimilarityAggregation.class);
    private final double alpha;
    private final SimilarityMeasurer similarityMeasurer;
    private final String[] seedTokens;
    private int[] seedIndices;
    private String[] termNodes;
    private DoubleMatrix weightMatrix;

    public IterativeSimilarityAggregation(String[] seedTokens, Tuple<String[], DoubleMatrix> bipartiteGraph) {
        this(seedTokens, bipartiteGraph, 0.5, new CosineDistance());
    }

    public IterativeSimilarityAggregation(String[] seedTokens, Tuple<String[], DoubleMatrix> bipartiteGraph, double alpha, DistanceMeasurer distance) {
        this.seedTokens = seedTokens;
        this.termNodes = (String[])bipartiteGraph.getFirst();
        this.weightMatrix = ((DoubleMatrix)bipartiteGraph.getSecond()).transpose();
        this.alpha = alpha;
        this.similarityMeasurer = new SimilarityMeasurer(distance);
        this.init();
    }

    private void init() {
        TIntArrayList list = new TIntArrayList();
        for (String token : this.seedTokens) {
            int find = ArrayUtils.find(this.termNodes, token);
            if (find >= 0) {
                list.add(find);
                continue;
            }
            LOG.info("Seed token \"" + token + "\" could not be found in the term list!");
        }
        this.seedIndices = list.toArray();
    }

    public String[] startStaticThresholding(double similarityThreshold, int maxIterations, boolean verbose) {
        DenseDoubleVector relevanceScore = this.computeRelevanceScore(this.seedIndices);
        int[] relevantTokens = IterativeSimilarityAggregation.filterRelevantItems(relevanceScore, 0.0);
        int iteration = 0;
        while (true) {
            DenseDoubleVector similarityScore;
            DoubleVector rankedTokens;
            int[] topRankedItems;
            boolean equal;
            boolean bl = equal = relevantTokens.length == (topRankedItems = IterativeSimilarityAggregation.getTopRankedItems(rankedTokens = IterativeSimilarityAggregation.rankScores(this.alpha, relevanceScore, similarityScore = this.computeRelevanceScore(relevantTokens)), similarityThreshold)).length;
            if (equal) {
                for (int i = 0; i < topRankedItems.length; ++i) {
                    if (topRankedItems[i] == relevantTokens[i]) continue;
                    equal = false;
                    break;
                }
            }
            if (verbose) {
                LOG.info(iteration + " | Top ranked item size: " + topRankedItems.length);
            }
            relevantTokens = topRankedItems;
            if (equal || maxIterations > 0 && iteration > maxIterations) break;
            ++iteration;
        }
        String[] tokens = new String[relevantTokens.length];
        for (int i = 0; i < relevantTokens.length; ++i) {
            tokens[i] = this.termNodes[relevantTokens[i]];
        }
        return tokens;
    }

    static int[] getTopRankedItems(DoubleVector pRankedTokens, double similarityThreshold) {
        int sortedIndice;
        double val;
        DoubleVector rankedTokens = pRankedTokens.deepCopy();
        int[] sortedIndices = new int[rankedTokens.getLength()];
        for (int i = 0; i < sortedIndices.length; ++i) {
            sortedIndices[i] = i;
        }
        for (int j = 0; j < rankedTokens.getLength() - 1; ++j) {
            int max = j;
            for (int i = j + 1; i < rankedTokens.getLength(); ++i) {
                if (!(rankedTokens.get(i) > rankedTokens.get(max))) continue;
                max = i;
            }
            if (j == max) continue;
            double tmp = rankedTokens.get(max);
            rankedTokens.set(max, rankedTokens.get(j));
            rankedTokens.set(j, tmp);
            ArrayUtils.swap(sortedIndices, max, j);
        }
        TIntArrayList list = new TIntArrayList();
        int[] nArray = sortedIndices;
        int n = nArray.length;
        for (int i = 0; i < n && (val = pRankedTokens.get(sortedIndice = nArray[i])) > similarityThreshold; ++i) {
            list.add(sortedIndice);
        }
        return list.toArray();
    }

    private DenseDoubleVector computeRelevanceScore(int[] seedSet) {
        int termsLength = this.termNodes.length;
        DenseDoubleVector relevanceScores = new DenseDoubleVector(termsLength);
        double constantLoss = 1.0 / (double)seedSet.length;
        for (int i = 0; i < termsLength; ++i) {
            double sum = 0.0;
            for (int j : seedSet) {
                DoubleVector columnVectorI = this.weightMatrix.getColumnVector(i);
                DoubleVector columnVectorJ = this.weightMatrix.getColumnVector(j);
                double similarity = 0.0;
                if (columnVectorI != null && columnVectorJ != null) {
                    similarity = this.similarityMeasurer.measureSimilarity(columnVectorI, columnVectorJ);
                }
                sum += similarity;
            }
            relevanceScores.set(i, constantLoss * sum);
        }
        return relevanceScores;
    }

    static DoubleVector rankScores(double alpha, DenseDoubleVector relevanceScores, DenseDoubleVector similarityScores) {
        DoubleVector multiply = relevanceScores.multiply(alpha);
        return similarityScores.multiply(alpha).add(multiply);
    }

    static int[] filterRelevantItems(DenseDoubleVector relevanceScores, double threshold) {
        TIntArrayList list = new TIntArrayList();
        for (int i = 0; i < relevanceScores.getLength(); ++i) {
            double val = relevanceScores.get(i);
            if (!(val > threshold)) continue;
            list.add(i);
        }
        return list.toArray();
    }
}

