package hex.genmodel.tools;

import com.google.gson.GsonBuilder;
import com.google.gson.JsonElement;
import com.google.gson.JsonPrimitive;
import com.google.gson.JsonSerializationContext;
import com.google.gson.JsonSerializer;
import com.google.gson.reflect.TypeToken;
import com.sun.jna.platform.win32.WinError;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.gbm.GbmMojoModel;
import hex.genmodel.algos.tree.ConvertTreeOptions;
import hex.genmodel.algos.tree.SharedTreeGraph;
import hex.genmodel.algos.tree.SharedTreeGraphConverter;
import hex.genmodel.algos.tree.TreeBackedMojoModel;
import hex.genmodel.tools.MojoPrinter;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
import joptsimple.internal.Strings;
import water.init.AbstractBuildVersion;
import water.util.JavaVersionUtils;

/* loaded from: input_file:hex/genmodel/tools/PrintMojo.class */
public class PrintMojo implements MojoPrinter {
    public static final AbstractBuildVersion ABV = AbstractBuildVersion.getBuildVersion();
    protected MojoModel genModel;
    protected PrintTreeOptions pTreeOptions;
    protected boolean internal;
    protected boolean floatToDouble;
    protected MojoPrinter.Format format = MojoPrinter.Format.dot;
    protected int treeToPrint = -1;
    protected int maxLevelsToPrintPerEdge = 10;
    protected boolean detail = false;
    protected String outputFileName = null;
    protected String optionalTitle = null;
    protected final String tmpOutputFileName = "tmpOutputFileName.gv";

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/genmodel/tools/PrintMojo$FloatCastingSerializer.class */
    public static class FloatCastingSerializer implements JsonSerializer<Float> {
        FloatCastingSerializer() {
        }

        @Override // com.google.gson.JsonSerializer
        public JsonElement serialize(Float f, Type type, JsonSerializationContext jsonSerializationContext) {
            return new JsonPrimitive((Number) new Double(f.floatValue()));
        }
    }

    /* loaded from: input_file:hex/genmodel/tools/PrintMojo$PrintTreeOptions.class */
    public static class PrintTreeOptions {
        public boolean _setDecimalPlace;
        public int _nPlaces;
        public int _fontSize;
        public boolean _internal;

        public PrintTreeOptions(boolean z, int i, int i2, boolean z2) {
            this._setDecimalPlace = z;
            this._nPlaces = this._setDecimalPlace ? i : this._nPlaces;
            this._fontSize = i2;
            this._internal = z2;
        }

        public float roundNPlace(float f) {
            if (this._nPlaces < 0) {
                return f;
            }
            return (float) (Math.round(f * r0) / Math.pow(10.0d, this._nPlaces));
        }
    }

    public static void main(String[] strArr) {
        MojoPrinter mojoPrinter = null;
        if (!JavaVersionUtils.JAVA_VERSION.isKnown() || JavaVersionUtils.JAVA_VERSION.getMajor() <= 7) {
            mojoPrinter = new PrintMojo();
        } else {
            Iterator it = ServiceLoader.load(MojoPrinter.class).iterator();
            while (it.hasNext()) {
                MojoPrinter mojoPrinter2 = (MojoPrinter) it.next();
                if (mojoPrinter2.supportsFormat(getFormat(strArr))) {
                    mojoPrinter = mojoPrinter2;
                }
            }
            if (mojoPrinter == null) {
                System.out.println("No supported MojoPrinter for the format required found. Please make sure you are using h2o-genmodel.jar for executing this tool.");
                System.exit(1);
            }
        }
        mojoPrinter.parseArgs(strArr);
        try {
            mojoPrinter.run();
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(2);
        }
        System.exit(0);
    }

    @Override // hex.genmodel.tools.MojoPrinter
    public boolean supportsFormat(MojoPrinter.Format format) {
        return !MojoPrinter.Format.png.equals(format);
    }

    static MojoPrinter.Format getFormat(String[] strArr) {
        for (int i = 0; i < strArr.length; i++) {
            if (strArr[i].equals("--format")) {
                try {
                    return MojoPrinter.Format.valueOf(strArr[i + 1]);
                } catch (Exception e) {
                    return null;
                }
            }
        }
        return null;
    }

