package ai.h2o.mojos.runtime;

import ai.h2o.mojos.runtime.api.MojoColumnMeta;
import ai.h2o.mojos.runtime.frame.MojoColumn;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import ai.h2o.mojos.runtime.frame.MojoFrameMeta;
import ai.h2o.mojos.runtime.shap.Scaler;
import ai.h2o.mojos.runtime.shap.ScalerGroup;
import ai.h2o.mojos.runtime.shap.TransformTraversal;
import ai.h2o.mojos.runtime.transforms.MojoTransform;
import ai.h2o.mojos.runtime.transforms.MojoTransformExecPipeBuilder;
import ai.h2o.mojos.runtime.transforms.MojoTransformNewColumnBuilder;
import ai.h2o.mojos.runtime.transforms.ShapCapable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/h2o/mojos/runtime/ShapBlender.class */
public class ShapBlender {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) ShapBlender.class);
    private final MojoTransformExecPipeBuilder root;
    private final List<MojoColumnMeta> globalColumns;
    private final boolean shapOriginal;
    private final Map<ShapCapable, List<Integer>> pcIndicesByTransform = new LinkedHashMap();
    private final Map<String, Integer> shapColumnsByName = new LinkedHashMap();
    private Map<Integer, Scaler> scalerByOutputColumn;

    public ShapBlender(List<MojoColumnMeta> list, MojoTransformExecPipeBuilder mojoTransformExecPipeBuilder, boolean z) {
        this.root = mojoTransformExecPipeBuilder;
        this.globalColumns = list;
        this.shapOriginal = z;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Set<Integer> prepareShapColumns(PipelineWiring pipelineWiring) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        int i = 0;
        for (MojoTransform mojoTransform : pipelineWiring.shapTransforms) {
            List<Integer> buildShapColumns = buildShapColumns(this.shapColumnsByName, pipelineWiring.getGroupInputColumns(mojoTransform.getTransformationGroup(), mojoTransform.iindices), this.root.pipelineMeta.probabilityComplementDetected ? 1 : this.root.pipelineMeta.outputClassLabels == null ? this.root.oindices.length : this.root.pipelineMeta.outputClassLabels.size());
            if (mojoTransform instanceof MojoTransformNewColumnBuilder) {
                this.pcIndicesByTransform.put((ShapCapable) mojoTransform, buildShapColumns.subList(i % buildShapColumns.size(), (i % buildShapColumns.size()) + 1));
                i++;
            } else {
                this.pcIndicesByTransform.put((ShapCapable) mojoTransform, buildShapColumns);
            }
            linkedHashSet.addAll(buildShapColumns);
        }
        this.scalerByOutputColumn = TransformTraversal.searchScalers(pipelineWiring, this.root.oindices, this.root.pipelineMeta.outputClassLabels == null ? 1 : this.root.pipelineMeta.outputClassLabels.size());
        if (this.scalerByOutputColumn.size() == 0) {
            this.scalerByOutputColumn = new LinkedHashMap();
            ScalerGroup scalerGroup = new ScalerGroup();
            Iterator<ShapCapable> it = this.pcIndicesByTransform.keySet().iterator();
            while (it.hasNext()) {
                for (int i2 : ((MojoTransform) ((ShapCapable) it.next())).oindices) {
                    this.scalerByOutputColumn.put(Integer.valueOf(i2), scalerGroup);
                }
            }
        }
        return this.shapOriginal ? switchToShapOriginalColumns() : linkedHashSet;
    }

    private Set<Integer> switchToShapOriginalColumns() {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        for (int i : this.root.iindices) {
            linkedHashSet.add(this.globalColumns.get(i).getColumnName());
        }
        List<Integer> buildShapColumns = buildShapColumns(this.shapColumnsByName, linkedHashSet, this.root.pipelineMeta.probabilityComplementDetected ? 1 : this.root.oindices.length);
        log.trace("Original SHAP column indices are: {}", buildShapColumns);
        return new LinkedHashSet(buildShapColumns);
    }

    private List<Integer> buildShapColumns(Map<String, Integer> map, Set<String> set, int i) {
        ArrayList arrayList = new ArrayList(i * (set.size() + 1));
        for (int i2 = 0; i2 < i; i2++) {
            String str = i > 1 ? "." + this.root.pipelineMeta.outputClassLabels.get(i2) : "";
            Iterator<String> it = set.iterator();
            while (it.hasNext()) {
                arrayList.add(Integer.valueOf(shapColumn(map, "contrib_" + it.next() + str)));
            }
            arrayList.add(Integer.valueOf(shapColumn(map, "contrib_bias" + str)));
        }
        return arrayList;
    }

    private int shapColumn(Map<String, Integer> map, String str) {
        Integer num = map.get(str);
        if (num == null) {
            num = Integer.valueOf(this.globalColumns.size());
            MojoColumnMeta create = MojoColumnMeta.create(str, MojoColumn.Type.Float64);
            map.put(str, num);
            this.globalColumns.add(create);
        }
        return num.intValue();
    }

    private Scaler getScaler(int i) {
        Scaler scaler = this.scalerByOutputColumn.get(Integer.valueOf(i));
        if (scaler == null) {
            throw new IllegalStateException(String.format("Error in blender - no scaler found for column %d('%s')", Integer.valueOf(i), this.globalColumns.get(i).getColumnName()));
        }
        return scaler;
    }

    private List<Integer> getShapColumnIndices(ShapCapable shapCapable) {
        List<Integer> list = this.pcIndicesByTransform.get(shapCapable);
        if (list == null) {
            throw new IllegalStateException(String.format("Shap indices not available for transform: %s", shapCapable));
        }
        log.trace("pcIndices ~ {}", MojoFrameMeta.debugIndicesToNames(this.globalColumns, list));
        return list;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void computeShap(MojoFrame mojoFrame, MojoTransform mojoTransform) {
        ShapCapable shapCapable = (ShapCapable) mojoTransform;
        OriginalMatrix originalMatrix = shapCapable.getOriginalMatrix();
        if (this.shapOriginal && originalMatrix == null) {
            throw new UnsupportedOperationException("Missing original matrix - cannot compute original SHAP for " + mojoTransform);
        }
        List<Integer> shapColumnIndices = getShapColumnIndices(shapCapable);
        double[] dArr = new double[shapColumnIndices.size()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = (double[]) mojoFrame.getColumnData(shapColumnIndices.get(i).intValue());
        }
        ShapBuffers shapBuffers = new ShapBuffers(mojoTransform);
        int nrows = mojoFrame.getNrows();
        for (int i2 = 0; i2 < nrows; i2++) {
            double[] prepareShapInputs = shapBuffers.prepareShapInputs(mojoFrame, i2);
            double[][] prepareShapOutputs = shapBuffers.prepareShapOutputs();
            shapCapable.computeShap(prepareShapInputs, prepareShapOutputs);
            int i3 = 0;
            for (int i4 = 0; i4 < prepareShapOutputs.length; i4++) {
                Scaler scaler = getScaler(mojoTransform.oindices[i4]);
                int length = prepareShapOutputs[i4].length;
                int i5 = 0;
                while (i5 < length) {
                    double d = prepareShapOutputs[i4][i5];
                    String columnName = mojoFrame.getColumnName(shapColumnIndices.get(i3).intValue());
                    if (Double.isNaN(d)) {
                        throw new IllegalStateException(String.format("Row %d: %s(%s) did not compute shapOutput[%d][%d] : `%s`", Integer.valueOf(i2), mojoTransform.getId(), mojoTransform.getClass().getName(), Integer.valueOf(i4), Integer.valueOf(i5), columnName));
                    }
                    double apply = scaler.apply(d, mojoFrame, i2, i5 == length - 1);
                    double[] dArr2 = dArr[i3];
                    int i6 = i2;
                    dArr2[i6] = dArr2[i6] + apply;
                    if (this.shapOriginal) {
                        originalMatrix.incrementOrigShap(mojoFrame, i2, columnName, apply);
                    }
                    i3++;
                    i5++;
                }
            }
        }
    }
}
