/*
 * Decompiled with CFR 0.152.
 */
package de.datexis.sector.tagger;

import com.google.common.collect.Lists;
import de.datexis.common.Resource;
import de.datexis.encoder.Encoder;
import de.datexis.encoder.LookupCacheEncoder;
import de.datexis.evaluation.ModelEvaluation;
import de.datexis.model.Dataset;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.sector.eval.ClassificationScoreCalculator;
import de.datexis.sector.tagger.DocumentSentenceIterator;
import de.datexis.sector.tagger.SectorEncoder;
import de.datexis.sector.tagger.SectorTaggerIterator;
import de.datexis.tagger.AbstractMultiDataSetIterator;
import de.datexis.tagger.Tagger;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.earlystopping.listener.EarlyStoppingListener;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingGraphTrainer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.graph.PreprocessorVertex;
import org.deeplearning4j.nn.conf.graph.SubsetVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.schedule.ExponentialSchedule;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.linalg.schedule.ScheduleType;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SectorTagger
extends Tagger {
    protected static final Logger log = LoggerFactory.getLogger(SectorTagger.class);
    protected Encoder bagEncoder = null;
    protected Encoder embEncoder = null;
    protected Encoder flagEncoder = null;
    protected Encoder targetEncoder = null;
    protected int workers = 4;
    protected boolean requireSubsampling;
    protected ModelEvaluation eval = new ModelEvaluation("null");
    protected final FeedForwardToRnnPreProcessor ff2rnn = new FeedForwardToRnnPreProcessor();

    public SectorTagger() {
        super("SECTOR");
    }

    public SectorTagger(String id) {
        super(id);
    }

    public SectorTagger(Resource modelPath) {
        super(modelPath);
        this.setId("SECTOR");
    }

    @JsonIgnore
    public ComputationGraph getNN() {
        return (ComputationGraph)this.net;
    }

    public boolean isRequireSubsampling() {
        return this.requireSubsampling;
    }

    public void setRequireSubsampling(boolean requireSubsampling) {
        this.requireSubsampling = requireSubsampling;
    }

    public void setInputEncoders(Encoder bagEncoder, Encoder embEncoder, Encoder flagEncoder) {
        this.bagEncoder = bagEncoder;
        this.embEncoder = embEncoder;
        this.flagEncoder = flagEncoder;
    }

    public void setTargetEncoder(Encoder targetEncoder) {
        this.targetEncoder = targetEncoder;
    }

    public SectorTagger setWorkspaceParams(int workers) {
        this.workers = workers;
        return this;
    }

    @JsonIgnore
    public List<Encoder> getEncoders() {
        return Lists.newArrayList((Object[])new Encoder[]{this.bagEncoder, this.embEncoder, this.flagEncoder, this.targetEncoder});
    }

    @JsonIgnore
    public Encoder getTargetEncoder() {
        return this.targetEncoder;
    }

    public void setEncoders(List<Encoder> encoders) {
        if (encoders.size() != 4) {
            throw new IllegalArgumentException("wrong number of encoders given (expected=4, actual=" + encoders.size() + ")");
        }
        this.bagEncoder = encoders.get(0);
        this.embEncoder = encoders.get(1);
        this.flagEncoder = encoders.get(2);
        this.targetEncoder = encoders.get(3);
    }

    public SectorTagger buildSECTORModel(int ffwLayerSize, int lstmLayerSize, int embeddingLayerSize, int iterations, double learningRate, double dropout, ILossFunction lossFunc, Activation activation) {
        long sentenceVectorSize;
        log.info("initializing graph with layer sizes bag={}, lstm={}, emb={} and {} loss", new Object[]{ffwLayerSize, lstmLayerSize, embeddingLayerSize, lossFunc.name()});
        this.embeddingLayerSize = embeddingLayerSize;
        ComputationGraphConfiguration.GraphBuilder gb = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater((IUpdater)new Adam((ISchedule)new ExponentialSchedule(ScheduleType.EPOCH, learningRate, 0.85))).weightInit(WeightInit.XAVIER).l2(1.0E-5).dropOut(dropout).gradientNormalization(GradientNormalization.ClipL2PerLayer).trainingWorkspaceMode(WorkspaceMode.ENABLED).inferenceWorkspaceMode(WorkspaceMode.ENABLED).cacheMode(CacheMode.HOST).graphBuilder().addInputs(new String[]{"bag"}).addInputs(new String[]{"emb"}).addInputs(new String[]{"flag"});
        if (ffwLayerSize > 0) {
            sentenceVectorSize = (long)ffwLayerSize + this.embEncoder.getEmbeddingVectorSize() + this.flagEncoder.getEmbeddingVectorSize();
            gb.addLayer("FF1", (Layer)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)new DenseLayer.Builder().nIn(this.bagEncoder.getEmbeddingVectorSize())).nOut(ffwLayerSize)).activation(Activation.ELU)).weightInit(WeightInit.RELU)).build(), new String[]{"bag"}).addLayer("FF2", (Layer)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)new DenseLayer.Builder().nIn(ffwLayerSize)).nOut(ffwLayerSize)).activation(Activation.ELU)).weightInit(WeightInit.RELU)).build(), new String[]{"FF1"}).addVertex("surf", (GraphVertex)new PreprocessorVertex((InputPreProcessor)new FeedForwardToRnnPreProcessor()), new String[]{"FF2"}).addVertex("sentence", (GraphVertex)new MergeVertex(), new String[]{"surf", "emb", "flag"});
        } else {
            sentenceVectorSize = this.bagEncoder.getEmbeddingVectorSize() + this.embEncoder.getEmbeddingVectorSize() + this.flagEncoder.getEmbeddingVectorSize();
            gb.addVertex("sentence", (GraphVertex)new MergeVertex(), new String[]{"bag", "emb", "flag"});
        }
        gb.addLayer("BLSTM", (Layer)new Bidirectional(Bidirectional.Mode.CONCAT, (Layer)((LSTM.Builder)((LSTM.Builder)((LSTM.Builder)((LSTM.Builder)new LSTM.Builder().nIn(sentenceVectorSize)).nOut(lstmLayerSize)).activation(Activation.TANH)).gateActivationFunction(Activation.SIGMOID)).build()), new String[]{"sentence"});
        gb.addVertex("FW", (GraphVertex)new SubsetVertex(0, lstmLayerSize - 1), new String[]{"BLSTM"});
        gb.addVertex("BW", (GraphVertex)new SubsetVertex(lstmLayerSize, 2 * lstmLayerSize - 1), new String[]{"BLSTM"});
        if (this.embeddingLayerSize > 0) {
            gb.addLayer("embeddingFW", (Layer)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)new DenseLayer.Builder().nIn(lstmLayerSize)).nOut(embeddingLayerSize)).activation(Activation.TANH)).build(), new String[]{"FW"}).addLayer("embeddingBW", (Layer)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)new DenseLayer.Builder().nIn(lstmLayerSize)).nOut(embeddingLayerSize)).activation(Activation.TANH)).build(), new String[]{"BW"});
            gb.addLayer("targetFW", (Layer)((RnnOutputLayer.Builder)((RnnOutputLayer.Builder)((RnnOutputLayer.Builder)((RnnOutputLayer.Builder)new RnnOutputLayer.Builder(lossFunc).nIn(embeddingLayerSize)).nOut(this.targetEncoder.getEmbeddingVectorSize())).activation(activation)).weightInit(WeightInit.SIGMOID_UNIFORM)).build(), new String[]{"embeddingFW"}).addLayer("targetBW", (Layer)((RnnOutputLayer.Builder)((RnnOutputLayer.Builder)((RnnOutputLayer.Builder)((RnnOutputLayer.Builder)new RnnOutputLayer.Builder(lossFunc).nIn(embeddingLayerSize)).nOut(this.targetEncoder.getEmbeddingVectorSize())).activation(activation)).weightInit(WeightInit.SIGMOID_UNIFORM)).build(), new String[]{"embeddingBW"});
        } else {
            gb.addLayer("targetFW", (Layer)((RnnOutputLayer.Builder)((RnnOutputLayer.Builder)((RnnOutputLayer.Builder)((RnnOutputLayer.Builder)new RnnOutputLayer.Builder(lossFunc).nIn(lstmLayerSize)).nOut(this.targetEncoder.getEmbeddingVectorSize())).activation(activation)).weightInit(WeightInit.SIGMOID_UNIFORM)).build(), new String[]{"FW"}).addLayer("targetBW", (Layer)((RnnOutputLayer.Builder)((RnnOutputLayer.Builder)((RnnOutputLayer.Builder)((RnnOutputLayer.Builder)new RnnOutputLayer.Builder(lossFunc).nIn(lstmLayerSize)).nOut(this.targetEncoder.getEmbeddingVectorSize())).activation(activation)).weightInit(WeightInit.SIGMOID_UNIFORM)).build(), new String[]{"BW"});
        }
        gb.setOutputs(new String[]{"targetFW", "targetBW"}).setInputTypes(new InputType[]{InputType.recurrent((long)this.inputVectorSize), InputType.recurrent((long)this.inputVectorSize), InputType.recurrent((long)this.inputVectorSize)}).backpropType(BackpropType.Standard);
        ComputationGraphConfiguration conf = gb.build();
        ComputationGraph lstm = new ComputationGraph(conf);
        lstm.init();
        this.net = lstm;
        this.net.setListeners(new TrainingListener[]{new PerformanceListener(128, true), new ScoreIterationListener(16)});
        return this;
    }

    public void trainModel(Dataset dataset) {
        this.trainModel(dataset, this.numEpochs);
    }

    public void trainModel(Dataset dataset, int numEpochs) {
        SectorTaggerIterator it = new SectorTaggerIterator(AbstractMultiDataSetIterator.Stage.TRAIN, dataset.getDocuments(), this, this.numExamples, this.maxTimeSeriesLength, this.batchSize, true, this.requireSubsampling);
        int batches = this.numExamples / this.batchSize;
        this.timer.start();
        this.appendTrainLog("Training " + this.getName() + " with " + this.numExamples + " examples in " + batches + " batches for " + numEpochs + " epochs.");
        int n = 0;
        Nd4j.getMemoryManager().togglePeriodicGc(false);
        for (int i = 1; i <= numEpochs; ++i) {
            this.appendTrainLog("Starting epoch " + i + " of " + numEpochs);
            this.triggerEpochListeners(true, i - 1);
            this.getNN().fit((MultiDataSetIterator)it);
            n += this.numExamples;
            this.timer.setSplit("epoch");
            this.appendTrainLog("Completed epoch " + i + " of " + numEpochs, this.timer.getLong("epoch"));
            this.triggerEpochListeners(false, i - 1);
            if (i < numEpochs) {
                it.reset();
            }
            Nd4j.getMemoryManager().invokeGc();
        }
        this.timer.stop();
        this.appendTrainLog("Training complete", this.timer.getLong());
        Nd4j.getMemoryManager().togglePeriodicGc(true);
        this.setModelAvailable(true);
    }

    public EarlyStoppingResult<ComputationGraph> trainModel(Dataset train, Dataset validation, EarlyStoppingConfiguration conf) {
        SectorTaggerIterator trainIt = new SectorTaggerIterator(AbstractMultiDataSetIterator.Stage.TRAIN, train.getDocuments(), this, this.numExamples, this.maxTimeSeriesLength, this.batchSize, true, this.requireSubsampling);
        SectorTaggerIterator validationIt = new SectorTaggerIterator(AbstractMultiDataSetIterator.Stage.TEST, validation.getDocuments(), this, -1, this.maxTimeSeriesLength, this.batchSize, false, this.requireSubsampling);
        int batches = trainIt.getNumExamples() / this.batchSize;
        this.timer.start();
        this.appendTrainLog("Training " + this.getName() + " with " + trainIt.getNumExamples() + " examples in " + batches + " batches using early stopping.");
        conf.setScoreCalculator((ScoreCalculator)new ClassificationScoreCalculator((Tagger)this, (LookupCacheEncoder)this.targetEncoder, (MultiDataSetIterator)validationIt));
        EarlyStoppingListener<ComputationGraph> listener = new EarlyStoppingListener<ComputationGraph>(){

            public void onStart(EarlyStoppingConfiguration<ComputationGraph> conf, ComputationGraph net) {
            }

            public void onEpoch(int epochNum, double score, EarlyStoppingConfiguration<ComputationGraph> conf, ComputationGraph net) {
                Nd4j.getMemoryManager().invokeGc();
            }

            public void onCompletion(EarlyStoppingResult<ComputationGraph> result) {
                log.info("Finished training with result {}", (Object)result.toString());
            }
        };
        EarlyStoppingGraphTrainer trainer = new EarlyStoppingGraphTrainer(conf, this.getNN(), (MultiDataSetIterator)trainIt, (EarlyStoppingListener)listener);
        Nd4j.getMemoryManager().togglePeriodicGc(false);
        EarlyStoppingResult result = trainer.fit();
        Nd4j.getMemoryManager().togglePeriodicGc(true);
        this.timer.stop();
        this.appendTrainLog("Training complete", this.timer.getLong());
        this.net = result.getBestModel();
        this.setModelAvailable(true);
        return result;
    }

    public void testModel(Dataset dataset) {
        this.timer.start();
        this.attachVectors(dataset.getDocuments(), AbstractMultiDataSetIterator.Stage.TEST, this.targetEncoder.getClass());
        this.timer.stop();
        this.appendTestLog("Testing complete", this.timer.getLong());
    }

    public void tag(Collection<Document> docs) {
        throw new UnsupportedOperationException("not implemented");
    }

    public Map<String, INDArray> encodeMatrix(DocumentSentenceIterator.DocumentBatch batch) {
        MultiDataSet next = batch.dataset;
        Map<String, INDArray> weights = SectorTagger.feedForward(this.getNN(), next);
        if (weights.containsKey("embedding")) {
            weights.put("embedding", this.ff2rnn.preProcess(weights.get("embedding"), batch.size, LayerWorkspaceMgr.noWorkspaces()));
        } else if (weights.containsKey("embeddingFW")) {
            INDArray fw = this.ff2rnn.preProcess(weights.get("embeddingFW"), batch.size, LayerWorkspaceMgr.noWorkspaces());
            INDArray bw = this.ff2rnn.preProcess(weights.get("embeddingBW"), batch.size, LayerWorkspaceMgr.noWorkspaces());
            weights.put("embeddingFW", fw);
            weights.put("embeddingBW", bw);
            weights.put("embedding", fw.add(bw).divi((Number)2));
        }
        return weights;
    }

    public static Map<String, INDArray> feedForward(ComputationGraph net, MultiDataSet next) {
        INDArray[] features = next.getFeatures();
        INDArray[] featuresMasks = next.getFeaturesMaskArrays();
        INDArray[] labelMasks = next.getLabelsMaskArrays();
        net.setLayerMaskArrays(featuresMasks, labelMasks);
        Map weights = net.feedForward(features, false, true);
        if (!weights.containsKey("target") && weights.containsKey("targetFW")) {
            INDArray fw = (INDArray)weights.get("targetFW");
            INDArray bw = (INDArray)weights.get("targetBW");
            weights.put("target", fw.add(bw).divi((Number)2));
        }
        return weights;
    }

    protected void triggerEpochListeners(boolean epochStart, int epochNum) {
        Collection listeners = this.getNN().getListeners();
        this.getNN().getConfiguration().setEpochCount(epochNum);
        if (listeners != null && !listeners.isEmpty()) {
            for (TrainingListener l : listeners) {
                if (epochStart) {
                    l.onEpochStart((Model)this.getNN());
                    continue;
                }
                l.onEpochEnd((Model)this.getNN());
            }
        }
    }

    public void attachVectors(Collection<Document> docs, AbstractMultiDataSetIterator.Stage stage, Class<? extends Encoder> targetClass) {
        SectorTaggerIterator it = new SectorTaggerIterator(stage, docs, this, this.batchSize, false, this.requireSubsampling);
        while (it.hasNext()) {
            this.attachVectors(it.nextDocumentBatch(), targetClass);
        }
    }

    protected void attachVectors(DocumentSentenceIterator.DocumentBatch batch, Class<? extends Encoder> targetClass) {
        Map<String, INDArray> weights = this.encodeMatrix(batch);
        INDArray target = weights.get("target");
        INDArray embeddingFW = null;
        INDArray embeddingBW = null;
        INDArray embedding = null;
        if (weights.containsKey("embedding")) {
            embedding = weights.get("embedding");
        }
        if (weights.containsKey("embeddingFW")) {
            embeddingFW = weights.get("embeddingFW");
            embeddingBW = weights.get("embeddingBW");
        }
        int batchIndex = 0;
        for (Document doc : batch.docs) {
            int t = 0;
            for (Sentence s : doc.getSentences()) {
                INDArray targetVec = target.getRow((long)batchIndex).getColumn((long)t);
                s.putVector(this.targetEncoder.getClass(), targetVec);
                if (embedding != null) {
                    INDArray embeddingVec = embedding.getRow((long)batchIndex).getColumn((long)t);
                    s.putVector(SectorEncoder.class, embeddingVec);
                }
                if (embeddingFW != null) {
                    INDArray fw = embeddingFW.getRow((long)batchIndex).getColumn((long)t);
                    INDArray bw = embeddingBW.getRow((long)batchIndex).getColumn((long)t);
                    s.putVector("embeddingFW", fw);
                    s.putVector("embeddingBW", bw);
                }
                ++t;
            }
            ++batchIndex;
        }
    }

    protected static void clearLayerStates(ComputationGraph net) {
        for (org.deeplearning4j.nn.api.Layer layer : net.getLayers()) {
            layer.clear();
            layer.clearNoiseWeightParams();
        }
        for (org.deeplearning4j.nn.api.Layer layer : net.getVertices()) {
            layer.clearVertex();
        }
        net.clear();
        net.clearLayerMaskArrays();
    }

    public void enableTrainingUI() {
        InMemoryStatsStorage stats = new InMemoryStatsStorage();
        this.net.addListeners(new TrainingListener[]{new StatsListener((StatsStorageRouter)stats, 1)});
        UIServer.getInstance().attach((StatsStorage)stats);
        UIServer.getInstance().enableRemoteListener((StatsStorageRouter)stats, true);
    }

    public void saveModel(Resource modelPath, String name) {
        Resource modelFile = modelPath.resolve(name + ".zip");
        try (OutputStream os = modelFile.getOutputStream();){
            ModelSerializer.writeModel((Model)this.net, (OutputStream)os, (boolean)true);
            this.setModel(modelFile);
        }
        catch (IOException ex) {
            log.error(ex.toString());
        }
    }

    public void loadModel(Resource modelFile) {
        try (InputStream is = modelFile.getInputStream();){
            this.net = ModelSerializer.restoreComputationGraph((InputStream)is, (boolean)false);
            this.setModel(modelFile);
            this.setModelAvailable(true);
            log.info("loaded Computation Graph from " + modelFile.getFileName());
        }
        catch (IOException ex) {
            log.error(ex.toString());
        }
    }

    public ComputationGraphConfiguration getGraphConfiguration() {
        return null;
    }

    public void setGraphConfiguration(JsonNode conf) {
    }
}

