package net.savantly.learning.graphite.learners.timeseries;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import net.savantly.learning.graphite.convert.GraphiteToCsv;
import net.savantly.learning.graphite.domain.GraphiteMultiSeries;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.NumberedFileInputSplit;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:net/savantly/learning/graphite/learners/timeseries/TimeSeriesStateClassifier.class */
public class TimeSeriesStateClassifier {
    private static final Logger log = LoggerFactory.getLogger(TimeSeriesStateClassifier.class);
    private GraphiteToCsv.CsvResult filePairCount;
    private final int numberOfPossibleLabels = 2;
    private List<GraphiteMultiSeries> positiveExamples = new ArrayList();
    private List<GraphiteMultiSeries> negativeExamples = new ArrayList();
    private File workingDirectory = new File("./data");
    private int miniBatchSize = 10;
    private boolean doRegression = false;
    private int numberOfIterations = 40;
    private boolean isBuilt = false;
    private DataNormalization normalizer = new NormalizerStandardize();
    private List<IterationListener> iterationListeners = new ArrayList();
    private double learningRate = 0.005d;

    private TimeSeriesStateClassifier() {
    }

    public static TimeSeriesStateClassifier builder() {
        return new TimeSeriesStateClassifier();
    }

    public TimeSeriesStateClassifier build() throws IOException {
        this.filePairCount = createCsvFiles();
        this.isBuilt = true;
        return this;
    }

    private GraphiteToCsv.CsvResult createCsvFiles() throws IOException {
        ArrayList arrayList = new ArrayList();
        this.positiveExamples.stream().forEach(graphiteMultiSeries -> {
            arrayList.add(Pair.of("0", graphiteMultiSeries));
        });
        this.negativeExamples.stream().forEach(graphiteMultiSeries2 -> {
            arrayList.add(Pair.of("0", graphiteMultiSeries2));
        });
        GraphiteToCsv.CsvResult createFileSequence = GraphiteToCsv.get(this.workingDirectory.getAbsolutePath()).createFileSequence(arrayList);
        Arrays.stream(this.workingDirectory.list()).forEach(str -> {
            log.info(str);
        });
        return createFileSequence;
    }

    public MultiLayerNetwork train() throws IOException, InterruptedException {
        if (!this.isBuilt) {
            throw new RuntimeException("must call build() first");
        }
        DataSetIterator trainingDataSets = getTrainingDataSets();
        this.normalizer.fit(trainingDataSets);
        trainingDataSets.reset();
        trainingDataSets.setPreProcessor(this.normalizer);
        DataSetIterator testingDataSets = getTestingDataSets();
        testingDataSets.setPreProcessor(this.normalizer);
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(createNetworkConfiguration());
        multiLayerNetwork.init();
        multiLayerNetwork.setListeners(this.iterationListeners);
        int i = this.numberOfIterations;
        for (int i2 = 0; i2 < i; i2++) {
            multiLayerNetwork.fit(trainingDataSets);
            Evaluation evaluate = multiLayerNetwork.evaluate(testingDataSets);
            log.info(String.format("Test set evaluation at epoch %d: Accuracy = %.2f, F1 = %.2f", Integer.valueOf(i2), Double.valueOf(evaluate.accuracy()), Double.valueOf(evaluate.f1())));
            testingDataSets.reset();
            trainingDataSets.reset();
        }
        log.info("----- completed training and testing -----");
        return multiLayerNetwork;
    }

