package hex.genmodel.easy;

import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.algos.deepwater.DeepwaterMojoModel;
import hex.genmodel.algos.deepwater.caffe.nano.Deepwater;
import hex.genmodel.algos.word2vec.WordEmbeddingModel;
import hex.genmodel.easy.exception.PredictException;
import hex.genmodel.easy.exception.PredictNumberFormatException;
import hex.genmodel.easy.exception.PredictUnknownCategoricalLevelException;
import hex.genmodel.easy.exception.PredictUnknownTypeException;
import hex.genmodel.easy.prediction.AbstractPrediction;
import hex.genmodel.easy.prediction.AutoEncoderModelPrediction;
import hex.genmodel.easy.prediction.BinomialModelPrediction;
import hex.genmodel.easy.prediction.ClusteringModelPrediction;
import hex.genmodel.easy.prediction.DimReductionModelPrediction;
import hex.genmodel.easy.prediction.MultinomialModelPrediction;
import hex.genmodel.easy.prediction.RegressionModelPrediction;
import hex.genmodel.easy.prediction.SortedClassProbability;
import hex.genmodel.easy.prediction.Word2VecPrediction;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.net.URL;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import javax.imageio.ImageIO;

/* loaded from: input_file:hex/genmodel/easy/EasyPredictModelWrapper.class */
public class EasyPredictModelWrapper implements Serializable {
    public final GenModel m;
    private final HashMap<String, Integer> modelColumnNameToIndexMap;
    private final HashMap<Integer, HashMap<String, Integer>> domainMap;
    private final boolean convertUnknownCategoricalLevelsToNa;
    private final boolean convertInvalidNumbersToNa;
    private final ConcurrentHashMap<String, AtomicLong> unknownCategoricalLevelsSeenPerColumn;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: hex.genmodel.easy.EasyPredictModelWrapper$1, reason: invalid class name */
    /* loaded from: input_file:hex/genmodel/easy/EasyPredictModelWrapper$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$hex$ModelCategory = new int[ModelCategory.values().length];

        static {
            try {
                $SwitchMap$hex$ModelCategory[ModelCategory.AutoEncoder.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$hex$ModelCategory[ModelCategory.Binomial.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$hex$ModelCategory[ModelCategory.Multinomial.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$hex$ModelCategory[ModelCategory.Clustering.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$hex$ModelCategory[ModelCategory.Regression.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$hex$ModelCategory[ModelCategory.DimReduction.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$hex$ModelCategory[ModelCategory.WordEmbedding.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$hex$ModelCategory[ModelCategory.Unknown.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
        }
    }

    /* loaded from: input_file:hex/genmodel/easy/EasyPredictModelWrapper$Config.class */
    public static class Config {
        private GenModel model;
        private boolean convertUnknownCategoricalLevelsToNa = false;
        private boolean convertInvalidNumbersToNa = false;

        public Config setModel(GenModel genModel) {
            this.model = genModel;
            return this;
        }

        public GenModel getModel() {
            return this.model;
        }

        public Config setConvertUnknownCategoricalLevelsToNa(boolean z) {
            this.convertUnknownCategoricalLevelsToNa = z;
            return this;
        }

        public boolean getConvertUnknownCategoricalLevelsToNa() {
            return this.convertUnknownCategoricalLevelsToNa;
        }

        public Config setConvertInvalidNumbersToNa(boolean z) {
            this.convertInvalidNumbersToNa = z;
            return this;
        }

        public boolean getConvertInvalidNumbersToNa() {
            return this.convertInvalidNumbersToNa;
        }
    }

