/*
 * Decompiled with CFR 0.152.
 */
package de.datexis.parvec.encoder;

import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;

public class LabelSeeker {
    private List<String> labelsUsed;
    private InMemoryLookupTable<VocabWord> lookupTable;

    public LabelSeeker(List<String> labelsUsed, InMemoryLookupTable<VocabWord> lookupTable) {
        if (labelsUsed.isEmpty()) {
            throw new IllegalStateException("You can't have 0 labels used for ParagraphVectors");
        }
        this.lookupTable = lookupTable;
        this.labelsUsed = labelsUsed;
    }

    public List<Pair<String, Double>> getScores(INDArray vector) {
        ArrayList<Pair<String, Double>> result = new ArrayList<Pair<String, Double>>();
        for (String label : this.labelsUsed) {
            INDArray vecLabel = this.lookupTable.vector(label);
            if (vecLabel == null) {
                throw new IllegalStateException("Label '" + label + "' has no known vector!");
            }
            double sim = Transforms.cosineSim((INDArray)vector, (INDArray)vecLabel);
            if (!Double.isFinite(sim)) {
                sim = 0.0;
            }
            result.add((Pair<String, Double>)new Pair((Object)label, (Object)sim));
        }
        return result;
    }

    public INDArray getScoresAsVector(INDArray vector) {
        List<Pair<String, Double>> resultPairs = this.getScores(vector);
        Double[] scores = (Double[])resultPairs.stream().map(Pair::getSecond).toArray(Double[]::new);
        INDArray vec = Nd4j.create((double[])ArrayUtils.toPrimitive((Double[])scores));
        double min = vec.minNumber().doubleValue();
        double max = vec.maxNumber().doubleValue();
        if (max - min == 0.0) {
            return Nd4j.zerosLike((INDArray)vec);
        }
        double scale = 1.0 / (max - min);
        INDArray scaled1 = vec.sub((Number)min).muli((Number)scale);
        double sum = scaled1.sumNumber().doubleValue();
        INDArray summax = sum != 0.0 ? scaled1.div((Number)sum) : scaled1;
        return summax;
    }
}