    private void loadMojo(String str) throws IOException {
        this.genModel = MojoModel.load(str);
    }

    protected static void usage() {
        System.out.println("Build git branch: " + ABV.branchName());
        System.out.println("Build git hash: " + ABV.lastCommitHash());
        System.out.println("Build git describe: " + ABV.describe());
        System.out.println("Build project version: " + ABV.projectVersion());
        System.out.println("Built by: '" + ABV.compiledBy() + Strings.SINGLE_QUOTE);
        System.out.println("Built on: '" + ABV.compiledOn() + Strings.SINGLE_QUOTE);
        System.out.println();
        System.out.println("Emit a human-consumable graph of a model for use with dot (graphviz).");
        System.out.println("The currently supported model types are DRF, GBM and XGBoost.");
        System.out.println();
        System.out.println("Usage:  java [...java args...] hex.genmodel.tools.PrintMojo [--tree n] [--levels n] [--title sss] [-o outputFileName]");
        System.out.println();
        System.out.println("    --format        Output format. For .png output at least Java 8 is required.");
        System.out.println("                    dot|json|raw|png [default dot]");
        System.out.println();
        System.out.println("    --tree          Tree number to print.");
        System.out.println("                    [default all]");
        System.out.println();
        System.out.println("    --levels        Number of levels per edge to print.");
        System.out.println("                    [default 10]");
        System.out.println();
        System.out.println("    --title         (Optional) Force title of tree graph.");
        System.out.println();
        System.out.println("    --detail        Specify to print additional detailed information like node numbers.");
        System.out.println();
        System.out.println("    --input | -i    Input mojo file.");
        System.out.println();
        System.out.println("    --output | -o   Output filename. Taken as a directory name in case of .png format and multiple trees to visualize.");
        System.out.println("                    [default stdout]");
        System.out.println("    --decimalplaces | -d    Set decimal places of all numerical values.");
        System.out.println();
        System.out.println("    --fontsize | -f    Set font sizes of strings.");
        System.out.println();
        System.out.println("    --internal    Internal H2O representation of the decision tree (splits etc.) is used for generating the GRAPHVIZ format.");
        System.out.println();
        System.out.println();
        System.out.println("Example:");
        System.out.println();
        System.out.println("    (brew install graphviz)");
        System.out.println("    java -cp h2o.jar hex.genmodel.tools.PrintMojo --tree 0 -i model_mojo.zip -o model.gv -f 20 -d 3");
        System.out.println("    dot -Tpng model.gv -o model.png");
        System.out.println("    open model.png");
        System.out.println();
        System.exit(1);
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:7:0x0024. Please report as an issue. */
    @Override // hex.genmodel.tools.MojoPrinter
    public void parseArgs(String[] strArr) {
        int i = -1;
        int i2 = 14;
        boolean z = false;
        int i3 = 0;
        while (i3 < strArr.length) {
            try {
                String str = strArr[i3];
                boolean z2 = -1;
                switch (str.hashCode()) {
                    case -1623337430:
                        if (str.equals("--input")) {
                            z2 = 5;
                            break;
                        }
                        break;
                    case -1613324104:
                        if (str.equals("--title")) {
                            z2 = 3;
                            break;
                        }
                        break;
                    case -947411256:
                        if (str.equals("--floattodouble")) {
                            z2 = 13;
                            break;
                        }
                        break;
                    case 1495:
                        if (str.equals("-d")) {
                            z2 = 10;
                            break;
                        }
                        break;
                    case 1497:
                        if (str.equals("-f")) {
                            z2 = 8;
                            break;
                        }
                        break;
                    case WinError.ERROR_EVENTLOG_FILE_CORRUPT /* 1500 */:
                        if (str.equals("-i")) {
                            z2 = 6;
                            break;
                        }
                        break;
                    case 1506:
                        if (str.equals("-o")) {
                            z2 = 14;
                            break;
                        }
                        break;
                    case 43011720:
                        if (str.equals("--raw")) {
                            z2 = 11;
                            break;
                        }
                        break;
                    case 238549309:
                        if (str.equals("--decimalplaces")) {
                            z2 = 9;
                            break;
                        }
                        break;
                    case 382308976:
                        if (str.equals("--fontsize")) {
                            z2 = 7;
                            break;
                        }
                        break;
                    case 586165341:
                        if (str.equals("--internal")) {
                            z2 = 12;
                            break;
                        }
                        break;
                    case 1064789489:
                        if (str.equals("--detail")) {
                            z2 = 4;
                            break;
                        }
                        break;
                    case 1131234711:
                        if (str.equals("--format")) {
                            z2 = false;
                            break;
                        }
                        break;
                    case 1293886223:
                        if (str.equals("--levels")) {
                            z2 = 2;
                            break;
                        }
                        break;
                    case 1333438782:
                        if (str.equals("--tree")) {
                            z2 = true;
                            break;
                        }
                        break;
                    case 1394501281:
                        if (str.equals("--output")) {
                            z2 = 15;
                            break;
                        }
                        break;
                }
                switch (z2) {
                    case false:
                        i3++;
                        if (i3 >= strArr.length) {
                            usage();
                        }
                        String str2 = strArr[i3];
                        try {
                            this.format = MojoPrinter.Format.valueOf(str2);
                            break;
                        } catch (Exception e) {
                            System.out.println("ERROR: invalid --format argument (" + str2 + ")");
                            System.exit(1);
                            break;
                        }
                    case true:
                        i3++;
                        if (i3 >= strArr.length) {
                            usage();
                        }
                        String str3 = strArr[i3];
                        try {
                            this.treeToPrint = Integer.parseInt(str3);
                            break;
                        } catch (Exception e2) {
                            System.out.println("ERROR: invalid --tree argument (" + str3 + ")");
                            System.exit(1);
                            break;
                        }
                    case true:
                        i3++;
                        if (i3 >= strArr.length) {
                            usage();
                        }
                        String str4 = strArr[i3];
                        try {
                            this.maxLevelsToPrintPerEdge = Integer.parseInt(str4);
                            break;
                        } catch (Exception e3) {
                            System.out.println("ERROR: invalid --levels argument (" + str4 + ")");
                            System.exit(1);
                            break;
                        }
                    case true:
                        i3++;
                        if (i3 >= strArr.length) {
                            usage();
                        }
                        this.optionalTitle = strArr[i3];
                        break;
                    case true:
                        this.detail = true;
                        break;
                    case true:
                    case true:
                        i3++;
                        if (i3 >= strArr.length) {
                            usage();
                        }
                        loadMojo(strArr[i3]);
                        break;
                    case true:
                    case true:
                        i3++;
                        if (i3 >= strArr.length) {
                            usage();
                        }
                        i2 = Integer.parseInt(strArr[i3]);
                        break;
                    case true:
                    case true:
                        i3++;
                        if (i3 >= strArr.length) {
                            usage();
                        }
                        z = true;
                        i = Integer.parseInt(strArr[i3]);
                        break;
                    case true:
                        this.format = MojoPrinter.Format.raw;
                        break;
                    case true:
                        this.internal = true;
                        break;
                    case true:
                        this.floatToDouble = true;
                        break;
                    case true:
                    case true:
                        i3++;
                        if (i3 >= strArr.length) {
                            usage();
                        }
                        this.outputFileName = strArr[i3];
                        break;
                    default:
                        System.out.println("ERROR: Unknown command line argument: " + str);
                        usage();
                        break;
                }
                i3++;
            } catch (Exception e4) {
                e4.printStackTrace();
                usage();
                return;
            }
        }
        this.pTreeOptions = new PrintTreeOptions(z, i, i2, this.internal);
    }

    protected void validateArgs() {
        if (this.genModel == null) {
            System.out.println("ERROR: Must specify -i");
            usage();
        }
    }

    @Override // hex.genmodel.tools.MojoPrinter
    public void run() throws Exception {
        validateArgs();
        PrintStream printStream = this.outputFileName != null ? new PrintStream(new FileOutputStream(this.outputFileName)) : System.out;
        if (!(this.genModel instanceof SharedTreeGraphConverter)) {
            System.out.println("ERROR: Unsupported MOJO type");
            System.exit(1);
            return;
        }
        SharedTreeGraphConverter sharedTreeGraphConverter = (SharedTreeGraphConverter) this.genModel;
        SharedTreeGraph convert = sharedTreeGraphConverter.convert(this.treeToPrint, null, new ConvertTreeOptions().withTreeConsistencyCheckEnabled());
        switch (this.format) {
            case raw:
                convert.print();
                return;
            case dot:
                convert.printDot(printStream, this.maxLevelsToPrintPerEdge, this.detail, this.optionalTitle, this.pTreeOptions);
                return;
            case json:
                if (!(sharedTreeGraphConverter instanceof TreeBackedMojoModel)) {
                    System.out.println("ERROR: Printing XGBoost MOJO as JSON not supported");
                    System.exit(1);
                }
                printJson((TreeBackedMojoModel) sharedTreeGraphConverter, convert, printStream);
                return;
            default:
                return;
        }
    }

    private Map<String, Object> getParamsAsJson(TreeBackedMojoModel treeBackedMojoModel) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put("h2o_version", this.genModel._h2oVersion);
        linkedHashMap.put("mojo_version", Double.valueOf(this.genModel._mojo_version));
        linkedHashMap.put("algo", this.genModel._algoName);
        linkedHashMap.put("model_category", this.genModel._category.toString());
        linkedHashMap.put("classifier", Boolean.valueOf(this.genModel.isClassifier()));
        linkedHashMap.put("supervised", Boolean.valueOf(this.genModel._supervised));
        linkedHashMap.put("nfeatures", Integer.valueOf(this.genModel._nfeatures));
        linkedHashMap.put("nclasses", Integer.valueOf(this.genModel._nclasses));
        linkedHashMap.put("balance_classes", Boolean.valueOf(this.genModel._balanceClasses));
        linkedHashMap.put("n_tree_groups", Integer.valueOf(treeBackedMojoModel.getNTreeGroups()));
        linkedHashMap.put("n_trees_in_group", Integer.valueOf(treeBackedMojoModel.getNTreesPerGroup()));
        linkedHashMap.put("base_score", Double.valueOf(treeBackedMojoModel.getInitF()));
        if (this.genModel.isClassifier()) {
            linkedHashMap.put("class_labels", this.genModel.getDomainValues(this.genModel.getResponseIdx()));
        }
        if (this.genModel instanceof GbmMojoModel) {
            GbmMojoModel gbmMojoModel = (GbmMojoModel) this.genModel;
            linkedHashMap.put("family", gbmMojoModel._family.toString());
            linkedHashMap.put("link_function", gbmMojoModel._link_function.toString());
        }
        return linkedHashMap;
    }

