/*
 * Decompiled with CFR 0.152.
 */
package de.datexis.sector;

import de.datexis.annotator.Annotator;
import de.datexis.annotator.AnnotatorComponent;
import de.datexis.common.Resource;
import de.datexis.common.WordHelpers;
import de.datexis.encoder.Encoder;
import de.datexis.encoder.LookupCacheEncoder;
import de.datexis.model.Annotation;
import de.datexis.model.Dataset;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.sector.encoder.ClassEncoder;
import de.datexis.sector.encoder.ClassTag;
import de.datexis.sector.encoder.HeadingEncoder;
import de.datexis.sector.encoder.HeadingTag;
import de.datexis.sector.eval.SectorEvaluation;
import de.datexis.sector.model.SectionAnnotation;
import de.datexis.sector.tagger.ScoreImprovementMinEpochsTerminationCondition;
import de.datexis.sector.tagger.SectorEncoder;
import de.datexis.sector.tagger.SectorTagger;
import de.datexis.sector.tagger.SectorTaggerIterator;
import de.datexis.tagger.AbstractMultiDataSetIterator;
import de.datexis.tagger.Tagger;
import java.io.BufferedOutputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Collection;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.memory.abstracts.DummyWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.dimensionalityreduction.PCA;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SectorAnnotator
extends Annotator {
    protected static final Logger log = LoggerFactory.getLogger(SectorAnnotator.class);
    protected String presavedDatasetDirectory = "";

    public void setPresavedDatasetDirectory(Resource directory) {
        this.presavedDatasetDirectory = directory.getPath().toAbsolutePath().toString();
    }

    public SectorAnnotator() {
    }

    public SectorAnnotator(Tagger root) {
        super(root);
    }

    protected SectorAnnotator(AnnotatorComponent comp) {
        super(comp);
    }

    public SectorTagger getTagger() {
        return (SectorTagger)super.getTagger();
    }

    public LookupCacheEncoder getTargetEncoder() {
        return (LookupCacheEncoder)this.getTagger().getTargetEncoder();
    }

    public void annotate(Collection<Document> docs) {
        this.annotate(docs, SegmentationMethod.BEMD);
    }

    public void annotate(Collection<Document> docs, SegmentationMethod segmentation) {
        log.info("Running SECTOR neural net encoding...");
        this.getTagger().attachVectors(docs, AbstractMultiDataSetIterator.Stage.ENCODE, this.getTargetEncoder().getClass());
        if (!segmentation.equals((Object)SegmentationMethod.NONE)) {
            this.segment(docs, segmentation, true);
        }
    }

    public void segment(Collection<Document> docs, SegmentationMethod segmentation, boolean mergeSections) {
        log.info("Predicting segmentation {}...", (Object)segmentation.toString());
        this.detectSections(docs, segmentation);
        if (mergeSections) {
            // empty if block
        }
        log.info("Attaching Annotations...");
        for (Document doc : docs) {
            SectorAnnotator.attachVectorsToAnnotations(doc, this.getTargetEncoder());
        }
        log.info("Segmentation done.");
    }

    /*
     * Unable to fully structure code
     */
    protected void detectSections(Collection<Document> docs, SegmentationMethod segmentation) {
        cMode = this.getTagger().getNN().getConfiguration().getInferenceWorkspaceMode();
        this.getTagger().getNN().getConfiguration().setTrainingWorkspaceMode(this.getTagger().getNN().getConfiguration().getInferenceWorkspaceMode());
        workspace = this.getTagger().getNN().getConfiguration().getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread();
lbl4:
        // 11 sources

        block16: for (Document doc : docs) {
            wsE = workspace.notifyScopeEntered();
            var8_8 = null;
            try {
                switch (1.$SwitchMap$de$datexis$sector$SectorAnnotator$SegmentationMethod[segmentation.ordinal()]) {
                    case 1: {
                        SectorAnnotator.applySectionsFromGold(doc);
                        ** break;
                    }
                    case 2: {
                        SectorAnnotator.applySectionsFromTargetLabels(doc, this.getTargetEncoder(), 2);
                        ** break;
                    }
                    case 3: {
                        mag = SectorAnnotator.detectSectionsFromEmbeddingDeviation(doc);
                        SectorAnnotator.applySectionsFromEdges(doc, SectorAnnotator.detectEdges(mag));
                        ** break;
                    }
                    case 4: {
                        mag = SectorAnnotator.detectSectionsFromBidirectionalEmbeddingDeviation(doc);
                        SectorAnnotator.applySectionsFromEdges(doc, SectorAnnotator.detectEdges(mag));
                        ** break;
                    }
                    case 5: {
                        mag = SectorAnnotator.detectSectionsFromBidirectionalEmbeddingDeviation(doc);
                        expectedNumberOfSections = (int)doc.countAnnotations(Annotation.Source.GOLD);
                        SectorAnnotator.applySectionsFromEdges(doc, SectorAnnotator.detectEdges(mag, expectedNumberOfSections));
                        ** break;
                    }
                    default: {
                        SectorAnnotator.applySectionsFromNewlines(doc);
                        continue block16;
                    }
                }
            }
            catch (Throwable var9_11) {
                var8_8 = var9_11;
                throw var9_11;
            }
            finally {
                if (wsE == null) continue;
                if (var8_8 != null) {
                    try {
                        wsE.close();
                    }
                    catch (Throwable var9_10) {
                        var8_8.addSuppressed(var9_10);
                    }
                    continue;
                }
                wsE.close();
            }
        }
        this.getTagger().getNN().getConfiguration().setTrainingWorkspaceMode(cMode);
    }

    public double evaluateModel(Dataset test) {
        return this.evaluateModel(test, true, true, true);
    }

    public double evaluateModel(Dataset test, boolean evalSentenceClassification, boolean evalSegmentation, boolean evalSegmentClassification) {
        SectorEvaluation eval;
        if (this.getTargetEncoder().getClass() == HeadingEncoder.class) {
            HeadingEncoder headings = (HeadingEncoder)this.getComponent("HL");
            eval = new SectorEvaluation(test.getName(), Annotation.Source.GOLD, Annotation.Source.PRED, (LookupCacheEncoder)headings);
            if (evalSentenceClassification) {
                log.info("Creating tags...");
                SectorAnnotator.removeTags(test.getDocuments(), Annotation.Source.PRED);
                this.createHeadingTags(test.getDocuments(), Annotation.Source.GOLD, headings);
                this.createHeadingTags(test.getDocuments(), Annotation.Source.PRED, headings);
            }
        } else if (this.getTargetEncoder().getClass() == ClassEncoder.class) {
            ClassEncoder classes = (ClassEncoder)this.getComponent("CLS");
            eval = new SectorEvaluation(test.getName(), Annotation.Source.GOLD, Annotation.Source.PRED, classes);
            if (evalSentenceClassification) {
                log.info("Creating tags...");
                SectorAnnotator.removeTags(test.getDocuments(), Annotation.Source.PRED);
                this.createClassTags(test.getDocuments(), Annotation.Source.GOLD, classes);
                this.createClassTags(test.getDocuments(), Annotation.Source.PRED, classes);
            }
        } else {
            throw new IllegalArgumentException("Target encoder has no evaluation: " + this.getTargetEncoder().getClass().toString());
        }
        eval.withSentenceClassEvaluation(evalSentenceClassification).withSegmentationEvaluation(evalSegmentClassification).withSegmentClassEvaluation(evalSegmentation).calculateScores(test);
        this.getTagger().appendTestLog(SectorEvaluation.printDatasetStats(test));
        this.getTagger().appendTestLog(eval.printEvaluationStats());
        this.getTagger().appendTestLog(eval.printSingleClassStats());
        return eval.getScore();
    }

    public void exportBatchesToFiles(Resource directory, Dataset dataset, int batchsize, int queueSize) throws IOException {
        SectorTagger tagger = this.getTagger();
        if (queueSize == -1) {
            queueSize = 256;
        }
        int maxTimeSeriesLength = this.getTagger().getMaxTimeSeriesLength();
        SectorTaggerIterator it = new SectorTaggerIterator(AbstractMultiDataSetIterator.Stage.TRAIN, dataset.getDocuments(), this.getTagger(), dataset.getDocuments().size(), maxTimeSeriesLength, batchsize, true, false);
        AsyncMultiDataSetIterator ait = new AsyncMultiDataSetIterator((MultiDataSetIterator)it, queueSize);
        this.presavedDatasetDirectory = directory.getPath().toAbsolutePath().toString();
        int batch = 0;
        while (ait.hasNext()) {
            BufferedOutputStream bo = new BufferedOutputStream(new FileOutputStream(this.presavedDatasetDirectory + "/train_" + batch + ".bin"));
            ait.next().save((OutputStream)bo);
            log.info("Exported Batch: " + ++batch);
            bo.close();
        }
    }

    public void trainModelPresaved(int epochs) {
        this.getTagger().trainModelPresaved(this.presavedDatasetDirectory, epochs);
    }

    public void trainModel(Dataset train) {
        this.provenance.setDataset(train.getName());
        this.provenance.setLanguage(train.getLanguage());
        this.getTagger().trainModel(train);
    }

    public void trainModel(Dataset train, int numEpochs) {
        this.provenance.setDataset(train.getName());
        this.provenance.setLanguage(train.getLanguage());
        this.getTagger().trainModel(train, numEpochs);
    }

    public void trainModelEarlyStopping(Dataset train, Dataset validation, int minEpochs, int minEpochsNoImprovement, int maxEpochs) {
        EarlyStoppingConfiguration conf = new EarlyStoppingConfiguration.Builder().evaluateEveryNEpochs(1).epochTerminationConditions(new EpochTerminationCondition[]{new ScoreImprovementMinEpochsTerminationCondition(minEpochs, minEpochsNoImprovement, maxEpochs)}).saveLastModel(false).build();
        EarlyStoppingResult<ComputationGraph> result = this.getTagger().trainModel(train, validation, conf);
        this.getTagger().appendTrainLog("Training complete " + result.toString());
    }

    private void createHeadingTags(Iterable<Document> docs, Annotation.Source source, HeadingEncoder headings) {
        HeadingTag.Factory headingTags = new HeadingTag.Factory(headings);
        for (Document doc : docs) {
            if (doc.isTagAvaliable(source, HeadingTag.class)) continue;
            if (source.equals((Object)Annotation.Source.GOLD)) {
                headingTags.attachFromSectionAnnotations(doc, source);
                continue;
            }
            if (!source.equals((Object)Annotation.Source.PRED)) continue;
            headingTags.attachFromSentenceVectors(doc, HeadingEncoder.class, source);
        }
    }

    private void createClassTags(Iterable<Document> docs, Annotation.Source source, ClassEncoder classes) {
        ClassTag.Factory classTags = new ClassTag.Factory(classes);
        for (Document doc : docs) {
            if (doc.isTagAvaliable(source, ClassTag.class)) continue;
            if (source.equals((Object)Annotation.Source.GOLD)) {
                classTags.attachFromSectionAnnotations(doc, source);
                continue;
            }
            if (!source.equals((Object)Annotation.Source.PRED)) continue;
            classTags.attachFromSentenceVectors(doc, ClassEncoder.class, source);
        }
    }

    private static void removeTags(Iterable<Document> docs, Annotation.Source source) {
        for (Document doc : docs) {
            for (Sentence s : doc.getSentences()) {
                s.clearTags(source);
            }
            doc.setTagAvailable(source, HeadingTag.class, false);
            doc.setTagAvailable(source, ClassTag.class, false);
        }
    }

    protected static void attachVectorsToAnnotations(Document doc, LookupCacheEncoder targetEncoder) {
        for (SectionAnnotation ann : doc.getAnnotations(Annotation.Source.GOLD, SectionAnnotation.class)) {
            INDArray exp;
            if (targetEncoder.getClass() == ClassEncoder.class) {
                exp = targetEncoder.encode(ann.getSectionLabel());
                ann.putVector(ClassEncoder.class, exp);
                continue;
            }
            if (targetEncoder.getClass() != HeadingEncoder.class) continue;
            exp = targetEncoder.encode(ann.getSectionHeading());
            ann.putVector(HeadingEncoder.class, exp);
        }
        for (SectionAnnotation ann : doc.getAnnotations(Annotation.Source.PRED, SectionAnnotation.class)) {
            int count = 0;
            INDArray pred = Nd4j.zeros((long[])new long[]{targetEncoder.getEmbeddingVectorSize(), 1L});
            for (Sentence s : doc.streamSentencesInRange(ann.getBegin(), ann.getEnd(), false).collect(Collectors.toList())) {
                pred.addi(s.getVector(targetEncoder.getClass()));
                ++count;
            }
            if (count > 1) {
                pred.divi((Number)count);
            }
            if (targetEncoder.getClass() == ClassEncoder.class) {
                ann.putVector(ClassEncoder.class, pred);
                ann.setSectionLabel(targetEncoder.getNearestNeighbour(pred));
                ann.setConfidence(targetEncoder.getMaxConfidence(pred));
                continue;
            }
            if (targetEncoder.getClass() != HeadingEncoder.class) continue;
            ann.putVector(HeadingEncoder.class, pred);
            Collection preds = targetEncoder.getNearestNeighbours(pred, 2);
            ann.setSectionHeading(StringUtils.join((Iterable)preds, (String)"/"));
            ann.setConfidence(targetEncoder.getMaxConfidence(pred));
        }
    }

    private static void applySectionsFromGold(Document doc) {
        SectionAnnotation section = null;
        for (SectionAnnotation ann : doc.getAnnotations(Annotation.Source.GOLD, SectionAnnotation.class)) {
            section = new SectionAnnotation(Annotation.Source.PRED);
            section.setBegin(ann.getBegin());
            section.setEnd(ann.getEnd());
            doc.addAnnotation((Annotation)section);
        }
    }

    private static void applySectionsFromNewlines(Document doc) {
        SectionAnnotation section = null;
        for (Sentence s : doc.getSentences()) {
            boolean endPar = s.streamTokens().anyMatch(t -> t.getText().equals("*NL*") || t.getText().equals("\n"));
            if (section == null) {
                section = new SectionAnnotation(Annotation.Source.PRED);
                section.setBegin(s.getBegin());
            }
            if (!endPar) continue;
            section.setEnd(s.getEnd());
            doc.addAnnotation((Annotation)section);
            section = null;
        }
        if (section != null) {
            log.warn("found last sentence without newline");
            section.setEnd(doc.getEnd());
            doc.addAnnotation(section);
            section = null;
        }
    }

    private static void applySectionsFromTargetLabels(Document doc, LookupCacheEncoder targetEncoder, int k) {
        String lastSection = "";
        INDArray sectionPredictions = Nd4j.create((long[])new long[]{1L, targetEncoder.getEmbeddingVectorSize()}).transposei();
        int sectionLength = 0;
        SectionAnnotation section = new SectionAnnotation(Annotation.Source.PRED);
        section.setBegin(doc.getBegin());
        for (Sentence s : doc.getSentences()) {
            INDArray pred = s.getVector(targetEncoder.getClass());
            Collection currentSections = targetEncoder.getNearestNeighbours(pred, k);
            if (!currentSections.contains(lastSection)) {
                if (!lastSection.isEmpty()) {
                    doc.addAnnotation((Annotation)section);
                }
                section = new SectionAnnotation(Annotation.Source.PRED);
                section.setBegin(s.getBegin());
                sectionLength = 0;
                sectionPredictions = Nd4j.create((long[])new long[]{1L, targetEncoder.getEmbeddingVectorSize()}).transposei();
            }
            sectionPredictions.addi(pred);
            String currentSection = targetEncoder.getNearestNeighbour(sectionPredictions.div((Number)(++sectionLength)));
            section.setEnd(s.getEnd());
            lastSection = currentSection;
        }
        if (!lastSection.isEmpty()) {
            doc.addAnnotation((Annotation)section);
        }
    }

    private static void applySectionsFromEdges(Document doc, INDArray docEdges) {
        if (doc.countSentences() < 1) {
            log.warn("Empty document");
            return;
        }
        if (docEdges == null || doc.countSentences() < 2) {
            SectionAnnotation section = new SectionAnnotation(Annotation.Source.PRED);
            section.setBegin(doc.getBegin());
            section.setEnd(doc.getEnd());
            doc.addAnnotation((Annotation)section);
            return;
        }
        int sectionLength = 0;
        SectionAnnotation section = new SectionAnnotation(Annotation.Source.PRED);
        section.setBegin(doc.getBegin());
        int t = 0;
        for (Sentence s : doc.getSentences()) {
            if (docEdges.getDouble((long)t) > 0.0) {
                if (sectionLength > 0) {
                    doc.addAnnotation((Annotation)section);
                }
                section = new SectionAnnotation(Annotation.Source.PRED);
                section.setBegin(s.getBegin());
                sectionLength = 0;
            }
            ++sectionLength;
            section.setEnd(s.getEnd());
            ++t;
        }
        if (sectionLength > 0) {
            doc.addAnnotation((Annotation)section);
        }
    }

    private static INDArray detectSectionsFromEmbeddingDeviation(Document doc) {
        int PCA_DIMS = 16;
        if (doc.countSentences() < 2) {
            return null;
        }
        INDArray docEmbs = SectorAnnotator.getEmbeddingMatrix(doc);
        INDArray docPCA = SectorAnnotator.pca(docEmbs, PCA_DIMS);
        INDArray docSmooth = SectorAnnotator.gaussianSmooth(docPCA);
        INDArray docMag = SectorAnnotator.deviation(docSmooth);
        return docMag;
    }

    private static INDArray detectSectionsFromBidirectionalEmbeddingDeviation(Document doc) {
        int PCA_DIMS = 16;
        double SMOOTH_FACTOR = 1.5;
        if (doc.countSentences() < 1) {
            return null;
        }
        Sentence sent = doc.getSentence(0);
        long layerSize = sent.getVector("embeddingFW").length();
        INDArray docFW = Nd4j.zeros((long[])new long[]{doc.countSentences(), layerSize});
        INDArray docBW = Nd4j.zeros((long[])new long[]{doc.countSentences(), layerSize});
        int t = 0;
        for (Sentence s : doc.getSentences()) {
            docFW.getRow((long)t).assign(s.getVector("embeddingFW"));
            docBW.getRow((long)t).assign(s.getVector("embeddingBW"));
            ++t;
        }
        INDArray docFwPCA = docFW.mmul(PCA.pca_factor((INDArray)docFW.dup(), (int)PCA_DIMS, (boolean)false));
        INDArray docBwPCA = docBW.mmul(PCA.pca_factor((INDArray)docBW.dup(), (int)PCA_DIMS, (boolean)false));
        INDArray zeros = Nd4j.zeros((int[])new int[]{docFW.rows(), 1});
        docFwPCA.putColumn(0, zeros);
        docBwPCA.putColumn(0, zeros);
        docFwPCA.putColumn(1, zeros);
        docBwPCA.putColumn(1, zeros);
        INDArray docFwPCAs = SectorAnnotator.gaussianSmooth(docFwPCA, SMOOTH_FACTOR);
        INDArray docBwPCAs = SectorAnnotator.gaussianSmooth(docBwPCA, SMOOTH_FACTOR);
        INDArray docMag = SectorAnnotator.deviation(docFwPCAs, docBwPCAs);
        return docMag;
    }

    protected static INDArray getLayerMatrix(Document doc, String layerClass) {
        Sentence sent = doc.getSentence(0);
        long layerSize = sent.getVector(layerClass).length();
        INDArray docWeights = Nd4j.zeros((long[])new long[]{doc.countSentences(), layerSize});
        int t = 0;
        for (Sentence s : doc.getSentences()) {
            docWeights.getRow((long)t++).assign(s.getVector(layerClass));
        }
        return docWeights;
    }

    protected static INDArray getLayerMatrix(Document doc, Class layerClass) {
        return SectorAnnotator.getLayerMatrix(doc, layerClass.getCanonicalName());
    }

    protected static INDArray getEmbeddingMatrix(Document doc) {
        return SectorAnnotator.getLayerMatrix(doc, SectorEncoder.class);
    }

    protected static INDArray pca(INDArray m, int dimensions) {
        return m.mmul(PCA.pca_factor((INDArray)m.dup(), (int)dimensions, (boolean)true));
    }

    protected static INDArray gaussianSmooth(INDArray target) {
        return SectorAnnotator.gaussianSmooth(target, 2.5);
    }

    protected static INDArray gaussianSmooth(INDArray target, double sd) {
        INDArray matrix = target.dup('c');
        INDArray kernel = Nd4j.zeros((int)matrix.rows(), (int)1, (char)'c');
        INDArray smooth = Nd4j.zerosLike((INDArray)target);
        int t = 0;
        while ((long)t < kernel.length()) {
            NormalDistribution dist = new NormalDistribution((double)t, sd);
            int k = 0;
            while ((long)k < kernel.length()) {
                kernel.putScalar((long)k, dist.density((double)k));
                ++k;
            }
            INDArray conv = matrix.mulColumnVector(kernel);
            smooth.getRow((long)t).assign(conv.sum(new int[]{0}));
            ++t;
        }
        return smooth;
    }

    protected static INDArray deviation(INDArray fw, INDArray bw) {
        INDArray dev = Nd4j.zeros((int[])new int[]{fw.rows(), 1});
        for (int t = 1; t < dev.rows(); ++t) {
            double fwd1 = t < dev.rows() - 1 ? Transforms.cosineDistance((INDArray)fw.getRow((long)t), (INDArray)fw.getRow((long)(t + 1))) : 0.0;
            double bwd1 = t > 2 ? Transforms.cosineDistance((INDArray)bw.getRow((long)(t - 1)), (INDArray)bw.getRow((long)(t - 2))) : 0.0;
            double geom = Math.sqrt(fwd1 * bwd1);
            dev.putScalar((long)t, 0L, Double.isNaN(geom) ? 0.0 : geom);
        }
        return dev;
    }

    protected static INDArray deviation(INDArray target) {
        INDArray dev = Nd4j.zeros((int[])new int[]{target.rows(), 1});
        for (int t = 1; t < dev.rows(); ++t) {
            dev.putScalar((long)t, 0L, Transforms.cosineDistance((INDArray)target.getRow((long)t), (INDArray)target.getRow((long)(t - 1))));
        }
        return dev;
    }

    protected static INDArray detectEdges(INDArray dev) {
        if (dev == null) {
            return null;
        }
        INDArray result = Nd4j.zeros((int[])new int[]{dev.rows(), 1});
        for (int t = 1; t < result.rows() - 1; ++t) {
            result.putScalar((long)t, 0L, dev.getDouble((long)(t - 1)) < dev.getDouble((long)t) && dev.getDouble((long)(t + 1)) < dev.getDouble((long)t) ? 1.0 : 0.0);
        }
        result.putScalar(0L, 0L, 1.0);
        return result;
    }

    protected static INDArray detectEdges(INDArray dev, int count) {
        int idx;
        int i;
        if (dev == null) {
            return null;
        }
        INDArray peaks = Nd4j.zeros((int[])new int[]{dev.rows(), 1});
        for (int t = 1; t < peaks.rows() - 1; ++t) {
            if (dev.getDouble((long)(t - 1)) < dev.getDouble((long)t) && dev.getDouble((long)(t + 1)) < dev.getDouble((long)t)) {
                peaks.putScalar((long)t, 0L, dev.getDouble((long)t));
                continue;
            }
            peaks.putScalar((long)t, 0L, 0.0);
        }
        INDArray result = Nd4j.zeros((int[])new int[]{dev.rows(), 1});
        INDArray[] p = Nd4j.sortWithIndices((INDArray)Nd4j.toFlattened((INDArray[])new INDArray[]{peaks}).dup(), (int)1, (boolean)false);
        INDArray sortedPeaks = p[0];
        INDArray[] m = Nd4j.sortWithIndices((INDArray)Nd4j.toFlattened((INDArray[])new INDArray[]{dev}).dup(), (int)1, (boolean)false);
        INDArray sortedMags = m[0];
        for (i = 0; i < count - 1; ++i) {
            idx = sortedPeaks.getInt(new int[]{i});
            if (idx == 0) continue;
            if (peaks.getDouble((long)idx) == 0.0) break;
            result.putScalar((long)idx, 0L, 1.0);
        }
        i = 0;
        while (i < dev.rows() && result.sumNumber().intValue() < count - 1) {
            if ((idx = sortedMags.getInt(new int[]{i++})) == 0 || result.getDouble((long)idx) == 1.0) continue;
            result.putScalar((long)idx, 0L, 1.0);
        }
        result.putScalar(0L, 0L, 1.0);
        return result;
    }

    protected static INDArray deltaMatrix(INDArray data) {
        INDArray result = Nd4j.zeros((int[])new int[]{data.rows(), 1});
        INDArray prev = Nd4j.zeros((int[])new int[]{data.columns()});
        for (int t = 0; t < data.rows(); ++t) {
            INDArray vec = data.getRow((long)t);
            result.putScalar((long)t, 0L, Transforms.cosineDistance((INDArray)prev, (INDArray)vec));
            prev = vec.dup();
        }
        result.putScalar(0L, 0L, 1.0);
        return result;
    }

    public static class Builder {
        SectorAnnotator ann;
        SectorTagger tagger;
        protected Encoder[] encoders = new Encoder[0];
        protected ILossFunction lossFunc = LossFunctions.LossFunction.MCXENT.getILossFunction();
        protected Activation activation = Activation.SOFTMAX;
        protected boolean requireSubsampling = false;
        private int examplesPerEpoch = -1;
        private int maxTimeSeriesLength = -1;
        private int ffwLayerSize = 0;
        private int lstmLayerSize = 256;
        private int embeddingLayerSize = 128;
        private double learningRate = 0.01;
        private double dropOut = 0.5;
        private int iterations = 1;
        private int batchSize = 16;
        private int numEpochs = 1;
        private boolean enabletrainingUI = false;

        public Builder() {
            this.tagger = new SectorTagger();
            this.ann = new SectorAnnotator(this.tagger);
        }

        public Builder withId(String id) {
            this.tagger.setId(id);
            return this;
        }

        public Builder withDataset(String datasetName, WordHelpers.Language lang) {
            this.ann.getProvenance().setDataset(datasetName);
            this.ann.getProvenance().setLanguage(lang.toString().toLowerCase());
            return this;
        }

        public Builder withLossFunction(LossFunctions.LossFunction lossFunc, Activation activation, boolean requireSubsampling) {
            this.lossFunc = lossFunc.getILossFunction();
            this.requireSubsampling = requireSubsampling;
            this.activation = activation;
            return this;
        }

        public Builder withLossFunction(ILossFunction lossFunc, Activation activation, boolean requireSubsampling) {
            this.lossFunc = lossFunc;
            this.requireSubsampling = requireSubsampling;
            this.activation = activation;
            return this;
        }

        public Builder withModelParams(int ffwLayerSize, int lstmLayerSize, int embeddingLayerSize) {
            this.ffwLayerSize = ffwLayerSize;
            this.lstmLayerSize = lstmLayerSize;
            this.embeddingLayerSize = embeddingLayerSize;
            return this;
        }

        public Builder withTrainingParams(double learningRate, double dropOut, int examplesPerEpoch, int batchSize, int numEpochs) {
            this.learningRate = learningRate;
            this.dropOut = dropOut;
            this.examplesPerEpoch = examplesPerEpoch;
            this.batchSize = batchSize;
            this.numEpochs = numEpochs;
            return this;
        }

        public Builder withTrainingParams(double learningRate, double dropOut, int examplesPerEpoch, int maxTimeSeriesLength, int batchSize, int numEpochs) {
            this.learningRate = learningRate;
            this.dropOut = dropOut;
            this.examplesPerEpoch = examplesPerEpoch;
            this.batchSize = batchSize;
            this.maxTimeSeriesLength = maxTimeSeriesLength;
            this.numEpochs = numEpochs;
            return this;
        }

        public Builder withInputEncoders(String desc, Encoder bagEncoder, Encoder embEncoder, Encoder flagEncoder) {
            this.tagger.setInputEncoders(bagEncoder, embEncoder, flagEncoder);
            this.ann.getProvenance().setFeatures(desc);
            this.ann.addComponent((AnnotatorComponent)bagEncoder);
            this.ann.addComponent((AnnotatorComponent)embEncoder);
            this.ann.addComponent((AnnotatorComponent)flagEncoder);
            return this;
        }

        public Builder withTargetEncoder(Encoder targetEncoder) {
            this.tagger.setTargetEncoder(targetEncoder);
            this.ann.addComponent((AnnotatorComponent)targetEncoder);
            return this;
        }

        public Builder enableTrainingUI(boolean enable) {
            this.enabletrainingUI = enable;
            return this;
        }

        public Builder pretrain(Dataset train) {
            for (Encoder e : this.encoders) {
                e.trainModel(train.streamDocuments());
            }
            return this;
        }

        public SectorAnnotator build() {
            this.tagger.buildSECTORModel(this.ffwLayerSize, this.lstmLayerSize, this.embeddingLayerSize, this.iterations, this.learningRate, this.dropOut, this.lossFunc, this.activation);
            if (this.enabletrainingUI) {
                this.tagger.enableTrainingUI();
            }
            this.tagger.setRequireSubsampling(this.requireSubsampling);
            this.tagger.setTrainingParams(this.examplesPerEpoch, this.maxTimeSeriesLength, this.batchSize, this.numEpochs, true);
            this.ann.getProvenance().setTask(this.tagger.getId());
            this.tagger.setName(this.ann.getProvenance().toString());
            this.tagger.appendTrainLog(this.printParams());
            return this.ann;
        }

        private String printParams() {
            StringBuilder line = new StringBuilder();
            line.append("TRAINING PARAMS: ").append(this.tagger.getName()).append("\n");
            line.append("\nEncoders:\n");
            for (Encoder e : this.tagger.getEncoders()) {
                line.append(e.getId()).append("\t").append(e.getClass().getSimpleName()).append("\t").append(e.getEmbeddingVectorSize()).append("\n");
            }
            line.append("\nNetwork Params:\n");
            line.append("FF").append("\t").append(this.ffwLayerSize).append("\n");
            line.append("BLSTM").append("\t").append(this.lstmLayerSize).append("\n");
            line.append("EMB").append("\t").append(this.embeddingLayerSize).append("\n");
            line.append("\nTraining Params:\n");
            line.append("examples per epoch").append("\t").append(this.examplesPerEpoch).append("\n");
            line.append("max time series length").append("\t").append(this.maxTimeSeriesLength).append("\n");
            line.append("epochs").append("\t").append(this.numEpochs).append("\n");
            line.append("iterations").append("\t").append(this.iterations).append("\n");
            line.append("batch size").append("\t").append(this.batchSize).append("\n");
            line.append("learning rate").append("\t").append(this.learningRate).append("\n");
            line.append("dropout").append("\t").append(this.dropOut).append("\n");
            line.append("loss").append("\t").append(this.lossFunc.toString()).append(this.requireSubsampling ? " (1-hot subsampled)" : " (1-hot/n-hot)").append("\n");
            line.append("\n");
            return line.toString();
        }
    }

    public static enum SegmentationMethod {
        NONE,
        GOLD,
        NL,
        MAX,
        EMD,
        BEMD,
        BEMD_FIXED;

    }
}