    public EasyPredictModelWrapper(Config config) {
        this.m = config.getModel();
        this.modelColumnNameToIndexMap = new HashMap<>();
        String[] names = this.m.getNames();
        for (int i = 0; i < names.length; i++) {
            this.modelColumnNameToIndexMap.put(names[i], Integer.valueOf(i));
        }
        this.unknownCategoricalLevelsSeenPerColumn = new ConcurrentHashMap<>();
        this.convertUnknownCategoricalLevelsToNa = config.getConvertUnknownCategoricalLevelsToNa();
        this.convertInvalidNumbersToNa = config.getConvertInvalidNumbersToNa();
        setupConvertUnknownCategoricalLevelsToNa();
        this.domainMap = new HashMap<>();
        for (int i2 = 0; i2 < this.m.getNumCols(); i2++) {
            String[] domainValues = this.m.getDomainValues(i2);
            if (domainValues != null) {
                HashMap<String, Integer> hashMap = new HashMap<>();
                for (int i3 = 0; i3 < domainValues.length; i3++) {
                    hashMap.put(domainValues[i3], Integer.valueOf(i3));
                }
                this.domainMap.put(Integer.valueOf(i2), hashMap);
            }
        }
    }

    public EasyPredictModelWrapper(GenModel genModel) {
        this(new Config().setModel(genModel));
    }

    public long getTotalUnknownCategoricalLevelsSeen() {
        long j = 0;
        Iterator<AtomicLong> it = getUnknownCategoricalLevelsSeenPerColumn().values().iterator();
        while (it.hasNext()) {
            j += it.next().get();
        }
        return j;
    }

    public ConcurrentHashMap<String, AtomicLong> getUnknownCategoricalLevelsSeenPerColumn() {
        return this.unknownCategoricalLevelsSeenPerColumn;
    }

    public AbstractPrediction predict(RowData rowData, ModelCategory modelCategory) throws PredictException {
        switch (AnonymousClass1.$SwitchMap$hex$ModelCategory[modelCategory.ordinal()]) {
            case Deepwater.Train /* 1 */:
                return predictAutoEncoder(rowData);
            case Deepwater.Predict /* 2 */:
                return predictBinomial(rowData);
            case Deepwater.SaveGraph /* 3 */:
                return predictMultinomial(rowData);
            case Deepwater.Save /* 4 */:
                return predictClustering(rowData);
            case Deepwater.Load /* 5 */:
                return predictRegression(rowData);
            case 6:
                return predictDimReduction(rowData);
            case 7:
                return predictWord2Vec(rowData);
            case 8:
                throw new PredictException("Unknown model category");
            default:
                throw new PredictException("Unhandled model category (" + this.m.getModelCategory() + ") in switch statement");
        }
    }

    public AbstractPrediction predict(RowData rowData) throws PredictException {
        return predict(rowData, this.m.getModelCategory());
    }

    public AutoEncoderModelPrediction predictAutoEncoder(RowData rowData) throws PredictException {
        validateModelCategory(ModelCategory.AutoEncoder);
        double[] dArr = new double[this.m.getPredsSize(ModelCategory.AutoEncoder)];
        double[] fillRawData = fillRawData(rowData, nanArray(this.m.nfeatures()));
        double[] score0 = this.m.score0(fillRawData, dArr);
        AutoEncoderModelPrediction autoEncoderModelPrediction = new AutoEncoderModelPrediction();
        autoEncoderModelPrediction.original = expandRawData(fillRawData, score0.length);
        autoEncoderModelPrediction.reconstructed = score0;
        autoEncoderModelPrediction.reconstructedRowData = reconstructedToRowData(score0);
        return autoEncoderModelPrediction;
    }

    private double[] expandRawData(double[] dArr, int i) {
        double[] dArr2 = new double[i];
        int i2 = 0;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            if (this.m._domains[i3] == null) {
                dArr2[i2] = dArr[i3];
                i2++;
            } else {
                dArr2[i2 + (Double.isNaN(dArr[i3]) ? this.m._domains[i3].length : (int) dArr[i3])] = 1.0d;
                i2 += this.m._domains[i3].length + 1;
            }
        }
        return dArr2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v20, types: [java.lang.Double] */
    private RowData reconstructedToRowData(double[] dArr) {
        Map<String, Double> catValuesAsMap;
        RowData rowData = new RowData();
        int i = 0;
        for (int i2 = 0; i2 < this.m.nfeatures(); i2++) {
            if (this.m._domains[i2] == null) {
                int i3 = i;
                i++;
                catValuesAsMap = Double.valueOf(dArr[i3]);
            } else {
                catValuesAsMap = catValuesAsMap(this.m._domains[i2], dArr, i);
                i += this.m._domains[i2].length + 1;
            }
            rowData.put(this.m._names[i2], catValuesAsMap);
        }
        return rowData;
    }

