package hex;

import hex.ContributionsWithBackgroundFrameTask;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.stream.IntStream;
import water.H2O;
import water.H2ONode;
import water.Job;
import water.JobUpdatePostMap;
import water.Key;
import water.MRTask;
import water.Scope;
import water.SplitToChunksApplyCombine;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.util.Log;

/* loaded from: input_file:hex/ContributionsWithBackgroundFrameTask.class */
public abstract class ContributionsWithBackgroundFrameTask<T extends ContributionsWithBackgroundFrameTask<T>> extends MRTask<T> {
    transient Frame _frame;
    transient Frame _backgroundFrame;
    Key<Frame> _frameKey;
    Key<Frame> _backgroundFrameKey;
    final boolean _aggregate;
    boolean _isFrameBigger;
    long _startRow;
    long _endRow;
    Job _job;
    static final /* synthetic */ boolean $assertionsDisabled;

    public ContributionsWithBackgroundFrameTask(Key<Frame> key, Key<Frame> key2, boolean z) {
        if (!$assertionsDisabled && null == key.get()) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && null == key2.get()) {
            throw new AssertionError();
        }
        this._frameKey = key;
        this._backgroundFrameKey = key2;
        this._frame = key.get();
        this._backgroundFrame = key2.get();
        if (!$assertionsDisabled && this._frame.numRows() <= 0) {
            throw new AssertionError("Frame has to contain at least one row.");
        }
        if (!$assertionsDisabled && this._backgroundFrame.numRows() <= 0) {
            throw new AssertionError("Background frame has to contain at least one row.");
        }
        this._isFrameBigger = this._frame.numRows() > this._backgroundFrame.numRows();
        this._aggregate = !z;
        this._startRow = -1L;
        this._endRow = -1L;
    }

    protected void loadFrames() {
        if (null == this._frame) {
            this._frame = this._frameKey.get();
        }
        if (null == this._backgroundFrame) {
            this._backgroundFrame = this._backgroundFrameKey.get();
        }
        if ($assertionsDisabled) {
            return;
        }
        if (this._frame == null || this._backgroundFrame == null) {
            throw new AssertionError();
        }
    }

    @Override // water.MRTask
    public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
        loadFrames();
        Frame frame = this._isFrameBigger ? this._backgroundFrame : this._frame;
        long j = 0;
        long numRows = frame.numRows();
        if (!this._isFrameBigger && this._startRow != -1 && this._endRow != -1) {
            j = this._startRow;
            numRows = this._endRow;
        }
        while (j < numRows && !isCancelled()) {
            if (null != this._job && this._job.stop_requested()) {
                return;
            }
            long j2 = j;
            Chunk[] chunkArr2 = (Chunk[]) IntStream.range(0, frame.numCols()).mapToObj(i -> {
                return frame.vec(i).chunkForRow(j2);
            }).toArray(i2 -> {
                return new Chunk[i2];
            });
            NewChunk[] newChunkArr2 = (NewChunk[]) Arrays.copyOf(newChunkArr, newChunkArr.length - 2);
            if (this._isFrameBigger) {
                map(chunkArr, chunkArr2, newChunkArr2);
                for (int i3 = 0; i3 < chunkArr[0]._len; i3++) {
                    for (int i4 = 0; i4 < chunkArr2[0]._len; i4++) {
                        newChunkArr[newChunkArr.length - 2].addNum(chunkArr[0].start() + i3);
                        newChunkArr[newChunkArr.length - 1].addNum(chunkArr2[0].start() + i4);
                    }
                }
            } else {
                map(chunkArr2, chunkArr, newChunkArr2);
                for (int i5 = 0; i5 < chunkArr2[0]._len; i5++) {
                    for (int i6 = 0; i6 < chunkArr[0]._len; i6++) {
                        newChunkArr[newChunkArr.length - 2].addNum(chunkArr2[0].start() + i5);
                        newChunkArr[newChunkArr.length - 1].addNum(chunkArr[0].start() + i6);
                    }
                }
            }
            j += chunkArr2[0]._len;
        }
    }

    public static double estimateRequiredMemory(int i, Frame frame, Frame frame2) {
        return 8 * i * frame.numRows() * frame2.numRows();
    }

    public static double estimatePerNodeMinimalMemory(int i, Frame frame, Frame frame2) {
        boolean z = frame.numRows() > frame2.numRows();
        double estimateRequiredMemory = estimateRequiredMemory(i, frame, frame2);
        long[] espc = (z ? frame : frame2).anyVec().espc();
        double numRows = ((16 * i) * r15.numRows()) / r15.anyVec().nChunks();
        if (null != espc) {
            long j = 0;
            for (int i2 = 0; i2 < espc.length - 1; i2++) {
                j = Math.max(j, espc[i2 + 1] - espc[i2]);
            }
            numRows = Math.max(numRows, 8 * i * j);
        }
        return Math.max(estimateRequiredMemory / H2O.CLOUD._memary.length, numRows + ((z ? frame2.numRows() : frame.numRows()) * i * 8));
    }

    double estimatePerNodeMinimalMemory(int i) {
        return estimatePerNodeMinimalMemory(i, this._frame, this._backgroundFrame);
    }

    public static long minMemoryPerNode() {
        long j = Long.MAX_VALUE;
        for (H2ONode h2ONode : H2O.CLOUD._memary) {
            long j2 = h2ONode._heartbeat.get_free_mem();
            if (j2 < j) {
                j = j2;
            }
        }
        return j;
    }

    public static long totalFreeMemory() {
        long j = 0;
        for (H2ONode h2ONode : H2O.CLOUD._memary) {
            j += h2ONode._heartbeat.get_free_mem();
        }
        return j;
    }

    public static boolean enoughMinMemory(double d) {
        return ((double) minMemoryPerNode()) > d;
    }

    protected abstract void map(Chunk[] chunkArr, Chunk[] chunkArr2, NewChunk[] newChunkArr);

    void setChunkRange(int i, int i2) {
        if (!$assertionsDisabled && this._isFrameBigger) {
            throw new AssertionError();
        }
        this._startRow = this._frame.anyVec().chunkForChunkIdx(i).start();
        this._endRow = this._frame.anyVec().chunkForChunkIdx(i2).start() + this._frame.anyVec().chunkForChunkIdx(i2)._len;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Frame runAndGetOutput(Job job, Key<Frame> key, String[] strArr) {
        this._job = job;
        loadFrames();
        double estimateRequiredMemory = estimateRequiredMemory(strArr.length + 2, this._frame, this._backgroundFrame);
        double estimatePerNodeMinimalMemory = estimatePerNodeMinimalMemory(strArr.length + 2);
        String[] strArr2 = new String[strArr.length + 2];
        System.arraycopy(strArr, 0, strArr2, 0, strArr.length);
        strArr2[strArr.length] = "RowIdx";
        strArr2[strArr.length + 1] = "BackgroundRowIdx";
        Key<Frame> make = this._aggregate ? Key.make(key + "_individual_contribs") : key;
        if (!this._aggregate) {
            if (enoughMinMemory(estimatePerNodeMinimalMemory)) {
                return ((ContributionsWithBackgroundFrameTask) withPostMapAction(JobUpdatePostMap.forJob(job)).doAll(strArr2.length, (byte) 3, this._isFrameBigger ? this._frame : this._backgroundFrame)).outputFrame(make, strArr2, (String[][]) null);
            }
            throw new RuntimeException("Not enough memory. Estimated minimal total memory is " + estimateRequiredMemory + "B. Estimated minimal per node memory (assuming perfectly balanced datasets) is " + estimatePerNodeMinimalMemory + "B. Node with minimum memory has " + minMemoryPerNode() + "B. Total available memory is " + totalFreeMemory() + "B.");
        }
        if (enoughMinMemory(estimatePerNodeMinimalMemory)) {
            Frame outputFrame = ((ContributionsWithBackgroundFrameTask) withPostMapAction(JobUpdatePostMap.forJob(job)).doAll(strArr2.length, (byte) 3, this._isFrameBigger ? this._frame : this._backgroundFrame)).outputFrame(make, strArr2, (String[][]) null);
            try {
                Frame outputFrame2 = new ContributionsMeanAggregator(this._job, (int) this._frame.numRows(), strArr.length, (int) this._backgroundFrame.numRows()).withPostMapAction(JobUpdatePostMap.forJob(job)).doAll(strArr.length, (byte) 3, outputFrame).outputFrame(key, strArr, (String[][]) null);
                outputFrame.delete(true);
                return outputFrame2;
            } catch (Throwable th) {
                outputFrame.delete(true);
                throw th;
            }
        }
        if (minMemoryPerNode() < 5 * (strArr.length + 2) * this._frame.numRows() * 8) {
            throw new RuntimeException("Not enough memory. Estimated minimal total memory is " + estimateRequiredMemory + "B. Estimated minimal per node memory (assuming perfectly balanced datasets) is " + estimatePerNodeMinimalMemory + "B. Node with minimum memory has " + minMemoryPerNode() + "B. Total available memory is " + totalFreeMemory() + "B.");
        }
        int nChunks = this._frame.anyVec().nChunks();
        int max = (int) Math.max(1.0d, Math.floor(nChunks / nChunks));
        Log.warn("Not enough memory to calculate SHAP at once. Calculating in " + nChunks + " iterations.");
        this._isFrameBigger = false;
        Scope.Safe safe = Scope.safe(new Frame[0]);
        Throwable th2 = null;
        try {
            try {
                LinkedList linkedList = new LinkedList();
                for (int i = 0; i < nChunks; i++) {
                    setChunkRange(i * max, Math.min(nChunks - 1, ((i + 1) * max) - 1));
                    Frame outputFrame3 = ((ContributionsWithBackgroundFrameTask) ((ContributionsWithBackgroundFrameTask) m1510clone()).withPostMapAction(JobUpdatePostMap.forJob(job)).doAll(strArr2.length, (byte) 3, this._backgroundFrame)).outputFrame(Key.make(key + "_individual_contribs_" + i), strArr2, (String[][]) null);
                    linkedList.add(Scope.track(new ContributionsMeanAggregator(this._job, (int) (this._endRow - this._startRow), strArr.length, (int) this._backgroundFrame.numRows()).setStartIndex((int) this._startRow).withPostMapAction(JobUpdatePostMap.forJob(job)).doAll(strArr.length, (byte) 3, outputFrame3).outputFrame(Key.make(key + "_part_" + i), strArr, (String[][]) null)));
                    outputFrame3.delete();
                }
                Frame untrack = Scope.untrack(SplitToChunksApplyCombine.concatFrames(linkedList, key));
                if (safe != null) {
                    if (0 != 0) {
                        try {
                            safe.close();
                        } catch (Throwable th3) {
                            th2.addSuppressed(th3);
                        }
                    } else {
                        safe.close();
                    }
                }
                return untrack;
            } finally {
            }
        } catch (Throwable th4) {
            if (safe != null) {
                if (th2 != null) {
                    try {
                        safe.close();
                    } catch (Throwable th5) {
                        th2.addSuppressed(th5);
                    }
                } else {
                    safe.close();
                }
            }
            throw th4;
        }
    }

    static {
        $assertionsDisabled = !ContributionsWithBackgroundFrameTask.class.desiredAssertionStatus();
    }
}
