package ai.h2o.mojos.cli;

import ai.h2o.mojos.check.PipelineCheck;
import ai.h2o.mojos.daimojo.ColumnType;
import ai.h2o.mojos.daimojo.NativeModel;
import ai.h2o.mojos.daimojo.NativePipeline;
import ai.h2o.mojos.daimojo.OutputColumn;
import ai.h2o.mojos.daimojo.TransformOps;
import ai.h2o.mojos.daimojo.jna.DaimojoLibrary;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.PrintStream;
import java.io.Reader;
import java.io.Writer;
import java.util.Map;
import java.util.concurrent.Callable;
import org.bridj.BridJ;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.impl.SimpleLogger;
import picocli.CommandLine;

@CommandLine.Command(name = "ExecuteNativeMojo", mixinStandardHelpOptions = true, sortOptions = false, description = {"Score a given dataset using a given MOJO2 pipeline."}, versionProvider = VersionInfo.class)
/* loaded from: input_file:ai/h2o/mojos/cli/ExecuteNativeMojo.class */
public class ExecuteNativeMojo implements Callable<Integer> {
    private static final String STDIN_FILENAME = "-";
    private static final String STDOUT_FILENAME = "-";
    private static final String STDOUT_NONE = "!";
    private static final String SHOW_PIPELINE = "SHOW_PIPELINE";
    private static final int DEFAULT_BATCH_SIZE = 100000;

    @CommandLine.Option(names = {"--batch"}, defaultValue = "0", description = {"Set batch size; if 0, a good value is determined automatically"})
    int batchSize;

    @CommandLine.Option(names = {"-s"}, description = {"Make logging more silent (tip: try -ss or -sss)"})
    boolean[] silent;

    @CommandLine.Option(names = {"-v"}, description = {"Make logging more verbose (tip: try -vv or -vvv)"})
    boolean[] verbose;

    @CommandLine.Option(names = {"--lib"}, description = {"Path to daimojo library (libdaimojo.so)"})
    String lib;

    @CommandLine.Option(names = {"--auxres"}, description = {"Override path to auxiliary parsing resource"})
    String auxres;

    @CommandLine.Parameters(index = "0", description = {"MOJO2 pipeline file."})
    String mojoFilename;

    @CommandLine.Parameters(index = "1", defaultValue = SHOW_PIPELINE, description = {"Dataset file to score in CSV format or '-' for read from standard input."})
    String inputCsvFilename;

    @CommandLine.Parameters(index = "2", defaultValue = STDOUT_NONE, description = {"Output file name or stdout by default."})
    String outputCsvFilename;

    @CommandLine.Option(hidden = true, names = {"--no-show-predictions"}, description = {"Don't show model predictions."})
    @Deprecated
    boolean noShowPredictions;

    @CommandLine.Option(hidden = true, names = {"--show-contributions"}, negatable = true, description = {"Show model contributions (default: ${DEFAULT-VALUE})."})
    @Deprecated
    boolean showContributionsTransformed;

    @CommandLine.Option(names = {"--check", "-c"}, paramLabel = "<expected output>", description = {"Check pipeline's output against expected output"})
    File expectedCsvOutput;
    private static final String[] LOG_LEVELS = {"trace", "debug", "info", "warn", "error", "off"};

    @CommandLine.Option(hidden = true, names = {"--show-predictions"}, defaultValue = "true", description = {"Show model predictions."})
    @Deprecated
    boolean showPredictions = true;

    @CommandLine.Option(hidden = true, names = {"--show-contributions-original"}, negatable = true, description = {"Show model contributions (default: ${DEFAULT-VALUE}) transformed to original input features."})
    @Deprecated
    boolean showContributionsOriginal = false;

    @CommandLine.Option(hidden = true, names = {"--with-prediction-interval"}, negatable = true, description = {"Enhance output with prediction interval bounds (default: ${DEFAULT-VALUE})."})
    @Deprecated
    boolean withPredictionInterval = false;

    @CommandLine.Option(names = {"--transforms", "-t"}, paramLabel = "<ops>", description = {"Transform operations - one or more of 'PIRO'; P=predict I=interval R=raw(shap) O=orig(shap)"})
    String transformOps = "";
    private boolean libHandled = false;

