/*
 * Decompiled with CFR 0.152.
 */
package de.datexis.ner.tagger;

import de.datexis.encoder.EncoderSet;
import de.datexis.encoder.EncodingHelpers;
import de.datexis.model.Annotation;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Token;
import de.datexis.model.tag.Tag;
import de.datexis.tagger.CachedSentenceIterator;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.LoggerFactory;

public class MentionTaggerIterator
extends CachedSentenceIterator {
    protected Annotation.Source source = Annotation.Source.GOLD;

    public MentionTaggerIterator(Collection<Document> docs, String name, EncoderSet encoders, Class tagset, Annotation.Source source, int numExamples, int batchSize, boolean randomize) {
        super(docs, name, numExamples, batchSize, randomize);
        this.log = LoggerFactory.getLogger(MentionTaggerIterator.class);
        this.source = source;
        this.encoders = encoders;
        this.tagset = tagset;
        try {
            this.inputSize = encoders.getEmbeddingVectorSize();
            this.labelSize = ((Tag)this.tagset.newInstance()).getVectorSize();
        }
        catch (IllegalAccessException | InstantiationException ex) {
            this.log.error("Could not instantiate target class " + tagset.getName());
        }
        this.reset();
    }

    public MentionTaggerIterator(Collection<Document> docs, String name, EncoderSet encoders, Class tagset, int numExamples, int batchSize, boolean randomize) {
        this(docs, name, encoders, tagset, Annotation.Source.GOLD, numExamples, batchSize, randomize);
    }

    public List<Token> nextTokens() {
        return this.nextSentence().getTokens();
    }

    public DataSet generateDataSet(ArrayList<Sentence> examples, int num, int exampleSize) {
        INDArray input = EncodingHelpers.createTimeStepMatrix((long)num, (long)this.inputSize, (long)exampleSize);
        INDArray label = EncodingHelpers.createTimeStepMatrix((long)num, (long)this.labelSize, (long)exampleSize);
        INDArray featuresMask = Nd4j.zeros((long)num, (long)exampleSize);
        INDArray labelsMask = Nd4j.zeros((long)num, (long)exampleSize);
        DataSet result = new DataSet(input, label, featuresMask, labelsMask);
        for (int batchNum = 0; batchNum < num; ++batchNum) {
            Sentence example = examples.get(batchNum);
            for (int t = 0; t < example.countTokens(); ++t) {
                featuresMask.put(batchNum, t, (Number)1);
                labelsMask.put(batchNum, t, (Number)1);
                INDArray inputEncoding = example.getToken(t).getVector(this.encoders);
                EncodingHelpers.putTimeStep((INDArray)result.getFeatures(), (long)batchNum, (long)t, (INDArray)inputEncoding);
                Tag goldLabel = example.getToken(t).getTag(this.source, this.tagset);
                EncodingHelpers.putTimeStep((INDArray)result.getLabels(), (long)batchNum, (long)t, (INDArray)goldLabel.getVector());
            }
        }
        if (this.clearCache()) {
            this.log.trace("Iterate: cleared embeddings []");
        }
        return result;
    }
}