    private MultiLayerConfiguration createNetworkConfiguration() {
        return new NeuralNetConfiguration.Builder().seed(123).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).weightInit(WeightInit.XAVIER).updater(Updater.NESTEROVS).learningRate(this.learningRate).gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(0.5d).list().layer(0, new GravesLSTM.Builder().activation(Activation.TANH).nIn(1).nOut(10).build()).layer(1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(2).build()).pretrain(false).backprop(true).build();
    }

    private DataSetIterator getTestingDataSets() throws IOException, InterruptedException {
        String path = this.workingDirectory.toPath().resolve("%d.train.features.csv").toAbsolutePath().toString();
        String path2 = this.workingDirectory.toPath().resolve("%d.train.labels.csv").toAbsolutePath().toString();
        CSVSequenceRecordReader cSVSequenceRecordReader = new CSVSequenceRecordReader();
        cSVSequenceRecordReader.initialize(new NumberedFileInputSplit(path, 1, this.filePairCount.getTrainFileCount()));
        CSVSequenceRecordReader cSVSequenceRecordReader2 = new CSVSequenceRecordReader();
        cSVSequenceRecordReader2.initialize(new NumberedFileInputSplit(path2, 1, this.filePairCount.getTrainFileCount()));
        return new SequenceRecordReaderDataSetIterator(cSVSequenceRecordReader, cSVSequenceRecordReader2, this.miniBatchSize, 2, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
    }

    private DataSetIterator getTrainingDataSets() throws IOException, InterruptedException {
        String path = this.workingDirectory.toPath().resolve("%d.test.features.csv").toAbsolutePath().toString();
        String path2 = this.workingDirectory.toPath().resolve("%d.test.labels.csv").toAbsolutePath().toString();
        CSVSequenceRecordReader cSVSequenceRecordReader = new CSVSequenceRecordReader();
        cSVSequenceRecordReader.initialize(new NumberedFileInputSplit(path, 1, this.filePairCount.getTestFileCount()));
        CSVSequenceRecordReader cSVSequenceRecordReader2 = new CSVSequenceRecordReader();
        cSVSequenceRecordReader2.initialize(new NumberedFileInputSplit(path2, 1, this.filePairCount.getTestFileCount()));
        return new SequenceRecordReaderDataSetIterator(cSVSequenceRecordReader, cSVSequenceRecordReader2, this.miniBatchSize, 2, this.doRegression, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
    }

    public List<GraphiteMultiSeries> getPostiveExamples() {
        return this.positiveExamples;
    }

    public TimeSeriesStateClassifier setPositiveExamples(List<GraphiteMultiSeries> list) {
        this.positiveExamples = list;
        return this;
    }

    public List<GraphiteMultiSeries> getNegativeExamples() {
        return this.negativeExamples;
    }

    public TimeSeriesStateClassifier setNegativeExamples(List<GraphiteMultiSeries> list) {
        this.negativeExamples = list;
        return this;
    }

    public File getWorkingDirectory() {
        return this.workingDirectory;
    }

    public TimeSeriesStateClassifier setWorkingDirectory(File file) {
        this.workingDirectory = file;
        return this;
    }

    public int getMiniBatchSize() {
        return this.miniBatchSize;
    }

    public TimeSeriesStateClassifier setMiniBatchSize(int i) {
        this.miniBatchSize = i;
        return this;
    }

    public boolean isDoRegression() {
        return this.doRegression;
    }

    public TimeSeriesStateClassifier setDoRegression(boolean z) {
        this.doRegression = z;
        return this;
    }

    public int getNumberOfIterations() {
        return this.numberOfIterations;
    }

    public TimeSeriesStateClassifier setNumberOfIterations(int i) {
        this.numberOfIterations = i;
        return this;
    }

    public DataNormalization getNormalizer() {
        return this.normalizer;
    }

    public TimeSeriesStateClassifier setNormalizer(DataNormalization dataNormalization) {
        this.normalizer = dataNormalization;
        return this;
    }

    public List<IterationListener> getIterationListeners() {
        return this.iterationListeners;
    }

    public TimeSeriesStateClassifier setIterationListeners(List<IterationListener> list) {
        this.iterationListeners = list;
        return this;
    }

    public TimeSeriesStateClassifier addIterationListener(IterationListener iterationListener) {
        this.iterationListeners.add(iterationListener);
        return this;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public TimeSeriesStateClassifier setLearningRate(double d) {
        this.learningRate = d;
        return this;
    }
}
