/*
 * Decompiled with CFR 0.152.
 */
package de.datexis.sector.encoder;

import de.datexis.encoder.LookupCacheEncoder;
import de.datexis.model.Document;
import de.datexis.model.Span;
import de.datexis.preprocess.LowercasePreprocessor;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Map;
import org.deeplearning4j.models.word2vec.wordstore.VocabularyWord;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.LoggerFactory;

public class ClassEncoder
extends LookupCacheEncoder {
    private static final TokenPreProcess preprocessor = new LowercasePreprocessor();
    public static final String ID = "CLS";

    public ClassEncoder() {
        super(ID);
        this.log = LoggerFactory.getLogger(ClassEncoder.class);
    }

    public ClassEncoder(String id) {
        super(id);
        this.log = LoggerFactory.getLogger(ClassEncoder.class);
    }

    public String getName() {
        return "Classification Encoder";
    }

    public INDArray encode(Span classLabel) {
        return this.encode(classLabel.getText());
    }

    public long getEmbeddingVectorSize() {
        return this.vocab.numWords();
    }

    public INDArray encode(String classLabel) {
        return this.oneHot(classLabel);
    }

    public int getIndex(String word) {
        String w = preprocessor.preProcess(word);
        return this.vocab.indexOf(w);
    }

    public INDArray oneHot(String word) {
        INDArray vector = Nd4j.zeros((long)this.getEmbeddingVectorSize(), (long)1L);
        int i = this.getIndex(word);
        if (i >= 0) {
            vector.put(i, 0, (Number)1.0);
        } else {
            this.log.warn("could not encode class '{}'. is it contained in training set?", (Object)word);
        }
        return vector;
    }

    public boolean isUnknown(String classLabel) {
        String w = preprocessor.preProcess(classLabel);
        return !this.vocab.containsWord(w);
    }

    public void trainModel(Collection<Document> documents) {
        throw new UnsupportedOperationException("cannot train classification on Documents");
    }

    public void trainModelUsingHead(Iterable<String> classes) {
        this.trainModel(classes, 0);
        double val = 0.0;
        for (VocabularyWord word : this.vocab.words()) {
            val += (double)word.getCount();
        }
        this.vocab.truncateVocabulary((int)(val / (double)this.vocab.numWords()));
        this.vocab.updateHuffmanCodes();
        this.appendTrainLog("truncated to " + this.vocab.numWords() + " classes");
    }

    public void trainModel(Iterable<String> classes, int minClassFrequency) {
        this.appendTrainLog("Training " + this.getName() + " model...");
        this.setModel(null);
        this.timer.start();
        this.totalWords = 0;
        for (String s : classes) {
            String w = preprocessor.preProcess(s);
            ++this.totalWords;
            if (w.isEmpty()) continue;
            if (!this.vocab.containsWord(w)) {
                this.vocab.addWord(w);
                continue;
            }
            this.vocab.incrementWordCounter(w);
        }
        int total = this.vocab.numWords();
        this.vocab.truncateVocabulary(minClassFrequency);
        this.vocab.updateHuffmanCodes();
        this.timer.stop();
        this.appendTrainLog("trained " + this.vocab.numWords() + " classes (" + total + " total)", this.timer.getLong());
        this.setModelAvailable(true);
    }

    public String getNearestNeighbour(INDArray v) {
        Collection<String> knn = this.getNearestNeighbours(v, 1);
        if (knn.isEmpty()) {
            return null;
        }
        return knn.iterator().next();
    }

    public Collection<String> getNearestNeighbours(INDArray v, int k) {
        Double[] data = new Double[(int)v.length()];
        int j = 0;
        while ((long)j < v.length()) {
            data[j] = v.getDouble((long)j);
            ++j;
        }
        ArrayList<String> result = new ArrayList<String>(k);
        for (int i = 0; i < k; ++i) {
            double max = Double.MIN_VALUE;
            int index = 0;
            int j2 = 0;
            while ((long)j2 < v.length()) {
                if (data[j2] > max) {
                    index = j2;
                    max = data[j2];
                    data[j2] = Double.MIN_VALUE;
                }
                ++j2;
            }
            result.add(this.getWord(index));
        }
        return result;
    }

    public Collection<Map.Entry<String, Double>> getNearestNeighbourEntries(INDArray v, int k) {
        Double[] data = new Double[(int)v.length()];
        int j = 0;
        while ((long)j < v.length()) {
            data[j] = v.getDouble((long)j);
            ++j;
        }
        ArrayList<Map.Entry<String, Double>> result = new ArrayList<Map.Entry<String, Double>>(k);
        for (int i = 0; i < k; ++i) {
            double max = Double.MIN_VALUE;
            int index = 0;
            int j2 = 0;
            while ((long)j2 < v.length()) {
                if (data[j2] > max) {
                    index = j2;
                    max = data[j2];
                    data[j2] = Double.MIN_VALUE;
                }
                ++j2;
            }
            result.add(new AbstractMap.SimpleEntry<String, Double>(this.getWord(index), v.getDouble((long)index)));
        }
        return result;
    }
}