    private static Map<String, Double> catValuesAsMap(String[] strArr, double[] dArr, int i) {
        HashMap hashMap = new HashMap(strArr.length + 1);
        for (int i2 = 0; i2 < strArr.length; i2++) {
            hashMap.put(strArr[i2], Double.valueOf(dArr[i2 + i]));
        }
        hashMap.put(null, Double.valueOf(dArr[i + strArr.length]));
        return hashMap;
    }

    public DimReductionModelPrediction predictDimReduction(RowData rowData) throws PredictException {
        double[] preamble = preamble(ModelCategory.DimReduction, rowData);
        DimReductionModelPrediction dimReductionModelPrediction = new DimReductionModelPrediction();
        dimReductionModelPrediction.dimensions = preamble;
        return dimReductionModelPrediction;
    }

    public Word2VecPrediction predictWord2Vec(RowData rowData) throws PredictException {
        validateModelCategory(ModelCategory.WordEmbedding);
        if (!(this.m instanceof WordEmbeddingModel)) {
            throw new PredictException("Model is not of the expected type, class = " + this.m.getClass().getSimpleName());
        }
        WordEmbeddingModel wordEmbeddingModel = (WordEmbeddingModel) this.m;
        int vecSize = wordEmbeddingModel.getVecSize();
        HashMap<String, float[]> hashMap = new HashMap<>(rowData.size());
        for (String str : rowData.keySet()) {
            Object obj = rowData.get(str);
            if (obj instanceof String) {
                hashMap.put(str, wordEmbeddingModel.transform0((String) obj, new float[vecSize]));
            }
        }
        Word2VecPrediction word2VecPrediction = new Word2VecPrediction();
        word2VecPrediction.wordEmbeddings = hashMap;
        return word2VecPrediction;
    }

    public BinomialModelPrediction predictBinomial(RowData rowData) throws PredictException {
        double[] preamble = preamble(ModelCategory.Binomial, rowData);
        BinomialModelPrediction binomialModelPrediction = new BinomialModelPrediction();
        binomialModelPrediction.labelIndex = (int) preamble[0];
        binomialModelPrediction.label = this.m.getDomainValues(this.m.getResponseIdx())[binomialModelPrediction.labelIndex];
        binomialModelPrediction.classProbabilities = new double[this.m.getNumResponseClasses()];
        System.arraycopy(preamble, 1, binomialModelPrediction.classProbabilities, 0, binomialModelPrediction.classProbabilities.length);
        if (this.m.calibrateClassProbabilities(preamble)) {
            binomialModelPrediction.calibratedClassProbabilities = new double[this.m.getNumResponseClasses()];
            System.arraycopy(preamble, 1, binomialModelPrediction.calibratedClassProbabilities, 0, binomialModelPrediction.calibratedClassProbabilities.length);
        }
        return binomialModelPrediction;
    }

    public MultinomialModelPrediction predictMultinomial(RowData rowData) throws PredictException {
        double[] preamble = preamble(ModelCategory.Multinomial, rowData);
        MultinomialModelPrediction multinomialModelPrediction = new MultinomialModelPrediction();
        multinomialModelPrediction.classProbabilities = new double[this.m.getNumResponseClasses()];
        multinomialModelPrediction.labelIndex = (int) preamble[0];
        multinomialModelPrediction.label = this.m.getDomainValues(this.m.getResponseIdx())[multinomialModelPrediction.labelIndex];
        System.arraycopy(preamble, 1, multinomialModelPrediction.classProbabilities, 0, multinomialModelPrediction.classProbabilities.length);
        return multinomialModelPrediction;
    }

