/*
 * Decompiled with CFR 0.152.
 */
package de.datexis.sector.eval;

import de.datexis.encoder.LookupCacheEncoder;
import de.datexis.sector.eval.ClassificationEvaluation;
import de.datexis.sector.tagger.SectorTagger;
import de.datexis.tagger.Tagger;
import java.util.Map;
import org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.earlystopping.scorecalc.base.BaseIEvaluationScoreCalculator;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

public class ClassificationScoreCalculator
extends BaseIEvaluationScoreCalculator<Model, ClassificationEvaluation> {
    protected Tagger tagger;
    protected LookupCacheEncoder encoder;

    public ClassificationScoreCalculator(Tagger tagger, LookupCacheEncoder encoder, DataSetIterator iterator) {
        super(iterator);
        this.tagger = tagger;
        this.encoder = encoder;
    }

    public ClassificationScoreCalculator(Tagger tagger, LookupCacheEncoder encoder, MultiDataSetIterator iterator) {
        super(iterator);
        this.tagger = tagger;
        this.encoder = encoder;
    }

    protected ClassificationEvaluation newEval() {
        return new ClassificationEvaluation("score calculation", this.encoder);
    }

    public double calculateScore(Model network) {
        ClassificationEvaluation eval = this.newEval();
        if (network instanceof MultiLayerNetwork) {
            DataSetIterator i = this.iter != null ? this.iter : new MultiDataSetWrapperIterator(this.iterator);
            eval = ((ClassificationEvaluation[])((MultiLayerNetwork)network).doEvaluation(i, (IEvaluation[])new ClassificationEvaluation[]{eval}))[0];
        } else if (network instanceof ComputationGraph) {
            MultiDataSetIterator i = this.iterator != null ? this.iterator : new MultiDataSetIteratorAdapter(this.iter);
            this.evaluate((ComputationGraph)network, eval, i);
            this.tagger.appendTrainLog("Validation score:\n" + eval.printClassificationAtKStats());
        } else {
            throw new RuntimeException("Unknown model type: " + network.getClass());
        }
        return this.finalScore(eval);
    }

    protected void evaluate(ComputationGraph net, ClassificationEvaluation evaluation, MultiDataSetIterator iterator) {
        MultiDataSet next;
        boolean useRnnSegments;
        if (iterator.resetSupported() && !iterator.hasNext()) {
            iterator.reset();
        }
        MultiDataSetIterator iter = iterator.asyncSupported() ? new AsyncMultiDataSetIterator(iterator, 2, true) : iterator;
        boolean bl = useRnnSegments = net.getConfiguration().getBackpropType() == BackpropType.TruncatedBPTT;
        if (useRnnSegments) {
            throw new UnsupportedOperationException("Evaluation with Truncated BPTT is not implemented.");
        }
        while (iter.hasNext() && (next = (MultiDataSet)iter.next()).getFeatures() != null && next.getLabels() != null) {
            Map<String, INDArray> weights = SectorTagger.feedForward(net, next);
            INDArray predicted = null;
            if (weights.containsKey("target")) {
                predicted = weights.get("target");
            } else if (weights.containsKey("targetFW")) {
                predicted = weights.get("targetFW").dup();
                predicted.addi(weights.get("targetBW")).divi((Number)2);
            }
            evaluation.eval(next.getLabels(0), weights.get("target"), next.getLabelsMaskArray(0));
        }
        if (iterator.asyncSupported()) {
            ((AsyncMultiDataSetIterator)iter).shutdown();
        }
    }

    protected double finalScore(ClassificationEvaluation e) {
        return e.getScore();
    }

    public boolean minimizeScore() {
        return false;
    }
}

