package hex.genmodel.algos.targetencoder;

import hex.genmodel.MojoModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:hex/genmodel/algos/targetencoder/TargetEncoderMojoModel.class */
public class TargetEncoderMojoModel extends MojoModel {
    public final Map<String, Integer> _columnNameToIdx;
    public Map<String, Boolean> _teColumn2HasNAs;
    public boolean _withBlending;
    public double _inflectionPoint;
    public double _smoothing;
    public List<ColumnsToSingleMapping> _inencMapping;
    public List<ColumnsMapping> _inoutMapping;
    List<String> _nonPredictors;
    Map<String, EncodingMap> _encodingsByCol;
    boolean _keepOriginalCategoricalColumns;
    private final boolean _imputeUnknownLevels = true;
    static final /* synthetic */ boolean $assertionsDisabled;

    public static double computeLambda(long j, double d, double d2) {
        return 1.0d / (1.0d + Math.exp((d - j) / d2));
    }

    public static double computeBlendedEncoding(double d, double d2, double d3) {
        return (d * d2) + ((1.0d - d) * d3);
    }

    static Map<String, Integer> name2Idx(String[] strArr) {
        HashMap hashMap = new HashMap(strArr.length);
        for (int i = 0; i < strArr.length; i++) {
            hashMap.put(strArr[i], Integer.valueOf(i));
        }
        return hashMap;
    }

    public TargetEncoderMojoModel(String[] strArr, String[][] strArr2, String str) {
        super(strArr, strArr2, str);
        this._imputeUnknownLevels = true;
        this._columnNameToIdx = name2Idx(strArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void init() {
        if (this._encodingsByCol == null) {
            return;
        }
        if (this._inencMapping == null) {
            this._inencMapping = new ArrayList();
        }
        if (this._inoutMapping == null) {
            this._inoutMapping = new ArrayList();
        }
        if (this._inencMapping.isEmpty() && this._inoutMapping.isEmpty()) {
            for (String str : this._encodingsByCol.keySet()) {
                String[] strArr = {str};
                this._inencMapping.add(new ColumnsToSingleMapping(strArr, str, null));
                String[] strArr2 = new String[getNumEncColsPerPredictor()];
                if (strArr2.length > 1) {
                    for (int i = 0; i < strArr2.length; i++) {
                        strArr2[i] = str + "_" + (i + 1) + "_te";
                    }
                } else {
                    strArr2[0] = str + "_te";
                }
                this._inoutMapping.add(new ColumnsMapping(strArr, strArr2));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setEncodings(EncodingMaps encodingMaps) {
        this._encodingsByCol = encodingMaps.encodingMap();
    }

    @Override // hex.genmodel.GenModel, water.genmodel.IGeneratedModel
    public int getPredsSize() {
        if (this._encodingsByCol == null) {
            return 0;
        }
        return this._encodingsByCol.size() * getNumEncColsPerPredictor();
    }

    int getNumEncColsPerPredictor() {
        if (nclasses() > 1) {
            return nclasses() - 1;
        }
        return 1;
    }

    @Override // hex.genmodel.GenModel
    public double[] score0(double[] dArr, double[] dArr2) {
        double interactionValue;
        if (this._encodingsByCol == null) {
            throw new IllegalStateException("Encoding map is missing.");
        }
        int i = 0;
        for (ColumnsToSingleMapping columnsToSingleMapping : this._inencMapping) {
            String[] from = columnsToSingleMapping.from();
            String single = columnsToSingleMapping.toSingle();
            EncodingMap encodingMap = this._encodingsByCol.get(single);
            int[] columnsIndices = columnsIndices(from);
            if (columnsIndices.length == 1) {
                interactionValue = dArr[columnsIndices[0]];
            } else {
                if (!$assertionsDisabled && columnsToSingleMapping.toDomainAsNum() == null) {
                    throw new AssertionError("Missing domain for interaction between columns " + Arrays.toString(from));
                }
                interactionValue = interactionValue(dArr, columnsIndices, columnsToSingleMapping.toDomainAsNum());
            }
            i += Double.isNaN(interactionValue) ? encodeNA(dArr2, i, encodingMap, single) : encodeCategory(dArr2, i, encodingMap, (int) interactionValue);
        }
        return dArr2;
    }

    public EncodingMap getEncodings(String str) {
        return this._encodingsByCol.get(str);
    }

    private int[] columnsIndices(String[] strArr) {
        int[] iArr = new int[strArr.length];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = this._columnNameToIdx.get(strArr[i]).intValue();
        }
        return iArr;
    }

    private double interactionValue(double[] dArr, int[] iArr, long[] jArr) {
        long j = 0;
        long j2 = 1;
        for (int i : iArr) {
            double d = dArr[i];
            int length = getDomainValues(i).length;
            if (Double.isNaN(d) || d >= length) {
                d = length;
            }
            j = (long) (j + (j2 * d));
            j2 *= length + 1;
        }
        int binarySearch = Arrays.binarySearch(jArr, j);
        if (binarySearch < 0) {
            return Double.NaN;
        }
        return binarySearch;
    }

    private double computeEncodedValue(double[] dArr, double d) {
        double d2 = dArr[0] / dArr[1];
        return this._withBlending ? computeBlendedEncoding(computeLambda((long) dArr[1], this._inflectionPoint, this._smoothing), d2, d) : d2;
    }

    int encodeCategory(double[] dArr, int i, EncodingMap encodingMap, int i2) {
        if (nclasses() <= 2) {
            dArr[i] = computeEncodedValue(encodingMap.getNumDen(i2), encodingMap.getPriorMean());
            return 1;
        }
        for (int i3 = 0; i3 < nclasses() - 1; i3++) {
            int i4 = i3 + 1;
            dArr[i + i3] = computeEncodedValue(encodingMap.getNumDen(i2, i4), encodingMap.getPriorMean(i4));
        }
        return nclasses() - 1;
    }

    int encodeNA(double[] dArr, int i, EncodingMap encodingMap, String str) {
        return this._teColumn2HasNAs.get(str).booleanValue() ? encodeCategory(dArr, i, encodingMap, encodingMap.getNACategory()) : encodeWithPriorMean(dArr, i, encodingMap);
    }

    private int encodeWithPriorMean(double[] dArr, int i, EncodingMap encodingMap) {
        if (this._nclasses <= 2) {
            dArr[i] = encodingMap.getPriorMean();
            return 1;
        }
        for (int i2 = 0; i2 < this._nclasses - 1; i2++) {
            dArr[i + i2] = encodingMap.getPriorMean(i2 + 1);
        }
        return this._nclasses - 1;
    }

    static {
        $assertionsDisabled = !TargetEncoderMojoModel.class.desiredAssertionStatus();
    }
}
