/*
 * Decompiled with CFR 0.152.
 */
package de.datexis.ner.tagger;

import com.google.common.collect.Lists;
import de.datexis.encoder.Encoder;
import de.datexis.encoder.EncoderSet;
import de.datexis.model.Annotation;
import de.datexis.model.Dataset;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Snippet;
import de.datexis.model.Token;
import de.datexis.model.tag.BIO2Tag;
import de.datexis.model.tag.BIOESTag;
import de.datexis.model.tag.Tag;
import de.datexis.ner.MentionAnnotation;
import de.datexis.ner.eval.MentionAnnotatorEval;
import de.datexis.ner.eval.MentionTaggerEval;
import de.datexis.ner.tagger.MentionTaggerIterator;
import de.datexis.tagger.AbstractIterator;
import de.datexis.tagger.Tagger;
import java.util.Collection;
import org.apache.commons.lang3.tuple.Pair;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MentionTagger
extends Tagger {
    protected static final Logger log = LoggerFactory.getLogger(MentionTagger.class);
    protected int workers = 1;
    protected Class<? extends Tag> tagset = BIOESTag.class;
    protected String type = "GENERIC";

    public MentionTagger() {
        this("BLSTM");
        this.setTagset(BIOESTag.class, "GENERIC");
    }

    public MentionTagger(String id) {
        super(id);
        this.setTagset(BIOESTag.class, "GENERIC");
    }

    public MentionTagger(AbstractIterator data, int ffwLayerSize, int lstmLayerSize, int iterations, double learningRate) {
        super(data.getInputSize(), data.getLabelSize());
        this.net = MentionTagger.createBLSTM(this.inputVectorSize, ffwLayerSize, lstmLayerSize, this.outputVectorSize, iterations, learningRate);
    }

    public MentionTagger setModelParams(int ffwLayerSize, int lstmLayerSize, int iterations, double learningRate) {
        this.net = MentionTagger.createBLSTM(this.inputVectorSize, ffwLayerSize, lstmLayerSize, this.outputVectorSize, iterations, learningRate);
        return this;
    }

    public Class<? extends Tag> getTagset() {
        return this.tagset;
    }

    public static ComputationGraph createBLSTM(long inputVectorSize, long ffwLayerSize, long lstmLayerSize, long outputVectorSize, int iterations, double learningRate) {
        log.info("initializing BLSTM network " + inputVectorSize + ":" + ffwLayerSize + ":" + ffwLayerSize + ":" + lstmLayerSize + ":" + outputVectorSize);
        ComputationGraphConfiguration.GraphBuilder gb = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater((IUpdater)new Adam(learningRate, 0.9, 0.999, 1.0E-8)).l2(1.0E-4).trainingWorkspaceMode(WorkspaceMode.ENABLED).inferenceWorkspaceMode(WorkspaceMode.ENABLED).graphBuilder().addInputs(new String[]{"input"});
        if (ffwLayerSize > 0L) {
            gb.addLayer("FF1", (Layer)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)new DenseLayer.Builder().nIn(inputVectorSize)).nOut(ffwLayerSize)).activation(Activation.RELU)).weightInit(WeightInit.RELU)).build(), new String[]{"input"}).addLayer("FF2", (Layer)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)new DenseLayer.Builder().nIn(ffwLayerSize)).nOut(ffwLayerSize)).activation(Activation.RELU)).weightInit(WeightInit.RELU)).build(), new String[]{"FF1"}).addLayer("BLSTM", (Layer)new Bidirectional(Bidirectional.Mode.AVERAGE, (Layer)((LSTM.Builder)((LSTM.Builder)((LSTM.Builder)((LSTM.Builder)new LSTM.Builder().nIn(ffwLayerSize)).nOut(lstmLayerSize)).activation(Activation.TANH)).weightInit(WeightInit.XAVIER)).build()), new String[]{"FF2"});
        } else {
            gb.addLayer("BLSTM", (Layer)new Bidirectional(Bidirectional.Mode.AVERAGE, (Layer)((LSTM.Builder)((LSTM.Builder)((LSTM.Builder)((LSTM.Builder)new LSTM.Builder().nIn(inputVectorSize)).nOut(lstmLayerSize)).activation(Activation.TANH)).weightInit(WeightInit.XAVIER)).build()), new String[]{"input"});
        }
        gb.addLayer("output", (Layer)((RnnOutputLayer.Builder)((RnnOutputLayer.Builder)((RnnOutputLayer.Builder)((RnnOutputLayer.Builder)new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(lstmLayerSize)).nOut(outputVectorSize)).activation(Activation.SOFTMAX)).weightInit(WeightInit.XAVIER)).build(), new String[]{"BLSTM"}).setOutputs(new String[]{"output"}).setInputTypes(new InputType[]{InputType.recurrent((long)inputVectorSize)}).backpropType(BackpropType.Standard).build();
        ComputationGraphConfiguration conf = gb.build();
        ComputationGraph lstm = new ComputationGraph(conf);
        lstm.init();
        return lstm;
    }

    public String getType() {
        return this.type;
    }

    public void setType(String type) {
        this.type = type;
    }

    public MentionTagger setTagset(Class<? extends Tag> tagset) {
        this.tagset = tagset;
        try {
            this.outputVectorSize = tagset.newInstance().getVectorSize();
        }
        catch (Exception ex) {
            log.error("Could not set output vector size");
        }
        return this;
    }

    public MentionTagger setTagset(Class<? extends Tag> tagset, String types) {
        this.setTagset(tagset);
        this.type = types;
        return this;
    }

    public MentionTagger setTrainingParams(int batchSize, int numEpochs, boolean randomize) {
        this.batchSize = batchSize;
        this.numEpochs = numEpochs;
        this.randomize = randomize;
        return this;
    }

    public MentionTagger setWorkspaceParams(int workers) {
        this.workers = workers;
        return this;
    }

    @JsonIgnore
    @Deprecated
    public EncoderSet getEncoderSet() {
        return new EncoderSet(this.getEncoders().toArray(new Encoder[0]));
    }

    public void trainModel(Dataset dataset) {
        this.trainModel(dataset, Annotation.Source.GOLD);
    }

    public void trainModel(Dataset dataset, Annotation.Source trainingAnnotations) {
        this.trainModel(dataset, trainingAnnotations, -1, this.randomize);
    }

    public void trainModel(Dataset dataset, Annotation.Source trainingAnnotations, int numExamples, boolean randomize) {
        this.trainModel(new MentionTaggerIterator(dataset.getDocuments(), dataset.getName(), this.getEncoderSet(), this.tagset, trainingAnnotations, numExamples, this.batchSize, randomize));
    }

    public void trainModel(Collection<Sentence> sentences, Annotation.Source trainingTags, boolean randomize) {
        this.trainModel(new MentionTaggerIterator(Lists.newArrayList((Object[])new Document[]{new Snippet(sentences, randomize)}), "training", this.getEncoderSet(), this.tagset, trainingTags, -1, this.batchSize, randomize));
    }

    protected void trainModel(MentionTaggerIterator it) {
        int batches = it.numExamples() / it.batch();
        int n = 0;
        this.appendTrainLog("Training " + this.getName() + " with " + it.numExamples() + " examples in " + batches + " batches for " + this.numEpochs + " epochs.");
        ParallelWrapper wrapper = null;
        if (this.workers > 1) {
            wrapper = new ParallelWrapper.Builder(this.net).prefetchBuffer(this.workers * 4).workers(this.workers).trainingMode(ParallelWrapper.TrainingMode.AVERAGING).workspaceMode(WorkspaceMode.ENABLED).build();
        }
        this.timer.start();
        for (int i = 1; i <= this.numEpochs; ++i) {
            this.timer.setSplit("epoch");
            if (wrapper != null) {
                wrapper.fit((DataSetIterator)it);
            } else if (this.net instanceof ComputationGraph) {
                ((ComputationGraph)this.net).fit((DataSetIterator)it);
            } else if (this.net instanceof MultiLayerNetwork) {
                ((MultiLayerNetwork)this.net).fit((DataSetIterator)it);
            }
            this.appendTrainLog("Completed epoch " + i + " of " + this.numEpochs + "\t" + (n += it.numExamples()), this.timer.getLong("epoch"));
            it.reset();
        }
        this.timer.stop();
        this.appendTrainLog("Training complete", this.timer.getLong());
        this.setModelAvailable(true);
    }

    public synchronized void tag(Collection<Document> documents) {
        log.debug("Labeling Documents...");
        MentionTaggerIterator it = new MentionTaggerIterator(documents, "train", this.getEncoderSet(), this.tagset, -1, this.batchSize, false);
        it.reset();
        while (it.hasNext()) {
            Pair examples = it.nextDataSet();
            INDArray input = ((DataSet)examples.getKey()).getFeatures();
            INDArray inputMask = ((DataSet)examples.getKey()).getFeaturesMaskArray();
            INDArray labelsMask = ((DataSet)examples.getKey()).getLabelsMaskArray();
            INDArray predicted = null;
            if (this.net instanceof MultiLayerNetwork) {
                predicted = ((MultiLayerNetwork)this.net).output(input, false, inputMask, labelsMask);
            } else if (this.net instanceof ComputationGraph) {
                ((ComputationGraph)this.net).setLayerMaskArrays(new INDArray[]{inputMask}, new INDArray[]{labelsMask});
                predicted = ((ComputationGraph)this.net).outputSingle(new INDArray[]{input});
            }
            MentionTagger.createTags((Iterable)examples.getValue(), predicted, it.getTagset(), Annotation.Source.PRED, this.type, false, true);
        }
        for (Document doc : it.getDocuments()) {
            doc.setTagAvailable(Annotation.Source.PRED, it.getTagset(), true);
            if (this.tagset.equals(BIO2Tag.class)) continue;
            doc.setTagAvailable(Annotation.Source.PRED, BIO2Tag.class, true);
        }
    }

    public void tagSentences(Collection<Sentence> sentences) {
        this.tag(Lists.newArrayList((Object[])new Document[]{new Snippet(sentences, false)}));
    }

    public void testModel(Dataset dataset, Annotation.Source expected) {
        MentionTaggerIterator it = new MentionTaggerIterator(dataset.getDocuments(), dataset.getName(), this.getEncoderSet(), this.tagset, -1, this.batchSize, false);
        this.test(it);
        MentionTaggerEval eval = new MentionTaggerEval(this.getName(), this.tagset);
        eval.calculateMeasures(dataset);
        this.appendTestLog(eval.printExperimentStats());
        this.appendTestLog(eval.printDatasetStats());
        this.appendTestLog(eval.printTrainingCurve());
        this.appendTestLog(eval.printSequenceClassStats(false));
        MentionAnnotatorEval annE = new MentionAnnotatorEval(this.getName());
        for (Document doc : dataset.getDocuments()) {
            if (doc.countAnnotations(expected) == 0L) {
                MentionAnnotation.annotateFromTags(doc, expected, BIO2Tag.class, this.type);
            }
            doc.clearAnnotations(Annotation.Source.PRED, MentionAnnotation.class);
            MentionAnnotation.annotateFromTags(doc, Annotation.Source.PRED, BIO2Tag.class, this.type);
        }
        annE.setTestDataset(dataset, 0L, 0L);
        annE.evaluateAnnotations();
        this.appendTestLog(annE.printAnnotationStats());
    }

    public Evaluation test(MentionTaggerIterator it) {
        this.timer.start();
        this.appendTrainLog("Evaluating " + this.getName() + " with " + it.numExamples() + " examples...");
        Evaluation eval = new Evaluation((double)it.getLabelSize());
        it.reset();
        while (it.hasNext()) {
            Pair examples = it.nextDataSet();
            INDArray input = ((DataSet)examples.getKey()).getFeatures();
            INDArray labels = ((DataSet)examples.getKey()).getLabels();
            INDArray inputMask = ((DataSet)examples.getKey()).getFeaturesMaskArray();
            INDArray labelsMask = ((DataSet)examples.getKey()).getLabelsMaskArray();
            INDArray predicted = null;
            if (this.net instanceof MultiLayerNetwork) {
                predicted = ((MultiLayerNetwork)this.net).output(input, false, inputMask, labelsMask);
            } else if (this.net instanceof ComputationGraph) {
                ((ComputationGraph)this.net).setLayerMaskArrays(new INDArray[]{inputMask}, new INDArray[]{labelsMask});
                predicted = ((ComputationGraph)this.net).outputSingle(new INDArray[]{input});
            }
            try {
                eval.evalTimeSeries(labels, predicted, labelsMask);
            }
            catch (IllegalStateException ex) {
                log.warn(ex.toString());
            }
            MentionTagger.createTags((Iterable)examples.getValue(), predicted, it.getTagset(), Annotation.Source.PRED, this.type, true, true);
        }
        for (Document doc : it.getDocuments()) {
            doc.setTagAvailable(Annotation.Source.PRED, it.getTagset(), true);
            if (this.tagset.equals(BIO2Tag.class)) continue;
            doc.setTagAvailable(Annotation.Source.PRED, BIO2Tag.class, true);
        }
        this.timer.stop();
        this.appendTrainLog("Evaluation complete", this.timer.getLong());
        return eval;
    }

    public void enableTrainingUI() {
        InMemoryStatsStorage stats = new InMemoryStatsStorage();
        this.net.addListeners(new TrainingListener[]{new StatsListener((StatsStorageRouter)stats, 1)});
        UIServer.getInstance().attach((StatsStorage)stats);
    }

    public static void createTags(Iterable<Sentence> sents, INDArray predicted, Class tagset, Annotation.Source source, String type, boolean keepVectors, boolean convertTags) {
        int batchNum = 0;
        int t = 0;
        for (Sentence s : sents) {
            for (Token token : s.getTokens()) {
                INDArray vec = predicted.getRow((long)batchNum).getColumn((long)t++);
                if (tagset.equals(BIO2Tag.class)) {
                    token.putTag(source, (Tag)new BIO2Tag(vec, type, true));
                }
                if (!tagset.equals(BIOESTag.class)) continue;
                token.putTag(source, (Tag)new BIOESTag(vec, type, true));
            }
            t = 0;
            ++batchNum;
            if (tagset.equals(BIOESTag.class)) {
                BIOESTag.correctCRF((Sentence)s, (Annotation.Source)source);
                if (convertTags) {
                    BIOESTag.convertToBIO2((Sentence)s, (Annotation.Source)source);
                }
            }
            if (keepVectors) continue;
        }
    }
}

