package net.jamu.matrix;

import java.util.Arrays;
import net.frobenius.TTrans;

/* loaded from: input_file:net/jamu/matrix/TensorD.class */
public class TensorD extends TensorBase {
    private static final double BETA = 1.0d;
    private static final int OFFS = 0;
    protected double[] a;

    public TensorD(int i, int i2) {
        this(i, i2, 1);
    }

    public TensorD(int i, int i2, int i3) {
        super(i, i2, i3);
        this.a = new double[this.length];
    }

    public TensorD(MatrixD matrixD) {
        super(matrixD.numRows(), matrixD.numColumns(), 1);
        this.a = Arrays.copyOf(matrixD.getArrayUnsafe(), matrixD.getArrayUnsafe().length);
    }

    public TensorD(TensorD tensorD) {
        super(tensorD.rows, tensorD.cols, tensorD.depth);
        this.a = Arrays.copyOf(tensorD.a, tensorD.a.length);
    }

    public TensorD set(int i, int i2, int i3, double d) {
        checkIndex(i, i2, i3);
        return setUnsafe(i, i2, i3, d);
    }

    public double get(int i, int i2, int i3) {
        checkIndex(i, i2, i3);
        return getUnsafe(i, i2, i3);
    }

    public TensorD setUnsafe(int i, int i2, int i3, double d) {
        this.a[idx(i, i2, i3)] = d;
        return this;
    }

    public double getUnsafe(int i, int i2, int i3) {
        return this.a[idx(i, i2, i3)];
    }

    public TensorD set(MatrixD matrixD, int i) {
        Checks.checkEqualDimension(this, matrixD);
        int startIdx = startIdx(i);
        double[] arrayUnsafe = matrixD.getArrayUnsafe();
        System.arraycopy(arrayUnsafe, OFFS, this.a, startIdx, arrayUnsafe.length);
        return this;
    }

    public MatrixD get(int i) {
        int startIdx = startIdx(i);
        int stride = stride();
        double[] dArr = new double[stride];
        System.arraycopy(this.a, startIdx, dArr, OFFS, stride);
        return new SimpleMatrixD(this.rows, this.cols, dArr);
    }

    public TensorD append(MatrixD matrixD) {
        Checks.checkEqualDimension(this, matrixD);
        double[] growAndCopyForAppend = growAndCopyForAppend(matrixD);
        double[] arrayUnsafe = matrixD.getArrayUnsafe();
        System.arraycopy(arrayUnsafe, OFFS, growAndCopyForAppend, this.length, arrayUnsafe.length);
        this.a = growAndCopyForAppend;
        this.length = growAndCopyForAppend.length;
        this.depth++;
        return this;
    }

    public TensorD multAdd(double d, TensorD tensorD, TensorD tensorD2) {
        Checks.checkMultAdd(this, tensorD, tensorD2);
        Matrices.getBlas().dgemm_multi(TTrans.NO_TRANS.val(), TTrans.NO_TRANS.val(), tensorD2.numRows(), tensorD2.numColumns(), this.cols, d, this.a, OFFS, Math.max(1, this.rows), tensorD.getArrayUnsafe(), OFFS, Math.max(1, tensorD.numRows()), BETA, tensorD2.getArrayUnsafe(), OFFS, Math.max(1, tensorD2.numRows()), Math.min(Math.min(this.depth, tensorD.depth), tensorD2.depth), stride(), tensorD.stride(), tensorD2.stride());
        return tensorD2;
    }

    public TensorD transAmultAdd(double d, TensorD tensorD, TensorD tensorD2) {
        Checks.checkTransAmultAdd(this, tensorD, tensorD2);
        Matrices.getBlas().dgemm_multi(TTrans.TRANS.val(), TTrans.NO_TRANS.val(), tensorD2.numRows(), tensorD2.numColumns(), this.rows, d, this.a, OFFS, Math.max(1, this.rows), tensorD.getArrayUnsafe(), OFFS, Math.max(1, tensorD.numRows()), BETA, tensorD2.getArrayUnsafe(), OFFS, Math.max(1, tensorD2.numRows()), Math.min(Math.min(this.depth, tensorD.depth), tensorD2.depth), stride(), tensorD.stride(), tensorD2.stride());
        return tensorD2;
    }

