package ai.h2o.mojos.runtime;

import ai.h2o.mojos.runtime.api.MojoColumnMeta;
import ai.h2o.mojos.runtime.api.MojoTransformationGroup;
import ai.h2o.mojos.runtime.transforms.I;
import ai.h2o.mojos.runtime.transforms.K;
import ai.h2o.mojos.runtime.transforms.L;
import ai.h2o.mojos.runtime.transforms.MojoTransform;
import ai.h2o.mojos.runtime.transforms.MojoTransformExecPipeBuilder;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/h2o/mojos/runtime/PipelineWiring.class */
public class PipelineWiring {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) PipelineWiring.class);
    private final List<MojoColumnMeta> globalColumns;
    private final MojoTransform[] producers;
    private final Set<String> transformedFeatures;
    public final List<MojoTransform> transformsFlattened = new ArrayList();
    private int rootLevelTreeCount = 0;
    private int nestedPipelineCount = 0;
    private int prematureTraversalStops = 0;
    public List<MojoTransform> shapTransforms = new ArrayList();

    public void noshap(MojoTransform mojoTransform) {
        log.trace("Do not compute SHAP on: {}", mojoTransform);
        if (!this.shapTransforms.remove(mojoTransform)) {
            throw new IllegalStateException("Failed to remove transformation from shapTransforms: " + mojoTransform);
        }
    }

    public PipelineWiring(List<MojoColumnMeta> list, MojoTransformExecPipeBuilder mojoTransformExecPipeBuilder) {
        this.globalColumns = list;
        this.transformedFeatures = mojoTransformExecPipeBuilder.pipelineMeta.transformedFeatures == null ? Collections.emptySet() : new LinkedHashSet<>(mojoTransformExecPipeBuilder.pipelineMeta.transformedFeatures);
        addChildrenFlattened(mojoTransformExecPipeBuilder.transforms);
        this.producers = new MojoTransform[list.size()];
        for (MojoTransform mojoTransform : this.transformsFlattened) {
            for (int i : mojoTransform.oindices) {
                this.producers[i] = mojoTransform;
            }
            if (mojoTransform instanceof L) {
                this.shapTransforms.add(mojoTransform);
            }
        }
        Iterator<MojoTransform> it = mojoTransformExecPipeBuilder.transforms.iterator();
        while (it.hasNext()) {
            if (it.next() instanceof K) {
                this.rootLevelTreeCount++;
            }
        }
    }

    private void addChildrenFlattened(List<MojoTransform> list) {
        for (MojoTransform mojoTransform : list) {
            if (mojoTransform instanceof MojoTransformExecPipeBuilder) {
                addChildrenFlattened(((MojoTransformExecPipeBuilder) mojoTransform).transforms);
                this.nestedPipelineCount++;
            } else {
                this.transformsFlattened.add(mojoTransform);
            }
        }
    }

    public MojoTransform getProducer(int i) {
        if (i < 0 || i > this.producers.length) {
            throw new IllegalStateException(String.format("Column index invalid (%d), no transformer can produce it", Integer.valueOf(i)));
        }
        return this.producers[i];
    }

    public List<MojoColumnMeta> getColumns() {
        return this.globalColumns;
    }

    public MojoTransform shapCapableOrigin(int i) {
        int i2 = i;
        while (true) {
            int i3 = i2;
            MojoTransform producer = getProducer(i3);
            log.trace("traversing through {}({}), seeking index {}", producer.toString(), producer.getTransformationGroup(), Integer.valueOf(i3));
            if (producer instanceof I) {
                int i4 = -1;
                for (int i5 = 0; i5 < producer.oindices.length; i5++) {
                    if (producer.oindices[i5] == i3) {
                        i4 = producer.iindices[i5];
                    }
                }
                if (i4 < 0) {
                    throw new IllegalArgumentException("output index not found in softmax: " + i3);
                }
                i2 = i4;
            } else {
                if (producer instanceof L) {
                    return producer;
                }
                if (producer.iindices.length != 1) {
                    throw new IllegalArgumentException("only 1:1 transform expected while traversing from blending operations up to the ShapCapable model, but found " + producer.getName());
                }
                i2 = producer.iindices[0];
            }
        }
    }

    private MojoColumnMeta getGroupInputColumn(MojoTransformationGroup mojoTransformationGroup, int i) {
        return this.globalColumns.get(getGroupInputColumnIndex(mojoTransformationGroup, i));
    }

    public int getGroupInputColumnIndex(MojoTransformationGroup mojoTransformationGroup, int i) {
        MojoTransformationGroup transformationGroup;
        int i2 = i;
        if (mojoTransformationGroup != null) {
            while (true) {
                MojoTransform producer = getProducer(i2);
                if (producer == null || (transformationGroup = producer.getTransformationGroup()) == null || !transformationGroup.getId().equals(mojoTransformationGroup.getId())) {
                    break;
                }
                int[] iArr = producer.iindices;
                if (iArr.length != 1) {
                    throw new IllegalStateException(String.format("producer of column #%d has %d columns; exactly 1 is required: %s", Integer.valueOf(i2), Integer.valueOf(iArr.length), producer));
                }
                if (producer.oindices.length != 1) {
                    throw new IllegalStateException(String.format("producer of column #%d has %d columns; exactly 1 is required: %s", Integer.valueOf(i2), Integer.valueOf(producer.oindices.length), producer));
                }
                String columnName = this.globalColumns.get(i2).getColumnName();
                if (this.transformedFeatures.contains(columnName)) {
                    log.debug("traversal stops on '{}' PRIOR reaching boundary of group '{}', due to a hint from `Pipeline.transformed`; is it constructed correctly?", columnName, mojoTransformationGroup);
                    this.prematureTraversalStops++;
                    break;
                }
                i2 = iArr[0];
            }
        }
        return i2;
    }

    public Set<String> getGroupInputColumns(MojoTransformationGroup mojoTransformationGroup, int[] iArr) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        for (int i : iArr) {
            linkedHashSet.add(getGroupInputColumn(mojoTransformationGroup, i).getColumnName());
        }
        return linkedHashSet;
    }

    public void reportPrematureTraversals() {
        if (this.prematureTraversalStops > 0) {
            log.warn("Premature traversal stops occurred {} times. See DEBUG log for full list. Columns in `Pipeline.transformed` might need review.", Integer.valueOf(this.prematureTraversalStops));
        }
    }

    public boolean isEnsemble() {
        return this.nestedPipelineCount > 0;
    }

    public boolean isTreeMetalearner() {
        return this.rootLevelTreeCount > 0 && this.nestedPipelineCount > 0;
    }
}