    private SortedClassProbability[] sortByDescendingClassProbability(String[] strArr, double[] dArr) {
        if (!$assertionsDisabled && dArr.length != strArr.length) {
            throw new AssertionError();
        }
        SortedClassProbability[] sortedClassProbabilityArr = new SortedClassProbability[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            sortedClassProbabilityArr[i] = new SortedClassProbability();
            sortedClassProbabilityArr[i].name = strArr[i];
            sortedClassProbabilityArr[i].probability = dArr[i];
        }
        Arrays.sort(sortedClassProbabilityArr, Collections.reverseOrder());
        return sortedClassProbabilityArr;
    }

    public SortedClassProbability[] sortByDescendingClassProbability(BinomialModelPrediction binomialModelPrediction) {
        return sortByDescendingClassProbability(this.m.getDomainValues(this.m.getResponseIdx()), binomialModelPrediction.classProbabilities);
    }

    public SortedClassProbability[] sortByDescendingClassProbability(MultinomialModelPrediction multinomialModelPrediction) {
        return sortByDescendingClassProbability(this.m.getDomainValues(this.m.getResponseIdx()), multinomialModelPrediction.classProbabilities);
    }

    public ClusteringModelPrediction predictClustering(RowData rowData) throws PredictException {
        double[] preamble = preamble(ModelCategory.Clustering, rowData);
        ClusteringModelPrediction clusteringModelPrediction = new ClusteringModelPrediction();
        clusteringModelPrediction.cluster = (int) preamble[0];
        return clusteringModelPrediction;
    }

    public RegressionModelPrediction predictRegression(RowData rowData) throws PredictException {
        double[] preamble = preamble(ModelCategory.Regression, rowData);
        RegressionModelPrediction regressionModelPrediction = new RegressionModelPrediction();
        regressionModelPrediction.value = preamble[0];
        return regressionModelPrediction;
    }

    public ModelCategory getModelCategory() {
        return this.m.getModelCategory();
    }

    public String[] getResponseDomainValues() {
        return this.m.getDomainValues(this.m.getResponseIdx());
    }

    public String getHeader() {
        return this.m.getHeader();
    }

    private void setupConvertUnknownCategoricalLevelsToNa() {
        if (!this.convertUnknownCategoricalLevelsToNa) {
            this.unknownCategoricalLevelsSeenPerColumn.clear();
            return;
        }
        for (int i = 0; i < this.m.getNumCols(); i++) {
            if (this.m.getDomainValues(i) != null) {
                this.unknownCategoricalLevelsSeenPerColumn.put(this.m.getNames()[i], new AtomicLong());
            }
        }
    }

    private void validateModelCategory(ModelCategory modelCategory) throws PredictException {
        if (!this.m.getModelCategories().contains(modelCategory)) {
            throw new PredictException(modelCategory + " prediction type is not supported for this model.");
        }
    }

    private double[] preamble(ModelCategory modelCategory, RowData rowData) throws PredictException {
        validateModelCategory(modelCategory);
        return predict(rowData, new double[this.m.getPredsSize(modelCategory)]);
    }