    private List<Object> getDomainValuesAsJSON() {
        ArrayList arrayList = new ArrayList();
        String[][] domainValues = this.genModel.getDomainValues();
        for (int i = 0; i < domainValues.length - 1; i++) {
            if (domainValues[i] != null) {
                LinkedHashMap linkedHashMap = new LinkedHashMap();
                linkedHashMap.put("colId", Integer.valueOf(i));
                linkedHashMap.put("colName", this.genModel._names[i]);
                linkedHashMap.put("values", domainValues[i]);
                arrayList.add(linkedHashMap);
            }
        }
        return arrayList;
    }

    private void printJson(TreeBackedMojoModel treeBackedMojoModel, SharedTreeGraph sharedTreeGraph, PrintStream printStream) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put("params", getParamsAsJson(treeBackedMojoModel));
        linkedHashMap.put("domainValues", getDomainValuesAsJSON());
        linkedHashMap.put("trees", sharedTreeGraph.toJson());
        if (this.optionalTitle != null) {
            linkedHashMap.put("title", this.optionalTitle);
        }
        GsonBuilder prettyPrinting = new GsonBuilder().setPrettyPrinting();
        if (this.floatToDouble) {
            prettyPrinting.registerTypeAdapter(new TypeToken<Float>() { // from class: hex.genmodel.tools.PrintMojo.1
            }.getType(), new FloatCastingSerializer());
        }
        printStream.print(prettyPrinting.create().toJson(linkedHashMap));
    }
}
