/*
 * Decompiled with CFR 0.152.
 */
package de.datexis.parvec.encoder;

import de.datexis.common.Resource;
import de.datexis.encoder.LookupCacheEncoder;
import de.datexis.model.Annotation;
import de.datexis.model.Dataset;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.parvec.encoder.LabelSeeker;
import de.datexis.parvec.encoder.ParVecIterator;
import de.datexis.preprocess.DocumentFactory;
import de.datexis.preprocess.MinimalLowercasePreprocessor;
import java.io.BufferedOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ParVecEncoder
extends LookupCacheEncoder {
    protected static final Logger log = LoggerFactory.getLogger(ParVecEncoder.class);
    protected WordVectors word2Vec;
    protected ParagraphVectors model;
    protected double learningRate = 0.025;
    protected double minLearningRate = 0.001;
    protected int batchSize = 16;
    protected int numEpochs = 1;
    protected int iterations = 5;
    protected int layerSize = 256;
    protected int targetSize;
    protected int windowSize = 10;
    protected static final TokenPreProcess preprocessor = new MinimalLowercasePreprocessor();
    protected final DefaultTokenizerFactory tokenizerFactory;
    protected List<VocabWord> labelsList;
    protected List<String> stopwords = new ArrayList<String>();

    public ParVecEncoder() {
        super("PV");
        this.tokenizerFactory = new DefaultTokenizerFactory();
        this.tokenizerFactory.setTokenPreProcessor(preprocessor);
    }

    public ParVecEncoder withWordEmbedding(WordVectors word2Vec) {
        this.word2Vec = word2Vec;
        return this;
    }

    public void setModelParams(int layerSize, int windowSize) {
        this.layerSize = layerSize;
        this.windowSize = windowSize;
    }

    public void setTrainingParams(double learningRate, double minLearningRate, int batchSize, int iterations, int numEpochs) {
        this.learningRate = learningRate;
        this.minLearningRate = minLearningRate;
        this.batchSize = batchSize;
        this.iterations = iterations;
        this.numEpochs = numEpochs;
    }

    public void setStopWords(List<String> words) {
        this.stopwords = words;
    }

    public void trainModel(Collection<Document> documents) {
        throw new UnsupportedOperationException("Please call trainModel(Dataset train)");
    }

    public void trainModel(Dataset train) {
        ParVecIterator it = new ParVecIterator(train, true);
        AbstractCache cache = new AbstractCache();
        ParagraphVectors.Builder builder = new ParagraphVectors.Builder().minWordFrequency(3).iterations(this.iterations).epochs(this.numEpochs).layerSize(this.layerSize).learningRate(this.learningRate).minLearningRate(this.minLearningRate).batchSize(this.batchSize).windowSize(this.windowSize).iterate((LabelAwareSentenceIterator)it).trainWordVectors(true).vocabCache((VocabCache)cache).tokenizerFactory((TokenizerFactory)this.tokenizerFactory).stopWords(this.stopwords).sampling(0.0);
        if (this.word2Vec != null) {
            builder.useExistingWordVectors(this.word2Vec);
        }
        this.model = builder.build();
        log.info("training ParVec...");
        this.model.fit();
        log.info("training complete.");
        try {
            Field labelsListField = ParagraphVectors.class.getDeclaredField("labelsList");
            labelsListField.setAccessible(true);
            this.labelsList = (List)labelsListField.get(this.model);
            this.targetSize = this.labelsList.size();
        }
        catch (IllegalAccessException | NoSuchFieldException e) {
            log.error(e.getMessage(), (Throwable)e);
            throw new RuntimeException(e);
        }
        this.setModelAvailable(true);
    }

    public INDArray encode(Span span) {
        if (span instanceof Sentence) {
            String text = ((Sentence)span).toTokenizedString().trim().replaceAll("\n", "*NL*").replaceAll("\t", "*t*");
            try {
                return this.model.inferVector(text, this.learningRate, this.minLearningRate, 1).transpose();
            }
            catch (ND4JIllegalStateException ex) {
                return Nd4j.zeros((long)this.layerSize, (long)1L);
            }
        }
        return this.encode(span.getText());
    }

    public INDArray encode(Annotation ann, Document doc) {
        String text = doc.streamSentencesInRange(ann.getBegin(), ann.getEnd(), false).map(s -> s.toTokenizedString().trim().replaceAll("\n", "*NL*").replaceAll("\t", "*t*")).collect(Collectors.joining(" "));
        try {
            return this.model.inferVector(text, this.learningRate, this.minLearningRate, 1).transpose();
        }
        catch (ND4JIllegalStateException ex) {
            return Nd4j.zeros((long)this.layerSize, (long)1L);
        }
    }

    public INDArray encode(String text) {
        text = DocumentFactory.createTokensFromText((String)text).stream().map(t -> t.getText().trim().replaceAll("\n", "*NL*").replaceAll("\t", "*t*")).collect(Collectors.joining(" "));
        try {
            return this.model.inferVector(text).transpose();
        }
        catch (ND4JIllegalStateException ex) {
            log.trace("unknown paragraph vector for '{}'", (Object)text);
            return Nd4j.zeros((long)this.layerSize, (long)1L);
        }
    }

    public void saveModel(Resource modelPath, String name) {
        try {
            Resource modelFile = modelPath.resolve(name + ".zip");
            WordVectorSerializer.writeParagraphVectors((ParagraphVectors)this.model, (OutputStream)modelFile.getOutputStream());
            this.setModel(modelFile);
        }
        catch (IOException ex) {
            log.error(ex.toString());
        }
    }

    public static ParVecEncoder load(Resource path) throws IOException {
        ParVecEncoder encoder = new ParVecEncoder();
        encoder.loadModel(path);
        return encoder;
    }

    public void loadModel(Resource modelFile) throws IOException {
        this.model = WordVectorSerializer.readParagraphVectors((InputStream)modelFile.getInputStream());
        this.model.setTokenizerFactory((TokenizerFactory)this.tokenizerFactory);
        this.layerSize = this.model.getLayerSize();
        try {
            Field labelsListField = ParagraphVectors.class.getDeclaredField("labelsList");
            labelsListField.setAccessible(true);
            this.labelsList = (List)labelsListField.get(this.model);
            this.targetSize = this.labelsList.size();
        }
        catch (IllegalAccessException | NoSuchFieldException e) {
            log.error(e.getMessage(), (Throwable)e);
            throw new RuntimeException(e);
        }
        log.info("Loaded ParagraphVectors with {} classes and layer size {}", (Object)this.targetSize, (Object)this.layerSize);
        this.setModel(modelFile);
        this.setModelAvailable(true);
    }

    public void writeBinaryW2VModel(OutputStream outputStream) throws IOException {
        int words = 0;
        try (BufferedOutputStream buf = new BufferedOutputStream(outputStream);
             DataOutputStream writer = new DataOutputStream(buf);){
            for (Object word : this.model.vocab().words()) {
                if (word == null) continue;
                INDArray wordVector = this.model.getWordVectorMatrix((String)word);
                log.trace("Write: " + word + " (size " + wordVector.length() + ")");
                writer.writeUTF((String)word);
                Nd4j.write((INDArray)wordVector, (DataOutputStream)writer);
                ++words;
            }
            writer.flush();
        }
        log.info("Wrote " + words + " words with size " + this.model.vectorSize());
    }

    @JsonIgnore
    public List<String> getWords() {
        return this.labelsList.stream().map(VocabWord::getLabel).collect(Collectors.toList());
    }

    public int getTotalWords() {
        return this.labelsList.size();
    }

    public long getEmbeddingVectorSize() {
        return this.model.inferVector("test").length();
    }

    public long getOutputVectorSize() {
        return this.targetSize;
    }

    public int getInputVectorSize() {
        return 0;
    }

    public String getWord(int index) {
        if (this.labelsList.size() < index) {
            return null;
        }
        return this.labelsList.get(index).getWord();
    }

    public int getIndex(String word) {
        return IntStream.range(0, this.labelsList.size()).filter(i -> word.equals(this.labelsList.get(i).getWord())).findFirst().orElse(-1);
    }

    public INDArray oneHot(String word) {
        INDArray vector = Nd4j.zeros((long)this.targetSize, (long)1L);
        int i = this.getIndex(word);
        if (i >= 0) {
            vector.putScalar((long)i, 1.0);
        } else {
            log.warn("could not encode class '{}'. is it contained in training set?", (Object)word);
        }
        return vector.transpose();
    }

    public String getNearestNeighbour(INDArray v) {
        return this.getNearestNeighbours(v, 1).stream().findFirst().orElse(null);
    }

    public Collection<String> getNearestNeighbours(INDArray v, int k) {
        INDArray[] sorted = Nd4j.sortWithIndices((INDArray)Nd4j.toFlattened((INDArray[])new INDArray[]{v}).dup(), (int)1, (boolean)false);
        if (sorted[0].length() <= 1L || sorted[0].sumNumber().doubleValue() == 0.0) {
            log.warn("NearestNeighbour on zero vector - please check vector alignment!");
        }
        INDArray idx = sorted[0];
        ArrayList<String> result = new ArrayList<String>(k);
        int i = 0;
        while (i < k) {
            result.add(this.getWord(idx.getInt(new int[]{i++})));
        }
        return result;
    }

    public INDArray getPredictions(INDArray v) {
        LabelSeeker seeker = new LabelSeeker(this.getWords(), (InMemoryLookupTable<VocabWord>)((InMemoryLookupTable)this.model.getLookupTable()));
        return seeker.getScoresAsVector(v).transpose();
    }
}

