/*
 * Decompiled with CFR 0.152.
 */
package de.datexis.sector.tagger;

import com.google.common.collect.Lists;
import de.datexis.encoder.Encoder;
import de.datexis.encoder.EncoderSet;
import de.datexis.encoder.EncodingHelpers;
import de.datexis.model.Annotation;
import de.datexis.model.Dataset;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.model.Token;
import de.datexis.sector.encoder.ClassEncoder;
import de.datexis.sector.encoder.HeadingEncoder;
import de.datexis.sector.model.SectionAnnotation;
import de.datexis.sector.tagger.SectorTagger;
import de.datexis.tagger.AbstractMultiDataSetIterator;
import de.datexis.tagger.DocumentSentenceIterator;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;

public class SectorTaggerIterator
extends DocumentSentenceIterator {
    protected EncoderSet inputEncoders;
    protected EncoderSet targetEncoders;
    protected SectorTagger tagger;
    protected boolean requireSubsampling;

    public SectorTaggerIterator(AbstractMultiDataSetIterator.Stage stage, Dataset dataset, SectorTagger tagger, int batchSize, boolean randomize, boolean useMultiClassLabels) {
        this(stage, dataset.getDocuments(), tagger, batchSize, randomize, useMultiClassLabels);
    }

    public SectorTaggerIterator(AbstractMultiDataSetIterator.Stage stage, Collection<Document> docs, SectorTagger tagger, int batchSize, boolean randomize, boolean requireSubsampling) {
        this(stage, docs, tagger, -1, batchSize, randomize, requireSubsampling);
    }

    public SectorTaggerIterator(AbstractMultiDataSetIterator.Stage stage, Collection<Document> docs, SectorTagger tagger, int numExamples, int batchSize, boolean randomize, boolean requireSubsampling) {
        this(stage, docs, tagger, numExamples, -1, batchSize, randomize, requireSubsampling);
    }

    public SectorTaggerIterator(AbstractMultiDataSetIterator.Stage stage, Collection<Document> docs, SectorTagger tagger, int numExamples, int maxTimeSeriesLength, int batchSize, boolean randomize, boolean requireSubsampling) {
        super(stage, docs, numExamples, maxTimeSeriesLength, batchSize, randomize);
        this.tagger = tagger;
        this.inputEncoders = new EncoderSet(new Encoder[]{tagger.bagEncoder, tagger.embEncoder, tagger.flagEncoder});
        this.targetEncoders = new EncoderSet(new Encoder[]{tagger.targetEncoder});
        this.requireSubsampling = requireSubsampling;
        this.reset();
    }

    public boolean asyncSupported() {
        return true;
    }

    public org.nd4j.linalg.dataset.api.MultiDataSet generateDataSet(DocumentSentenceIterator.DocumentBatch batch) {
        INDArray inputMask = this.createMask(batch.docs, batch.maxDocLength, Sentence.class);
        INDArray bag = this.tagger.bagEncoder.encodeMatrix(batch.docs, batch.maxDocLength, Sentence.class);
        INDArray emb = this.tagger.embEncoder.encodeMatrix(batch.docs, batch.maxDocLength, Sentence.class);
        INDArray flag = this.tagger.flagEncoder.encodeMatrix(batch.docs, batch.maxDocLength, Sentence.class);
        INDArray targets = this.stage.equals((Object)AbstractMultiDataSetIterator.Stage.TRAIN) || this.stage.equals((Object)AbstractMultiDataSetIterator.Stage.TEST) ? this.encodeTarget(batch.docs, batch.maxDocLength, Sentence.class) : Nd4j.zeros((long[])new long[]{batch.size, this.tagger.targetEncoder.getEmbeddingVectorSize(), batch.maxDocLength});
        return new MultiDataSet(new INDArray[]{bag, emb, flag}, new INDArray[]{targets, targets}, new INDArray[]{inputMask, inputMask, inputMask}, new INDArray[]{inputMask, inputMask});
    }

    public INDArray createMask(List<Document> input, int maxTimeSteps, Class<? extends Span> timeStepClass) {
        INDArray mask = Nd4j.zeros((DataType)DataType.FLOAT, (long[])new long[]{input.size(), maxTimeSteps});
        for (int batchIndex = 0; batchIndex < input.size(); ++batchIndex) {
            Document example = input.get(batchIndex);
            int spanCount = 0;
            if (timeStepClass == Token.class) {
                spanCount = example.countTokens();
            } else if (timeStepClass == Sentence.class) {
                spanCount = example.countSentences();
            }
            int t = 0;
            while (t < spanCount && t < maxTimeSteps) {
                mask.putScalar(new int[]{batchIndex, t++}, 1);
            }
        }
        return mask;
    }

    public INDArray encodeTarget(List<Document> input, int maxTimeSteps, Class<? extends Span> timeStepClass) {
        INDArray encoding = Nd4j.zeros((long[])new long[]{input.size(), this.tagger.targetEncoder.getEmbeddingVectorSize(), maxTimeSteps});
        for (int batchIndex = 0; batchIndex < input.size(); ++batchIndex) {
            Document example = input.get(batchIndex);
            List spansToEncode = Collections.EMPTY_LIST;
            if (timeStepClass == Token.class) {
                spansToEncode = Lists.newArrayList((Iterable)example.getTokens());
            } else if (timeStepClass == Sentence.class) {
                spansToEncode = Lists.newArrayList((Iterable)example.getSentences());
            }
            List anns = example.streamAnnotations(Annotation.Source.GOLD, SectionAnnotation.class).sorted().collect(Collectors.toList());
            Iterator it = anns.iterator();
            if (!it.hasNext()) {
                return encoding;
            }
            SectionAnnotation ann = (SectionAnnotation)((Object)it.next());
            INDArray vec = this.encodeAnnotation(this.tagger.targetEncoder, ann);
            for (int t = 0; t < spansToEncode.size() && t < maxTimeSteps; ++t) {
                Span s = (Span)spansToEncode.get(t);
                if (s.getBegin() >= ann.getEnd() && it.hasNext()) {
                    ann = (SectionAnnotation)((Object)it.next());
                    vec = this.encodeAnnotation(this.tagger.targetEncoder, ann);
                }
                EncodingHelpers.putTimeStep((INDArray)encoding, (long)batchIndex, (long)t, (INDArray)vec.dup());
            }
        }
        return encoding;
    }

    protected INDArray encodeAnnotation(Encoder enc, SectionAnnotation ann) {
        if (enc instanceof HeadingEncoder) {
            if (this.requireSubsampling) {
                return ((HeadingEncoder)enc).encodeSubsampled(ann.getSectionHeading());
            }
            return ((HeadingEncoder)enc).encode(ann.getSectionHeading());
        }
        if (enc instanceof ClassEncoder) {
            return ((ClassEncoder)enc).encode(ann.getSectionLabel());
        }
        return Nd4j.create((int)1);
    }
}

