/*
 * Decompiled with CFR 0.152.
 */
package de.datexis.sector.encoder;

import de.datexis.common.WordHelpers;
import de.datexis.encoder.impl.BagOfWordsEncoder;
import de.datexis.model.Span;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HeadingEncoder
extends BagOfWordsEncoder {
    protected static final Logger log = LoggerFactory.getLogger(HeadingEncoder.class);
    public static final String ID = "HL";
    public static String OTHER_CLASS = "other";

    public HeadingEncoder() {
        super(ID);
    }

    public void trainModel(List<String> headlines, int minWordFrequency, int minWordLength, WordHelpers.Language language) {
        this.appendTrainLog("Training " + this.getName() + " model...");
        this.setModel(null);
        this.totalWords = 0;
        this.timer.start();
        this.setLanguage(language);
        for (String s : headlines) {
            for (String t : WordHelpers.splitSpaces((String)s)) {
                String w = this.preprocessor.preProcess(t);
                if (w.isEmpty()) continue;
                ++this.totalWords;
                if (this.wordHelpers.isStopWord(w) || w.length() < minWordLength) continue;
                if (!this.vocab.containsWord(w)) {
                    this.vocab.addWord(w);
                    continue;
                }
                this.vocab.incrementWordCounter(w);
            }
        }
        int total = this.vocab.numWords();
        this.vocab.truncateVocabulary(minWordFrequency);
        this.vocab.addWord(this.preprocessor.preProcess(OTHER_CLASS));
        this.vocab.updateHuffmanCodes();
        this.timer.stop();
        this.appendTrainLog("trained " + this.vocab.numWords() + " words (" + total + " total)", this.timer.getLong());
        this.setModelAvailable(true);
    }

    public INDArray encode(String phrase) {
        if (phrase != null) {
            return this.encode(WordHelpers.splitSpaces((String)phrase));
        }
        return this.encodeOtherClass();
    }

    public INDArray encode(Iterable<? extends Span> spans) {
        INDArray vec = super.encode(spans);
        return vec.sumNumber().doubleValue() > 0.0 ? vec : this.encodeOtherClass();
    }

    protected INDArray encode(String[] words) {
        INDArray vec = super.encode(words);
        return vec.sumNumber().doubleValue() > 0.0 ? vec : this.encodeOtherClass();
    }

    public INDArray encodeSubsampled(String phrase) {
        INDArray vec = super.encodeSubsampled(phrase);
        return vec.sumNumber().doubleValue() > 0.0 ? vec : this.encodeOtherClass();
    }

    protected INDArray encodeOtherClass() {
        INDArray vector = Nd4j.zeros((long)this.getEmbeddingVectorSize(), (long)1L);
        return vector;
    }

    public Collection<String> getNearestNeighbours(INDArray v, int maxN) {
        INDArray[] sorted = Nd4j.sortWithIndices((INDArray)v.dup(), (int)0, (boolean)false);
        if (sorted[0].sumNumber().doubleValue() == 0.0) {
            log.warn("NearestNeighbour on zero vector - please check vector alignment!");
        }
        INDArray idx = sorted[0];
        double max = sorted[1].getDouble(0L);
        double med = sorted[1].medianNumber().doubleValue();
        ArrayList<String> result = new ArrayList<String>(maxN);
        int i = 0;
        int n = 0;
        while (n < maxN) {
            String word = this.getWord(idx.getInt(new int[]{i}));
            double prob = sorted[1].getDouble((long)i);
            if (prob == 0.0 || prob < (max + med) / 2.0) break;
            if (!word.equals(OTHER_CLASS)) {
                result.add(word);
                ++n;
            }
            ++i;
        }
        if (result.isEmpty()) {
            result.add(OTHER_CLASS);
        }
        return result;
    }
}

