/*
 * Decompiled with CFR 0.152.
 */
package de.datexis.sector.tagger;

import de.datexis.common.Resource;
import de.datexis.encoder.Encoder;
import de.datexis.encoder.EncodingHelpers;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.sector.tagger.SectorTagger;
import de.datexis.sector.tagger.SectorTaggerIterator;
import de.datexis.tagger.AbstractMultiDataSetIterator;
import de.datexis.tagger.DocumentSentenceIterator;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.slf4j.LoggerFactory;

public class SectorEncoder
extends Encoder {
    protected SectorTagger tagger;

    public SectorEncoder() {
        this("SECTOR", new SectorTagger());
    }

    public SectorEncoder(String id) {
        this(id, new SectorTagger());
    }

    public SectorEncoder(String id, SectorTagger sector) {
        super(id);
        this.log = LoggerFactory.getLogger(SectorEncoder.class);
        this.tagger = sector;
        this.setModelFilename(this.tagger.getModel());
        this.setModelAvailable(true);
    }

    @JsonIgnore
    public SectorTagger getTagger() {
        return this.tagger;
    }

    public void setTagger(SectorTagger tagger) {
        this.tagger = tagger;
    }

    public long getEmbeddingVectorSize() {
        return this.tagger.getEmbeddingLayerSize();
    }

    public INDArray encode(Span span) {
        throw new IllegalArgumentException("SECTOR is only implemented to encode over Documents.");
    }

    public INDArray encode(String word) {
        throw new IllegalArgumentException("SECTOR is only implemented to encode over Documents.");
    }

    public void encodeEach(Document d, Class<? extends Span> elementClass) {
        this.encodeEach(Collections.singleton(d), elementClass);
    }

    public void encodeEach(Collection<Document> docs, Class<? extends Span> elementClass) {
        if (elementClass != Sentence.class) {
            throw new IllegalArgumentException("SECTOR is only implemented to encode Sentences over a Document");
        }
        this.tagger.tag(docs);
    }

    public INDArray encodeMatrix(List<Document> input, int maxTimeSteps, Class<? extends Span> timeStepClass) {
        if (timeStepClass != Sentence.class) {
            throw new IllegalArgumentException("SECTOR is only implemented to encode Sentences over a Document");
        }
        SectorTaggerIterator it = new SectorTaggerIterator(AbstractMultiDataSetIterator.Stage.ENCODE, input, this.tagger, this.tagger.getBatchSize(), false, this.tagger.requireSubsampling);
        INDArray result = null;
        while (it.hasNext()) {
            DocumentSentenceIterator.DocumentBatch batch = it.nextDocumentBatch();
            Map<String, INDArray> weights = this.tagger.encodeMatrix(batch);
            INDArray target = weights.get("target");
            INDArray embedding = weights.get("embedding");
            INDArray lstm = weights.get("BLSTM");
            int batchNum = 0;
            for (Document doc : batch.docs) {
                int t = 0;
                for (Sentence s : doc.getSentences()) {
                    if (t >= maxTimeSteps) break;
                    if (target != null) {
                        s.putVector(this.tagger.getTargetEncoder().getClass(), EncodingHelpers.getTimeStep((INDArray)target, (long)batchNum, (long)t));
                    }
                    if (embedding != null) {
                        s.putVector(SectorEncoder.class, EncodingHelpers.getTimeStep((INDArray)embedding, (long)batchNum, (long)t));
                    }
                    ++t;
                }
                ++batchNum;
            }
            if (maxTimeSteps > batch.maxDocLength) {
                embedding = Nd4j.append((INDArray)embedding, (int)(maxTimeSteps - batch.maxDocLength), (double)0.0, (int)2);
            }
            result = result == null ? embedding : Nd4j.concat((int)0, (INDArray[])new INDArray[]{result, embedding});
        }
        return result;
    }

    public void encodeEach(Sentence input, Class<? extends Span> elementClass) {
        throw new IllegalArgumentException("SECTOR is only implemented to encode over Documents.");
    }

    public void trainModel(Collection<Document> documents) {
        throw new UnsupportedOperationException("You need to train SectorTagger.");
    }

    public void loadModel(Resource file) {
        this.tagger.loadModel(file);
        this.setModelAvailable(true);
        this.setModel(file);
    }

    public void saveModel(Resource dir, String name) {
        this.tagger.saveModel(dir, name);
        this.setModelFilename(this.tagger.getModel());
    }

    public String getName() {
        return this.tagger.getName();
    }

    public void setName(String name) {
        this.tagger.setName(name);
    }

    public int getBatchSize() {
        return this.tagger.getBatchSize();
    }

    public void setBatchSize(int size) {
        this.tagger.setBatchSize(size);
    }

    public int getEmbeddingLayerSize() {
        return this.tagger.getEmbeddingLayerSize();
    }

    public void setEmbeddingLayerSize(int size) {
        this.tagger.setEmbeddingLayerSize(size);
    }

    public void setMultiClass(boolean isMultiClass) {
        this.tagger.setRequireSubsampling(isMultiClass);
    }

    public boolean isMultiClass() {
        return this.tagger.isRequireSubsampling();
    }

    public void setNumEpochs(int numEpochs) {
        this.tagger.setNumEpochs(numEpochs);
    }

    public int getNumEpochs() {
        return this.tagger.getNumEpochs();
    }

    public void setRandomize(boolean rand) {
        this.tagger.setRandomize(rand);
    }

    public boolean isRandomize() {
        return this.tagger.isRandomize();
    }
}