    public TensorD transBmultAdd(double d, TensorD tensorD, TensorD tensorD2) {
        Checks.checkTransBmultAdd(this, tensorD, tensorD2);
        Matrices.getBlas().dgemm_multi(TTrans.NO_TRANS.val(), TTrans.TRANS.val(), tensorD2.numRows(), tensorD2.numColumns(), this.cols, d, this.a, OFFS, Math.max(1, this.rows), tensorD.getArrayUnsafe(), OFFS, Math.max(1, tensorD.numRows()), BETA, tensorD2.getArrayUnsafe(), OFFS, Math.max(1, tensorD2.numRows()), Math.min(Math.min(this.depth, tensorD.depth), tensorD2.depth), stride(), tensorD.stride(), tensorD2.stride());
        return tensorD2;
    }

    public TensorD transABmultAdd(double d, TensorD tensorD, TensorD tensorD2) {
        Checks.checkTransABmultAdd(this, tensorD, tensorD2);
        Matrices.getBlas().dgemm_multi(TTrans.TRANS.val(), TTrans.TRANS.val(), tensorD2.numRows(), tensorD2.numColumns(), this.rows, d, this.a, OFFS, Math.max(1, this.rows), tensorD.getArrayUnsafe(), OFFS, Math.max(1, tensorD.numRows()), BETA, tensorD2.getArrayUnsafe(), OFFS, Math.max(1, tensorD2.numRows()), Math.min(Math.min(this.depth, tensorD.depth), tensorD2.depth), stride(), tensorD.stride(), tensorD2.stride());
        return tensorD2;
    }

    public TensorD transABmultAdd(TensorD tensorD, TensorD tensorD2) {
        return transABmultAdd(BETA, tensorD, tensorD2);
    }

    public TensorD transABmult(double d, TensorD tensorD, TensorD tensorD2) {
        return transABmultAdd(d, tensorD, tensorD2.zeroInplace());
    }

    public TensorD transABmult(TensorD tensorD, TensorD tensorD2) {
        return transABmult(BETA, tensorD, tensorD2);
    }

    public TensorD transBmultAdd(TensorD tensorD, TensorD tensorD2) {
        return transBmultAdd(BETA, tensorD, tensorD2);
    }

    public TensorD transBmult(double d, TensorD tensorD, TensorD tensorD2) {
        return transBmultAdd(d, tensorD, tensorD2.zeroInplace());
    }

    public TensorD transBmult(TensorD tensorD, TensorD tensorD2) {
        return transBmult(BETA, tensorD, tensorD2);
    }

    public TensorD transAmultAdd(TensorD tensorD, TensorD tensorD2) {
        return transAmultAdd(BETA, tensorD, tensorD2);
    }

    public TensorD transAmult(double d, TensorD tensorD, TensorD tensorD2) {
        return transAmultAdd(d, tensorD, tensorD2.zeroInplace());
    }

    public TensorD transAmult(TensorD tensorD, TensorD tensorD2) {
        return transAmult(BETA, tensorD, tensorD2);
    }

    public TensorD multAdd(TensorD tensorD, TensorD tensorD2) {
        return multAdd(BETA, tensorD, tensorD2);
    }

    public TensorD mult(double d, TensorD tensorD, TensorD tensorD2) {
        return multAdd(d, tensorD, tensorD2.zeroInplace());
    }

    public TensorD mult(TensorD tensorD, TensorD tensorD2) {
        return mult(BETA, tensorD, tensorD2);
    }

    public TensorD hadamard(TensorD tensorD, TensorD tensorD2) {
        Checks.checkEqualDimension(this, tensorD);
        Checks.checkEqualDimension(this, tensorD2);
        int min = Math.min(Math.min(this.depth, tensorD.depth), tensorD2.depth) * stride();
        double[] dArr = this.a;
        double[] arrayUnsafe = tensorD.getArrayUnsafe();
        double[] arrayUnsafe2 = tensorD2.getArrayUnsafe();
        for (int i = OFFS; i < min; i++) {
            arrayUnsafe2[i] = dArr[i] * arrayUnsafe[i];
        }
        return tensorD2;
    }