    private String logLevel() {
        int length = (2 - (this.verbose == null ? 0 : this.verbose.length)) + (this.silent == null ? 0 : this.silent.length);
        if (length < 0) {
            length = 0;
        } else if (length >= LOG_LEVELS.length) {
            length = LOG_LEVELS.length - 1;
        }
        return LOG_LEVELS[length];
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Finally extract failed */
    @Override // java.util.concurrent.Callable
    public Integer call() throws Exception {
        System.setProperty(SimpleLogger.DEFAULT_LOG_LEVEL_KEY, logLevel());
        Logger logger = LoggerFactory.getLogger((Class<?>) ExecuteNativeMojo.class);
        if (this.lib == null) {
            FatJar.extract();
        } else {
            handleLibArgument(this.lib);
        }
        logger.debug("daimojo library will load from {}", BridJ.getNativeLibraryFile("daimojo"));
        if (this.batchSize < 1) {
            this.batchSize = batchSizeMagic(this.inputCsvFilename);
            logger.warn("Batch size was automatically set to {}", Integer.valueOf(this.batchSize));
        }
        TransformOps transformOps = getTransformOps();
        File file = new File(this.mojoFilename);
        if (this.expectedCsvOutput == null) {
            NativeModel load = NativeModel.load(file, this.auxres);
            Throwable th = null;
            try {
                NativePipeline newPipeline = load.newPipeline(transformOps);
                if (SHOW_PIPELINE.equals(this.inputCsvFilename)) {
                    printPipelineInfo(newPipeline, System.out);
                    if (load != null) {
                        if (0 != 0) {
                            try {
                                load.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            load.close();
                        }
                    }
                    return 0;
                }
                boolean equals = this.inputCsvFilename.equals("-");
                boolean z = this.outputCsvFilename.equals("-") || this.outputCsvFilename.equals(STDOUT_NONE);
                Reader inputStreamReader = equals ? new InputStreamReader(System.in) : new FileReader(this.inputCsvFilename);
                Throwable th3 = null;
                try {
                    Writer outputStreamWriter = z ? new OutputStreamWriter(System.out) : new FileWriter(this.outputCsvFilename);
                    Throwable th4 = null;
                    try {
                        try {
                            logger.info("Total rows: {}", Long.valueOf(new Predict(newPipeline, this.batchSize).predict(inputStreamReader, outputStreamWriter)));
                            if (outputStreamWriter != null) {
                                if (0 != 0) {
                                    try {
                                        outputStreamWriter.close();
                                    } catch (Throwable th5) {
                                        th4.addSuppressed(th5);
                                    }
                                } else {
                                    outputStreamWriter.close();
                                }
                            }
                            if (inputStreamReader != null) {
                                if (0 != 0) {
                                    try {
                                        inputStreamReader.close();
                                    } catch (Throwable th6) {
                                        th3.addSuppressed(th6);
                                    }
                                } else {
                                    inputStreamReader.close();
                                }
                            }
                        } finally {
                        }
                    } catch (Throwable th7) {
                        if (outputStreamWriter != null) {
                            if (th4 != null) {
                                try {
                                    outputStreamWriter.close();
                                } catch (Throwable th8) {
                                    th4.addSuppressed(th8);
                                }
                            } else {
                                outputStreamWriter.close();
                            }
                        }
                        throw th7;
                    }
                } catch (Throwable th9) {
                    if (inputStreamReader != null) {
                        if (0 != 0) {
                            try {
                                inputStreamReader.close();
                            } catch (Throwable th10) {
                                th3.addSuppressed(th10);
                            }
                        } else {
                            inputStreamReader.close();
                        }
                    }
                    throw th9;
                }
            } finally {
                if (load != null) {
                    if (0 != 0) {
                        try {
                            load.close();
                        } catch (Throwable th11) {
                            th.addSuppressed(th11);
                        }
                    } else {
                        load.close();
                    }
                }
            }
        } else {
            if (this.inputCsvFilename.equals("-")) {
                throw new IllegalArgumentException("Check cannot run against stdin");
            }
            if (this.outputCsvFilename.equals("-")) {
                throw new IllegalArgumentException("Check cannot write output rows to stdout");
            }
            File file2 = this.outputCsvFilename.equals(STDOUT_NONE) ? null : new File(this.outputCsvFilename);
            PipelineCheck pipelineCheck = new PipelineCheck(file, transformOps, this.batchSize);
            pipelineCheck.checkCsv(CsvConfig.fromJvmProperties(), new File(this.inputCsvFilename), this.expectedCsvOutput, file2);
            int warningCount = pipelineCheck.getWarningCount();
            if (warningCount > 0) {
                logger.warn("Total warnings: {}", Integer.valueOf(warningCount));
            }
            int errorCount = pipelineCheck.getErrorCount();
            if (errorCount > 0) {
                logger.error("Total errors: {}", Integer.valueOf(errorCount));
                return 1;
            }
        }
        return 0;
    }

    private TransformOps getTransformOps() {
        if (!this.transformOps.isEmpty()) {
            return TransformOps.parse(this.transformOps);
        }
        this.showPredictions = !this.noShowPredictions;
        long j = this.showPredictions ? DaimojoLibrary.MOJO_Transform_Operations.PREDICT.value : 0L;
        if (this.withPredictionInterval) {
            j |= DaimojoLibrary.MOJO_Transform_Operations.INTERVAL.value;
        }
        if (this.showContributionsTransformed) {
            j |= DaimojoLibrary.MOJO_Transform_Operations.CONTRIBS_RAW.value;
        }
        if (this.showContributionsOriginal) {
            j |= DaimojoLibrary.MOJO_Transform_Operations.CONTRIBS_ORIGINAL.value;
        }
        return new TransformOps(j);
    }

    private static int batchSizeMagic(String str) {
        int i = DEFAULT_BATCH_SIZE;
        if (!str.equals("-")) {
            i = (int) (new File(str).length() / 100);
            if (i < 100) {
                i = 100;
            }
        }
        return i;
    }

    public static void printPipelineInfo(NativePipeline nativePipeline, PrintStream printStream) {
        printStream.println("UUID: " + nativePipeline.model.getUUID());
        printStream.println("Created: " + nativePipeline.model.getTimeCreated());
        printStream.println("Inputs:");
        for (Map.Entry<String, ColumnType> entry : nativePipeline.model.getFeatureDesc().entrySet()) {
            printStream.printf("* %s %s\n", entry.getValue(), entry.getKey());
        }
        printStream.println("Outputs:");
        for (OutputColumn outputColumn : nativePipeline.outputs()) {
            printStream.printf("#%d: %s %s %s\n", Integer.valueOf(outputColumn.index), outputColumn.dataType, outputColumn.name, outputColumn.operations);
        }
    }

    private void handleLibArgument(String str) throws IOException {
        File file;
        if (this.libHandled) {
            return;
        }
        this.libHandled = true;
        if (str == null) {
            File findThisClasspathElement = findThisClasspathElement();
            if (findThisClasspathElement != null) {
                BridJ.addLibraryPath(findThisClasspathElement.getParentFile().getAbsolutePath());
                return;
            }
            return;
        }
        File file2 = new File(str);
        if (file2.isFile()) {
            file = file2;
        } else {
            if (!file2.isDirectory()) {
                throw new FileNotFoundException(str);
            }
            file = new File(file2, "libdaimojo.so");
        }
        if (!file.exists()) {
            throw new FileNotFoundException(str);
        }
        BridJ.setNativeLibraryFile("daimojo", file2);
    }

    private static File findThisClasspathElement() {
        String substring;
        String url = ExecuteNativeMojo.class.getResource(ExecuteNativeMojo.class.getSimpleName() + ".class").toString();
        if (url.startsWith("jar:file:")) {
            substring = url.substring(9, url.lastIndexOf(STDOUT_NONE));
        } else {
            if (!url.startsWith("file:")) {
                return null;
            }
            substring = url.substring(5, (url.length() - (ExecuteNativeMojo.class.getName().length() + 6)) - 1);
        }
        return new File(substring);
    }

    public static void main(String... strArr) throws IOException {
        ExecuteNativeMojo executeNativeMojo = new ExecuteNativeMojo();
        if (strArr.length >= 2 && "--lib".equals(strArr[0])) {
            executeNativeMojo.handleLibArgument(strArr[1]);
        }
        int execute = new CommandLine(executeNativeMojo).execute(strArr);
        if (execute != 0) {
            System.exit(execute);
        }
    }
}