    private static double[] nanArray(int i) {
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2] = Double.NaN;
        }
        return dArr;
    }

    private double[] fillRawData(RowData rowData, double[] dArr) throws PredictException {
        BufferedImage read;
        double doubleValue;
        boolean z = (this.m instanceof DeepwaterMojoModel) && ((DeepwaterMojoModel) this.m)._problem_type.equals("image");
        boolean z2 = (this.m instanceof DeepwaterMojoModel) && ((DeepwaterMojoModel) this.m)._problem_type.equals("text");
        for (String str : rowData.keySet()) {
            Integer num = this.modelColumnNameToIndexMap.get(str);
            if (num != null && num.intValue() < dArr.length) {
                BufferedImage bufferedImage = null;
                if (this.m.getDomainValues(num.intValue()) == null) {
                    double d = Double.NaN;
                    Object obj = rowData.get(str);
                    if (obj instanceof String) {
                        String trim = ((String) obj).trim();
                        if (z) {
                            if (trim.matches("^(https?|ftp|file)://[-a-zA-Z0-9+&@#/%?=~_|!:,.;]*[-a-zA-Z0-9+&@#/%=~_|]")) {
                                try {
                                    read = ImageIO.read(new URL(trim));
                                } catch (IOException e) {
                                    throw new PredictException("Couldn't read image from " + trim);
                                }
                            } else {
                                read = ImageIO.read(new File(trim));
                            }
                            bufferedImage = read;
                        } else {
                            if (z2) {
                                throw new PredictException("MOJO scoring for text classification is not yet implemented.");
                            }
                            try {
                                d = Double.parseDouble(trim);
                            } catch (NumberFormatException e2) {
                                if (!this.convertInvalidNumbersToNa) {
                                    throw new PredictNumberFormatException("Unable to parse value: " + trim + ", from column: " + str + ", as Double; " + e2.getMessage());
                                }
                            }
                        }
                    } else if (obj instanceof Double) {
                        d = ((Double) obj).doubleValue();
                    } else {
                        if (!(obj instanceof byte[]) || !z) {
                            throw new PredictUnknownTypeException("Unexpected object type " + obj.getClass().getName() + " for numeric column " + str);
                        }
                        try {
                            bufferedImage = ImageIO.read(new ByteArrayInputStream((byte[]) obj));
                        } catch (IOException e3) {
                            throw new PredictException("Couldn't interpret raw bytes as an image.");
                        }
                    }
                    if (z && bufferedImage != null) {
                        DeepwaterMojoModel deepwaterMojoModel = (DeepwaterMojoModel) this.m;
                        int i = deepwaterMojoModel._width;
                        int i2 = deepwaterMojoModel._height;
                        int i3 = deepwaterMojoModel._channels;
                        float[] fArr = new float[i * i2 * i3];
                        try {
                            GenModel.img2pixels(bufferedImage, i, i2, i3, fArr, 0, deepwaterMojoModel._meanImageData);
                            double[] dArr2 = new double[fArr.length];
                            for (int i4 = 0; i4 < dArr2.length; i4++) {
                                dArr2[i4] = fArr[i4];
                            }
                            return dArr2;
                        } catch (IOException e4) {
                            e4.printStackTrace();
                            throw new PredictException("Couldn't vectorize image.");
                        }
                    }
                    dArr[num.intValue()] = d;
                } else {
                    Object obj2 = rowData.get(str);
                    if (obj2 instanceof String) {
                        String str2 = (String) obj2;
                        HashMap<String, Integer> hashMap = this.domainMap.get(num);
                        Integer num2 = hashMap.get(str2);
                        if (num2 == null) {
                            num2 = hashMap.get(str + "." + str2);
                        }
                        if (num2 != null) {
                            doubleValue = num2.intValue();
                        } else {
                            if (!this.convertUnknownCategoricalLevelsToNa) {
                                throw new PredictUnknownCategoricalLevelException("Unknown categorical level (" + str + "," + str2 + ")", str, str2);
                            }
                            doubleValue = Double.NaN;
                            this.unknownCategoricalLevelsSeenPerColumn.get(str).incrementAndGet();
                        }
                    } else {
                        if (!(obj2 instanceof Double) || !Double.isNaN(((Double) obj2).doubleValue())) {
                            throw new PredictUnknownTypeException("Unexpected object type " + obj2.getClass().getName() + " for categorical column " + str);
                        }
                        doubleValue = ((Double) obj2).doubleValue();
                    }
                    dArr[num.intValue()] = doubleValue;
                }
            }
        }
        return dArr;
    }

    private double[] predict(RowData rowData, double[] dArr) throws PredictException {
        return this.m.score0(fillRawData(rowData, nanArray(this.m.nfeatures())), dArr);
    }

    static {
        $assertionsDisabled = !EasyPredictModelWrapper.class.desiredAssertionStatus();
    }
}
