package ai.djl.examples.training;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.basicdataset.Mnist;
import ai.djl.basicmodelzoo.basic.Mlp;
import ai.djl.examples.training.util.Arguments;
import ai.djl.examples.training.util.TrainingUtils;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import java.io.IOException;
import java.nio.file.Paths;
import org.apache.commons.cli.ParseException;

/* loaded from: input_file:ai/djl/examples/training/TrainMnist.class */
public final class TrainMnist {
    private TrainMnist() {
    }

    public static void main(String[] strArr) throws IOException, ParseException {
        runExample(strArr);
    }

    public static TrainingResult runExample(String[] strArr) throws IOException, ParseException {
        Arguments parseArgs = Arguments.parseArgs(strArr);
        Mlp mlp = new Mlp(784, 10, new int[]{128, 64});
        Model newInstance = Model.newInstance();
        Throwable th = null;
        try {
            newInstance.setBlock(mlp);
            RandomAccessDataset dataset = getDataset(Dataset.Usage.TRAIN, parseArgs);
            RandomAccessDataset dataset2 = getDataset(Dataset.Usage.TEST, parseArgs);
            Trainer newTrainer = newInstance.newTrainer(setupTrainingConfig(parseArgs));
            Throwable th2 = null;
            try {
                try {
                    newTrainer.setMetrics(new Metrics());
                    newTrainer.initialize(new Shape[]{new Shape(new long[]{1, 784})});
                    TrainingUtils.fit(newTrainer, parseArgs.getEpoch(), dataset, dataset2, parseArgs.getOutputDir(), "mlp");
                    TrainingResult trainingResult = newTrainer.getTrainingResult();
                    float floatValue = trainingResult.getValidateEvaluation("Accuracy").floatValue();
                    newInstance.setProperty("Epoch", String.valueOf(trainingResult.getEpoch()));
                    newInstance.setProperty("Accuracy", String.format("%.5f", Float.valueOf(floatValue)));
                    newInstance.setProperty("Loss", String.format("%.5f", trainingResult.getValidateLoss()));
                    newInstance.save(Paths.get(parseArgs.getOutputDir(), new String[0]), "mlp");
                    if (newTrainer != null) {
                        if (0 != 0) {
                            try {
                                newTrainer.close();
                            } catch (Throwable th3) {
                                th2.addSuppressed(th3);
                            }
                        } else {
                            newTrainer.close();
                        }
                    }
                    return trainingResult;
                } finally {
                }
            } catch (Throwable th4) {
                if (newTrainer != null) {
                    if (th2 != null) {
                        try {
                            newTrainer.close();
                        } catch (Throwable th5) {
                            th2.addSuppressed(th5);
                        }
                    } else {
                        newTrainer.close();
                    }
                }
                throw th4;
            }
        } finally {
            if (newInstance != null) {
                if (0 != 0) {
                    try {
                        newInstance.close();
                    } catch (Throwable th6) {
                        th.addSuppressed(th6);
                    }
                } else {
                    newInstance.close();
                }
            }
        }
    }

    private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
        return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()).addEvaluator(new Accuracy()).optDevices(Device.getDevices(arguments.getMaxGpus())).addTrainingListeners(TrainingListener.Defaults.logging(arguments.getOutputDir()));
    }

    private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arguments) throws IOException {
        Mnist build = Mnist.builder().optUsage(usage).setSampling(arguments.getBatchSize(), true).optLimit(arguments.getLimit()).build();
        build.prepare(new ProgressBar());
        return build;
    }
}
