package com.linkedin.dagli.dag;

import com.linkedin.dagli.generator.Generator;
import com.linkedin.dagli.objectio.ObjectIterator;
import com.linkedin.dagli.objectio.ObjectReader;
import com.linkedin.dagli.objectio.ObjectWriter;
import com.linkedin.dagli.objectio.biglist.BigListWriter;
import com.linkedin.dagli.producer.Producer;
import com.linkedin.dagli.transformer.PreparedTransformer;
import it.unimi.dsi.fastutil.objects.ObjectBigArrayBigList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;

/* loaded from: input_file:com/linkedin/dagli/dag/FastPreparedDAGExecutor.class */
public final class FastPreparedDAGExecutor extends AbstractDAGExecutor<FastPreparedDAGExecutor> {
    private static final long serialVersionUID = 1;
    public static final int DEFAULT_MIN_INPUTS_PER_THREAD = 128;
    private int _maxThreads = 1;
    private int _minInputsPerThread = DEFAULT_MIN_INPUTS_PER_THREAD;
    private int _maxMinibatchSize = 1024;
    private boolean _useCommonPool = true;

    public FastPreparedDAGExecutor withCommonThreadPool(boolean z) {
        return (FastPreparedDAGExecutor) clone(fastPreparedDAGExecutor -> {
            fastPreparedDAGExecutor._useCommonPool = z;
        });
    }

    public FastPreparedDAGExecutor withMaxThreads(int i) {
        return (FastPreparedDAGExecutor) clone(fastPreparedDAGExecutor -> {
            fastPreparedDAGExecutor._maxThreads = i;
        });
    }

    public FastPreparedDAGExecutor withMinInputsPerThread(int i) {
        return (FastPreparedDAGExecutor) clone(fastPreparedDAGExecutor -> {
            fastPreparedDAGExecutor._minInputsPerThread = i;
        });
    }

    public FastPreparedDAGExecutor withMaxMinibatchSize(int i) {
        return (FastPreparedDAGExecutor) clone(fastPreparedDAGExecutor -> {
            fastPreparedDAGExecutor._maxMinibatchSize = i;
        });
    }