    public TensorD hadamard(TensorD tensorD) {
        Checks.checkEqualDimension(this, tensorD);
        return hadamard(tensorD, create(this.rows, this.cols, Math.min(this.depth, tensorD.depth)));
    }

    public TensorD hadamardTransposed(TensorD tensorD) {
        Checks.checkTrans(this, tensorD);
        int i = this.rows;
        int i2 = this.cols;
        int min = Math.min(this.depth, tensorD.depth);
        TensorD create = create(i, i2, min);
        double[] dArr = this.a;
        double[] arrayUnsafe = tensorD.getArrayUnsafe();
        double[] arrayUnsafe2 = create.getArrayUnsafe();
        for (int i3 = OFFS; i3 < min; i3++) {
            for (int i4 = OFFS; i4 < i2; i4++) {
                for (int i5 = OFFS; i5 < i; i5++) {
                    int idx = idx(i5, i4, i3);
                    arrayUnsafe2[idx] = dArr[idx] * arrayUnsafe[tensorD.idx(i4, i5, i3)];
                }
            }
        }
        return create;
    }

    public TensorD transposedHadamard(TensorD tensorD) {
        Checks.checkTrans(this, tensorD);
        int numRows = tensorD.numRows();
        int numColumns = tensorD.numColumns();
        int min = Math.min(this.depth, tensorD.depth);
        TensorD create = create(numRows, numColumns, min);
        double[] dArr = this.a;
        double[] arrayUnsafe = tensorD.getArrayUnsafe();
        double[] arrayUnsafe2 = create.getArrayUnsafe();
        for (int i = OFFS; i < min; i++) {
            for (int i2 = OFFS; i2 < numColumns; i2++) {
                for (int i3 = OFFS; i3 < numRows; i3++) {
                    int idx = tensorD.idx(i3, i2, i);
                    arrayUnsafe2[idx] = arrayUnsafe[idx] * dArr[idx(i2, i3, i)];
                }
            }
        }
        return create;
    }

    public TensorD times(TensorD tensorD) {
        return mult(tensorD, create(this.rows, tensorD.numColumns(), Math.min(this.depth, tensorD.depth)));
    }

    public TensorD timesTransposed(TensorD tensorD) {
        return transBmult(tensorD, create(this.rows, tensorD.numRows(), Math.min(this.depth, tensorD.depth)));
    }

    public TensorD transposedTimes(TensorD tensorD) {
        return transAmult(tensorD, create(this.cols, tensorD.numColumns(), Math.min(this.depth, tensorD.depth)));
    }

    public TensorD zeroInplace() {
        Arrays.fill(this.a, 0.0d);
        return this;
    }

    public TensorD scaleInplace(double d) {
        if (d == 0.0d) {
            return zeroInplace();
        }
        if (d == BETA) {
            return this;
        }
        double[] dArr = this.a;
        for (int i = OFFS; i < dArr.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] * d;
        }
        return this;
    }

    public TensorD clampInplace(double d, double d2) {
        double[] dArr = this.a;
        for (int i = OFFS; i < dArr.length; i++) {
            dArr[i] = Math.min(Math.max(dArr[i], d), d2);
        }
        return this;
    }

    public double[] getArrayUnsafe() {
        return this.a;
    }

    public TensorD copy() {
        return new TensorD(this);
    }

    private TensorD create(int i, int i2, int i3) {
        return new TensorD(i, i2, i3);
    }

    private double[] growAndCopyForAppend(Dimensions dimensions) {
        return copyForAppend(new double[checkNewArrayLength(dimensions)]);
    }

    private double[] copyForAppend(double[] dArr) {
        System.arraycopy(this.a, OFFS, dArr, OFFS, this.length);
        return dArr;
    }
}
