/*
 * Decompiled with CFR 0.152.
 */
package de.datexis.index.impl;

import de.datexis.common.Resource;
import de.datexis.index.ArticleRef;
import de.datexis.index.encoder.EntityEncoder;
import de.datexis.index.impl.LuceneArticleIndex;
import java.io.BufferedOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.VocabularyHolder;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KNNArticleIndex
extends LuceneArticleIndex {
    protected static final Logger log = LoggerFactory.getLogger(KNNArticleIndex.class);
    protected ParagraphVectors parvec;
    EntityEncoder encoder;
    protected VocabCache<VocabWord> vocabCache;
    protected WeightLookupTable<VocabWord> lookupVectors = null;
    protected ModelUtils<VocabWord> lookupUtils;

    public KNNArticleIndex(Resource parVec) throws IOException {
        this.encoder = new EntityEncoder(parVec, EntityEncoder.Strategy.NAME);
        this.generateLookupCache();
    }

    protected void generateLookupCache() {
        log.debug("building entity list....");
        VocabularyHolder ids = new VocabularyHolder.Builder().build();
        this.vocabCache = new InMemoryLookupCache();
        this.lookupUtils = new BasicModelUtils();
        int num = 0;
        try {
            IndexReader reader = this.searcher.getIndexReader();
            ArrayList entries = new ArrayList(reader.maxDoc());
            for (int i = 0; i < reader.maxDoc(); ++i) {
                Document d = reader.document(i);
                ArticleRef ref = this.createWikidataArticleRef(d);
                String id = ref.getId();
                if (!ids.containsWord(id)) {
                    ids.addWord(id);
                    continue;
                }
                ids.incrementWordCounter(id);
            }
            ids.updateHuffmanCodes();
            ids.transferBackToVocabCache(this.vocabCache, true);
            this.lookupVectors = new InMemoryLookupTable(this.vocabCache, (int)this.encoder.getEmbeddingVectorSize(), true, 0.01, Nd4j.getRandom(), 0.0, true);
            this.lookupVectors.resetWeights();
            num = 0;
            for (ArticleRef ref : entries) {
                INDArray embedding = this.encoder.encodeEntity(ref);
                this.lookupVectors.putVector(ref.getId(), embedding);
                if (++num % 100000 != 0) continue;
                log.info("inserted " + num + " vectors into lookup table");
            }
            log.info("generated " + entries.size() + " entity vectors");
            this.lookupUtils.init(this.lookupVectors);
            log.info("initialized lookup tables");
        }
        catch (IOException ex) {
            log.error(ex.toString());
        }
    }

    public void saveModel(Resource modelPath, String name) throws IOException {
        KNNArticleIndex.writeBinaryModel(this.lookupVectors, modelPath.resolve(name + "_lookup.bin").getOutputStream());
    }

    private static void writeBinaryModel(WeightLookupTable<VocabWord> vec, OutputStream outputStream) throws IOException {
        int words = 0;
        try (BufferedOutputStream buf = new BufferedOutputStream(outputStream);
             DataOutputStream writer = new DataOutputStream(buf);){
            for (String word : vec.getVocabCache().words()) {
                if (word == null) continue;
                INDArray wordVector = vec.vector(word);
                log.trace("Write: " + word + " (size " + wordVector.length() + ")");
                writer.writeUTF(word);
                Nd4j.write((INDArray)wordVector, (DataOutputStream)writer);
                ++words;
            }
            writer.flush();
        }
        log.info("Wrote " + words + " words with size " + vec.layerSize());
    }

    public List<ArticleRef> querySimilarArticles(String wikidataId, int hits) {
        ArrayList<ArticleRef> result = new ArrayList<ArticleRef>(hits);
        for (String id : this.lookupUtils.wordsNearest(wikidataId, hits)) {
            Optional<ArticleRef> a = this.queryWikidataID(id);
            if (!a.isPresent()) continue;
            result.add(a.get());
        }
        return result;
    }

    public List<ArticleRef> querySimilarArticles(String entityName, String context, int hits) {
        ArrayList<ArticleRef> result = new ArrayList<ArticleRef>(hits);
        INDArray eVec = this.encoder.encode(entityName);
        INDArray cVec = this.encoder.encode(context);
        for (String id : this.lookupUtils.wordsNearest(Nd4j.hstack((INDArray[])new INDArray[]{eVec, cVec}), hits)) {
            Optional<ArticleRef> a = this.queryWikidataID(id);
            if (!a.isPresent()) continue;
            result.add(a.get());
        }
        return result;
    }

    public List<ArticleRef> querySimilarArticles(INDArray vec, int hits) {
        ArrayList<ArticleRef> result = new ArrayList<ArticleRef>(hits);
        for (String id : this.lookupUtils.wordsNearest(vec, hits)) {
            Optional<ArticleRef> a = this.queryWikidataID(id);
            if (!a.isPresent()) continue;
            result.add(a.get());
        }
        return result;
    }
}