    @Override // com.linkedin.dagli.dag.AbstractDAGExecutor
    protected <R, N extends PreparedDAGTransformer<R, N>, T extends PreparableDAGTransformer<R, N, T>> DAGExecutionResult<R, N> prepareAndApplyUnsafeImpl(T t, ObjectReader<Object>[] objectReaderArr) {
        throw new UnsupportedOperationException("FastPreparedDAGExecutor cannot be used to prepare DAGs, only to apply already-prepared DAGs");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.linkedin.dagli.dag.AbstractDAGExecutor
    public <R, T extends PreparedDAGTransformer<R, T>> ObjectReader<?>[] applyUnsafeImpl(T t, ObjectReader<Object>[] objectReaderArr) {
        return executeUnsafeImpl(t.internalAPI().getDAGStructure(), objectReaderArr);
    }

    private <R> ObjectReader<?>[] executeUnsafeImpl(DAGStructure<R> dAGStructure, ObjectReader<Object>[] objectReaderArr) {
        long size64 = objectReaderArr[0].size64();
        Object[] createExecutionStateArray = dAGStructure.createExecutionStateArray(size64);
        int max = Math.max(1, Math.min(this._maxMinibatchSize, dAGStructure._maxMinibatchSize));
        int min = (int) Math.min(this._maxThreads, size64 / this._minInputsPerThread);
        ObjectIterator[] objectIteratorArr = (ObjectIterator[]) Arrays.stream(objectReaderArr).map((v0) -> {
            return v0.iterator();
        }).toArray(i -> {
            return new ObjectIterator[i];
        });
        if (min <= 1) {
            return executeUnsafeImplThread(dAGStructure, objectIteratorArr, 0L, size64, max, createExecutionStateArray);
        }
        ForkJoinPool commonPool = this._useCommonPool ? ForkJoinPool.commonPool() : new ForkJoinPool(min);
        ArrayList arrayList = new ArrayList(min);
        long j = ((size64 + min) - serialVersionUID) / min;
        int i2 = 0;
        while (i2 < min) {
            long j2 = i2 * j;
            long j3 = i2 == min - 1 ? size64 - (i2 * j) : j;
            BigListWriter[] bigListWriterArr = new BigListWriter[objectReaderArr.length];
            for (int i3 = 0; i3 < bigListWriterArr.length; i3++) {
                bigListWriterArr[i3] = new BigListWriter(new ObjectBigArrayBigList(j3));
                bigListWriterArr[i3].write(objectIteratorArr[i3], j3);
            }
            arrayList.add(() -> {
                return executeUnsafeImplThread(dAGStructure, (ObjectIterator[]) Arrays.stream(bigListWriterArr).map((v0) -> {
                    return v0.createReader();
                }).map((v0) -> {
                    return v0.iterator();
                }).toArray(i4 -> {
                    return new ObjectIterator[i4];
                }), j2, j3, max, createExecutionStateArray);
            });
            i2++;
        }
        List invokeAll = commonPool.invokeAll(arrayList);
        ObjectWriter[] emptyResultList = getEmptyResultList(dAGStructure, size64);
        for (int i4 = 0; i4 < min; i4++) {
            try {
                ObjectReader[] objectReaderArr2 = (ObjectReader[]) ((Future) invokeAll.get(i4)).get();
                for (int i5 = 0; i5 < emptyResultList.length; i5++) {
                    emptyResultList[i5].write(objectReaderArr2[i5].iterator(), objectReaderArr2[i5].size64());
                }
            } catch (InterruptedException | ExecutionException e) {
                throw new RuntimeException(e);
            }
        }
        return (ObjectReader[]) Arrays.stream(emptyResultList).map((v0) -> {
            return v0.createReader();
        }).toArray(i6 -> {
            return new ObjectReader[i6];
        });
    }

    private static <R> BigListWriter<Object>[] getEmptyResultList(DAGStructure<R> dAGStructure, long j) {
        BigListWriter<Object>[] bigListWriterArr = new BigListWriter[dAGStructure._outputs.size()];
        for (int i = 0; i < bigListWriterArr.length; i++) {
            bigListWriterArr[i] = new BigListWriter<>(new ObjectBigArrayBigList(j));
        }
        return bigListWriterArr;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static <R> ObjectReader<Object>[] executeUnsafeImplThread(DAGStructure<R> dAGStructure, ObjectIterator<Object>[] objectIteratorArr, long j, long j2, int i, Object[] objArr) {
        Object[][] objArr2 = new Object[dAGStructure._maxParentCount][i];
        Object[][] objArr3 = new Object[dAGStructure._nodes.length][i];
        BigListWriter<Object>[] emptyResultList = getEmptyResultList(dAGStructure, j2);
        long j3 = j;
        while (true) {
            long j4 = j3;
            if (j4 >= j + j2) {
                return (ObjectReader[]) Arrays.stream(emptyResultList).map((v0) -> {
                    return v0.createReader();
                }).toArray(i2 -> {
                    return new ObjectReader[i2];
                });
            }
            int min = (int) Math.min(i, (j + j2) - j4);
            for (int i3 = 0; i3 < objectIteratorArr.length; i3++) {
                objectIteratorArr[i3].next(objArr3[i3], 0, min);
            }
            apply(j4, min, dAGStructure, objArr3, objArr2, objArr);
            for (int i4 = 0; i4 < dAGStructure._outputIndices.length; i4++) {
                emptyResultList[i4].write(objArr3[dAGStructure._outputIndices[i4]], 0, min);
            }
            j3 = j4 + i;
        }
    }

    protected static <R> void apply(long j, int i, DAGStructure<R> dAGStructure, Object[][] objArr, Object[][] objArr2, Object[] objArr3) {
        for (int size = dAGStructure._placeholders.size(); size < dAGStructure._nodes.length; size++) {
            Producer<?> producer = dAGStructure._nodes[size];
            if (producer instanceof Generator) {
                for (int i2 = 0; i2 < i; i2++) {
                    objArr[size][i2] = ((Generator) producer).generate(j + i2);
                }
            } else {
                if (!(producer instanceof PreparedTransformer)) {
                    throw new IllegalStateException("DAG is not prepared; this executor only accepts prepared DAGs");
                }
                int[] iArr = dAGStructure._parents[size];
                for (int i3 = 0; i3 < iArr.length; i3++) {
                    objArr2[i3] = objArr[iArr[i3]];
                }
                ((PreparedTransformer) producer).internalAPI().applyAllUnsafe(objArr3[size], i, objArr2, objArr[size]);
            }
        }
    }

    @Override // com.linkedin.dagli.dag.AbstractDAGExecutor
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        FastPreparedDAGExecutor fastPreparedDAGExecutor = (FastPreparedDAGExecutor) obj;
        return this._maxThreads == fastPreparedDAGExecutor._maxThreads && this._minInputsPerThread == fastPreparedDAGExecutor._minInputsPerThread;
    }

    @Override // com.linkedin.dagli.dag.AbstractDAGExecutor
    public int hashCode() {
        return Objects.hash(Integer.valueOf(this._maxThreads), Integer.valueOf(this._minInputsPerThread));
    }

    @Override // com.linkedin.dagli.dag.AbstractDAGExecutor, com.linkedin.dagli.dag.PreparedDAGExecutor
    public /* bridge */ /* synthetic */ AbstractDAGExecutor internalAPI() {
        return super.internalAPI();
    }
}
