/*
 * Decompiled with CFR 0.152.
 */
package de.datexis.index.encoder;

import de.datexis.common.Resource;
import de.datexis.encoder.Encoder;
import de.datexis.index.ArticleRef;
import de.datexis.index.WikiDataArticle;
import de.datexis.model.Annotation;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.preprocess.MinimalLowercasePreprocessor;
import java.io.IOException;
import java.io.InputStream;
import java.util.Collection;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class EntityEncoder
extends Encoder {
    protected static final Logger log = LoggerFactory.getLogger(EntityEncoder.class);
    protected ParagraphVectors parvec;
    protected Strategy strategy;

    public EntityEncoder(Resource paragraphVectors, Strategy strategy) throws IOException {
        this.loadModel(paragraphVectors);
        this.strategy = strategy;
    }

    public void loadModel(Resource paragraphVectors) throws IOException {
        log.info("loading paragraph vectors...");
        this.parvec = WordVectorSerializer.readParagraphVectors((InputStream)paragraphVectors.getInputStream());
        DefaultTokenizerFactory t = new DefaultTokenizerFactory();
        t.setTokenPreProcessor((TokenPreProcess)new MinimalLowercasePreprocessor());
        this.parvec.setTokenizerFactory((TokenizerFactory)t);
        log.info("loaded " + this.parvec.getLabelsSource().getLabels().size() + " paragraph labels with size " + this.parvec.getLayerSize());
    }

    public long getEmbeddingVectorSize() {
        if (this.strategy.equals((Object)Strategy.NAME)) {
            return this.parvec.getLayerSize();
        }
        if (this.strategy.equals((Object)Strategy.NAME_CONTEXT)) {
            return this.parvec.getLayerSize() * 2;
        }
        throw new IllegalArgumentException("invalid strategy");
    }

    public INDArray encodeEntity(WikiDataArticle art) {
        return this.encodeEntity(art.getId(), art.getTitle(), art.getDescription());
    }

    public INDArray encodeEntity(ArticleRef ref) {
        return this.encodeEntity(ref.getId(), ref.getTitle(), ref.getDescription());
    }

    private INDArray encodeEntity(String id, String title, String description) {
        INDArray nameEmbedding = this.encodeID(id, title);
        if (this.strategy.equals((Object)Strategy.NAME)) {
            return nameEmbedding;
        }
        if (this.strategy.equals((Object)Strategy.NAME_CONTEXT)) {
            INDArray contextEmbedding;
            String context = title;
            if (description != null) {
                context = context + " " + description;
            }
            if ((contextEmbedding = this.encode(context)).maxNumber().doubleValue() == 0.0) {
                contextEmbedding = nameEmbedding;
            }
            return Nd4j.hstack((INDArray[])new INDArray[]{nameEmbedding, contextEmbedding});
        }
        throw new IllegalArgumentException("invalid strategy");
    }

    public INDArray encodeID(String id, String fallback) {
        try {
            return this.normalize(this.parvec.getWordVectorMatrix(id));
        }
        catch (Exception e) {
            return null;
        }
    }

    public INDArray encodeMention(String mention, String context) {
        INDArray nameEmbedding = this.encode(mention);
        if (this.strategy.equals((Object)Strategy.NAME)) {
            return nameEmbedding;
        }
        if (this.strategy.equals((Object)Strategy.NAME_CONTEXT)) {
            INDArray contextEmbedding = this.encode(context);
            return Nd4j.hstack((INDArray[])new INDArray[]{nameEmbedding, contextEmbedding});
        }
        throw new IllegalArgumentException("invalid strategy");
    }

    public INDArray encode(Span span) {
        return this.encode(span.getText());
    }

    public INDArray encode(String word) {
        try {
            return this.normalize(this.parvec.inferVector(word));
        }
        catch (Exception e) {
            return Nd4j.zeros((int)this.parvec.getLayerSize());
        }
    }

    public void encodeEach(Document doc, Annotation.Source source, Class<? extends Annotation> type) {
        doc.streamAnnotations(source, type).forEach(ann -> {
            String entityMention = ann.getText();
            String entityContext = ((Sentence)doc.getSentenceAtPosition(ann.getBegin()).get()).toTokenizedString();
            INDArray vec = this.encodeMention(entityMention, entityContext);
            ann.putVector(EntityEncoder.class, vec);
        });
    }

    private INDArray normalize(INDArray vec) {
        return vec != null ? Transforms.unitVec((INDArray)vec) : null;
    }

    public void trainModel(Collection<Document> documents) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public void saveModel(Resource dir, String name) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public static enum Strategy {
        NAME,
        NAME_CONTEXT;

    }
}

