package mikera.matrixx;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import mikera.arrayz.INDArray;
import mikera.indexz.Index;
import mikera.matrixx.impl.ADiagonalMatrix;
import mikera.matrixx.impl.ColumnMatrix;
import mikera.matrixx.impl.DiagonalMatrix;
import mikera.matrixx.impl.IdentityMatrix;
import mikera.matrixx.impl.ScalarMatrix;
import mikera.matrixx.impl.SparseColumnMatrix;
import mikera.matrixx.impl.SparseRowMatrix;
import mikera.matrixx.impl.StridedMatrix;
import mikera.matrixx.impl.ZeroMatrix;
import mikera.util.Rand;
import mikera.vectorz.AVector;
import mikera.vectorz.Tools;
import mikera.vectorz.Vector3;
import mikera.vectorz.Vectorz;
import mikera.vectorz.impl.SparseIndexedVector;
import mikera.vectorz.util.ErrorMessages;
import mikera.vectorz.util.VectorzException;
import us.bpsm.edn.parser.Parser;
import us.bpsm.edn.parser.Parsers;

/* loaded from: input_file:mikera/matrixx/Matrixx.class */
public class Matrixx {
    private static final long SPARSE_ELEMENT_THRESHOLD = 100000;

    public static AMatrix createIdentityMatrix(int i) {
        return createImmutableIdentityMatrix(i);
    }

    public static IdentityMatrix createImmutableIdentityMatrix(int i) {
        return IdentityMatrix.create(i);
    }

    public static AMatrix createMutableIdentityMatrix(int i) {
        AMatrix newMatrix = newMatrix(i, i);
        for (int i2 = 0; i2 < i; i2++) {
            newMatrix.unsafeSet(i2, i2, 1.0d);
        }
        return newMatrix;
    }

    public static AMatrix toMatrix(Object obj) {
        if (obj instanceof AMatrix) {
            return (AMatrix) obj;
        }
        if (obj instanceof AVector) {
            return ColumnMatrix.wrap((AVector) obj);
        }
        if (!(obj instanceof Iterable)) {
            throw new UnsupportedOperationException("Can't convert to matrix: " + obj.getClass());
        }
        ArrayList arrayList = new ArrayList();
        Iterator it = ((Iterable) obj).iterator();
        while (it.hasNext()) {
            arrayList.add(Vectorz.toVector(it.next()));
        }
        return createFromVectors(arrayList);
    }

    public static AMatrix createSparse(AMatrix aMatrix) {
        int rowCount = aMatrix.rowCount();
        int columnCount = aMatrix.columnCount();
        return (rowCount == 0 || columnCount == 0) ? ZeroMatrix.create(rowCount, columnCount) : SparseRowMatrix.create(aMatrix);
    }

    public static AMatrix createSparse(int i, int i2) {
        return SparseRowMatrix.create(i, i2);
    }

    public static SparseRowMatrix createSparseRows(Iterable<AVector> iterable) {
        return createSparseRows(iterable.iterator());
    }

    public static SparseRowMatrix createSparseRows(Iterator<AVector> it) {
        AVector next = it.next();
        int length = next.length();
        HashMap hashMap = new HashMap();
        hashMap.put(0, next);
        int i = 1;
        while (it.hasNext()) {
            AVector next2 = it.next();
            if (!next2.isZero()) {
                hashMap.put(Integer.valueOf(i), next2.sparseClone());
            }
            i++;
        }
        return SparseRowMatrix.wrap(hashMap, i, length);
    }

    public static AMatrix createSparse(int i, Index[] indexArr, AVector[] aVectorArr) {
        int length = indexArr.length;
        if (length != aVectorArr.length) {
            throw new IllegalArgumentException("Length of indexes array must match length of weights array");
        }
        SparseRowMatrix create = SparseRowMatrix.create(length, i);
        for (int i2 = 0; i2 < length; i2++) {
            create.replaceRow(i2, SparseIndexedVector.wrap(i, indexArr[i2].mo2clone(), aVectorArr[i2].toDoubleArray()));
        }
        return create;
    }

    public static SparseColumnMatrix createSparseColumns(AMatrix aMatrix) {
        int columnCount = aMatrix.columnCount();
        AVector[] aVectorArr = new AVector[columnCount];
        for (int i = 0; i < columnCount; i++) {
            aVectorArr[i] = Vectorz.createSparse(aMatrix.getColumn(i));
        }
        return SparseColumnMatrix.wrap(aVectorArr);
    }

    public static AMatrix createSparseRows(AMatrix aMatrix) {
        return aMatrix.rowCount() == 0 ? ZeroMatrix.create(0, aMatrix.columnCount()) : SparseRowMatrix.create(aMatrix);
    }

    public static SparseRowMatrix createSparseRows(INDArray iNDArray) {
        if (iNDArray.dimensionality() != 2) {
            throw new IllegalArgumentException(ErrorMessages.incompatibleShape(iNDArray));
        }
        int shape = iNDArray.getShape(0);
        SparseRowMatrix create = SparseRowMatrix.create(shape, iNDArray.getShape(1));
        for (int i = 0; i < shape; i++) {
            AVector asVector = iNDArray.slice(i).sparseClone().asVector();
            if (!asVector.isZero()) {
                create.replaceRow(i, asVector);
            }
        }
        return create;
    }

    public static ZeroMatrix createImmutableZeroMatrix(int i, int i2) {
        return ZeroMatrix.create(i, i2);
    }

    public static ADiagonalMatrix createScaleMatrix(int i, double d) {
        DiagonalMatrix diagonalMatrix = new DiagonalMatrix(i);
        for (int i2 = 0; i2 < i; i2++) {
            diagonalMatrix.unsafeSet(i2, i2, d);
        }
        return diagonalMatrix;
    }

    public static ADiagonalMatrix createScalarMatrix(int i, double d) {
        return d == 1.0d ? IdentityMatrix.create(i) : ScalarMatrix.create(i, d);
    }

    public static DiagonalMatrix createScaleMatrix(double... dArr) {
        int length = dArr.length;
        DiagonalMatrix diagonalMatrix = new DiagonalMatrix(length);
        for (int i = 0; i < length; i++) {
            diagonalMatrix.unsafeSet(i, i, dArr[i]);
        }
        return diagonalMatrix;
    }

    public static DiagonalMatrix createDiagonalMatrix(double... dArr) {
        DiagonalMatrix diagonalMatrix = new DiagonalMatrix(dArr.length);
        diagonalMatrix.getLeadingDiagonal().setValues(dArr);
        return diagonalMatrix;
    }

    public static DiagonalMatrix createDiagonalMatrix(AVector aVector) {
        return DiagonalMatrix.wrap(aVector.toDoubleArray());
    }

    public static Matrix33 createRotationMatrix(Vector3 vector3, double d) {
        return createRotationMatrix(vector3.x, vector3.y, vector3.z, d);
    }

    public static Matrix33 createRotationMatrix(double d, double d2, double d3, double d4) {
        double sqrt = Math.sqrt((d * d) + (d2 * d2) + (d3 * d3));
        if (sqrt == 0.0d) {
            return Matrix33.createIdentityMatrix();
        }
        double cos = Math.cos(d4);
        double d5 = d;
        double d6 = d2;
        double d7 = d3;
        if (sqrt != 1.0d) {
            double d8 = 1.0d / sqrt;
            d5 = d * d8;
            d6 = d2 * d8;
            d7 = d3 * d8;
        }
        double sin = Math.sin(d4);
        return new Matrix33((d5 * d5) + ((1.0d - (d5 * d5)) * cos), ((d5 * d6) * (1.0d - cos)) - (d7 * sin), (d5 * d7 * (1.0d - cos)) + (d6 * sin), (d5 * d6 * (1.0d - cos)) + (d7 * sin), (d6 * d6) + ((1.0d - (d6 * d6)) * cos), ((d6 * d7) * (1.0d - cos)) - (d5 * sin), ((d5 * d7) * (1.0d - cos)) - (d6 * sin), (d6 * d7 * (1.0d - cos)) + (d5 * sin), (d7 * d7) + ((1.0d - (d7 * d7)) * cos));
    }

    public static Matrix33 createRotationMatrix(AVector aVector, double d) {
        if (aVector.length() != 3) {
            throw new VectorzException("Rotation matrix requires a 3d axis vector");
        }
        return createRotationMatrix(aVector.unsafeGet(0), aVector.unsafeGet(1), aVector.unsafeGet(2), d);
    }

    public static Matrix33 createXAxisRotationMatrix(double d) {
        return createRotationMatrix(1.0d, 0.0d, 0.0d, d);
    }

    public static Matrix33 createYAxisRotationMatrix(double d) {
        return createRotationMatrix(0.0d, 1.0d, 0.0d, d);
    }

    public static Matrix33 createZAxisRotationMatrix(double d) {
        return createRotationMatrix(0.0d, 0.0d, 1.0d, d);
    }

    public static Matrix22 create2DRotationMatrix(double d) {
        return Matrix22.createRotationMatrix(d);
    }

    public static Matrix createRandomSquareMatrix(int i) {
        Matrix createSquareMatrix = createSquareMatrix(i);
        fillRandomValues(createSquareMatrix);
        return createSquareMatrix;
    }

    public static AMatrix createRandomMatrix(int i, int i2) {
        AMatrix newMatrix = newMatrix(i, i2);
        fillRandomValues(newMatrix);
        return newMatrix;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Matrix createInverse(AMatrix aMatrix) {
        if (!aMatrix.isSquare()) {
            throw new IllegalArgumentException("Matrix must be square for inverse!");
        }
        int rowCount = aMatrix.rowCount();
        Matrix matrix = new Matrix(aMatrix);
        int[] iArr = new int[rowCount];
        decomposeLU(matrix, iArr);
        return backSubstituteLU(matrix, iArr);
    }

    private static void decomposeLU(Matrix matrix, int[] iArr) {
        int length = iArr.length;
        double[] dArr = matrix.data;
        double[] dArr2 = new double[length];
        calcRowFactors(dArr, dArr2);
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < i; i2++) {
                int i3 = (length * i2) + i;
                double d = dArr[i3];
                for (int i4 = 0; i4 < i2; i4++) {
                    d -= dArr[(length * i2) + i4] * dArr[(length * i4) + i];
                }
                dArr[i3] = d;
            }
            int i5 = 0;
            double d2 = Double.NEGATIVE_INFINITY;
            for (int i6 = i; i6 < length; i6++) {
                int i7 = (length * i6) + i;
                double d3 = dArr[i7];
                for (int i8 = 0; i8 < i; i8++) {
                    d3 -= dArr[(length * i6) + i8] * dArr[(length * i8) + i];
                }
                dArr[i7] = d3;
                double abs = dArr2[i6] * Math.abs(d3);
                if (abs > d2) {
                    d2 = abs;
                    i5 = i6;
                }
            }
            if (i != i5) {
                matrix.swapRows(i, i5);
                dArr2[i5] = dArr2[i];
            }
            iArr[i] = i5;
            if (dArr[(length * i) + i] == 0.0d) {
                throw new IllegalArgumentException(ErrorMessages.singularMatrix());
            }
            double d4 = 1.0d / dArr[(length * i) + i];
            int i9 = (length * (i + 1)) + i;
            for (int i10 = 0; i10 < (length - 1) - i; i10++) {
                int i11 = (length * i10) + i9;
                dArr[i11] = dArr[i11] * d4;
            }
        }
    }

    private static void calcRowFactors(double[] dArr, double[] dArr2) {
        int length = dArr2.length;
        for (int i = 0; i < length; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < length; i2++) {
                d = Math.max(d, Math.abs(dArr[(i * length) + i2]));
            }
            if (d == 0.0d) {
                throw new IllegalArgumentException(ErrorMessages.singularMatrix());
            }
            dArr2[i] = 1.0d / d;
        }
    }

    private static Matrix backSubstituteLU(Matrix matrix, int[] iArr) {
        int length = iArr.length;
        double[] dArr = matrix.data;
        Matrix matrix2 = new Matrix(createImmutableIdentityMatrix(length));
        double[] dArr2 = matrix2.data;
        for (int i = 0; i < length; i++) {
            int i2 = -1;
            for (int i3 = 0; i3 < length; i3++) {
                int i4 = iArr[i3];
                double d = dArr2[(length * i4) + i];
                dArr2[(length * i4) + i] = dArr2[(length * i3) + i];
                if (i2 >= 0) {
                    for (int i5 = i2; i5 <= i3 - 1; i5++) {
                        d -= dArr[(i3 * length) + i5] * dArr2[(length * i5) + i];
                    }
                } else if (d != 0.0d) {
                    i2 = i3;
                }
                dArr2[(length * i3) + i] = d;
            }
            for (int i6 = 0; i6 < length; i6++) {
                int i7 = (length - 1) - i6;
                int i8 = length * i7;
                double d2 = 0.0d;
                for (int i9 = 0; i9 < i6; i9++) {
                    d2 += dArr[i8 + ((length - 1) - i9)] * dArr2[(length * ((length - 1) - i9)) + i];
                }
                dArr2[(length * i7) + i] = (dArr2[(length * i7) + i] - d2) / dArr[i8 + i7];
            }
        }
        return matrix2;
    }

    public static AMatrix newMatrix(int i, int i2) {
        if (i == i2) {
            if (i == 1) {
                return new Matrix11();
            }
            if (i == 2) {
                return new Matrix22();
            }
            if (i == 3) {
                return new Matrix33();
            }
        }
        return ((long) i) * ((long) i2) > SPARSE_ELEMENT_THRESHOLD ? createSparse(i, i2) : Matrix.create(i, i2);
    }

    public static Matrix createFromVector(AVector aVector, int i, int i2) {
        Matrix create = Matrix.create(i, i2);
        aVector.copyTo(0, create.data, 0, Math.min(i * i2, aVector.length()));
        return create;
    }

    private static Matrix createSquareMatrix(int i) {
        return Matrix.create(i, i);
    }

    public static AMatrix extractLowerTriangular(AMatrix aMatrix) {
        int rowCount = aMatrix.rowCount();
        if (rowCount > aMatrix.columnCount()) {
            throw new IllegalArgumentException("Too few columns in matrix");
        }
        AMatrix newMatrix = newMatrix(rowCount, rowCount);
        for (int i = 0; i < rowCount; i++) {
            for (int i2 = 0; i2 <= i; i2++) {
                newMatrix.unsafeSet(i, i2, aMatrix.unsafeGet(i, i2));
            }
        }
        return newMatrix;
    }

    public static AMatrix extractUpperTriangular(AMatrix aMatrix) {
        int rowCount = aMatrix.rowCount();
        if (rowCount > aMatrix.rowCount()) {
            throw new IllegalArgumentException("Too few rows in matrix");
        }
        AMatrix newMatrix = newMatrix(rowCount, rowCount);
        for (int i = 0; i < rowCount; i++) {
            for (int i2 = i; i2 < rowCount; i2++) {
                newMatrix.unsafeSet(i, i2, aMatrix.unsafeGet(i, i2));
            }
        }
        return newMatrix;
    }

    public static Matrix create(AMatrix aMatrix) {
        return new Matrix(aMatrix);
    }

    public static Matrix create(List<Object> list) {
        int size = list.size();
        AVector create = Vectorz.create(list.get(0));
        Matrix create2 = Matrix.create(size, create.length());
        create2.setRow(0, create);
        for (int i = 1; i < size; i++) {
            create2.setRow(i, Vectorz.create(list.get(i)));
        }
        return create2;
    }

    public static AMatrix create(IMatrix iMatrix) {
        int rowCount = iMatrix.rowCount();
        int columnCount = iMatrix.columnCount();
        AMatrix newMatrix = newMatrix(rowCount, columnCount);
        for (int i = 0; i < rowCount; i++) {
            for (int i2 = 0; i2 < columnCount; i2++) {
                newMatrix.unsafeSet(i, i2, iMatrix.get(i, i2));
            }
        }
        return newMatrix;
    }

    public static void fillRandomValues(AMatrix aMatrix) {
        int rowCount = aMatrix.rowCount();
        int columnCount = aMatrix.columnCount();
        for (int i = 0; i < rowCount; i++) {
            for (int i2 = 0; i2 < columnCount; i2++) {
                aMatrix.unsafeSet(i, i2, Rand.nextDouble());
            }
        }
    }

    public static Matrix createFromVectors(AVector... aVectorArr) {
        int length = aVectorArr.length;
        Matrix create = Matrix.create(length, length == 0 ? 0 : aVectorArr[0].length());
        for (int i = 0; i < length; i++) {
            create.getRowView(i).set(aVectorArr[i]);
        }
        return create;
    }

    public static AMatrix createFromVectors(List<AVector> list) {
        int size = list.size();
        AMatrix newMatrix = newMatrix(size, size == 0 ? 0 : list.get(0).length());
        for (int i = 0; i < size; i++) {
            newMatrix.getRowView(i).set(list.get(i));
        }
        return newMatrix;
    }

    private static Parser.Config getMatrixParserConfig() {
        return Parsers.defaultConfiguration();
    }

    public static AMatrix parse(String str) {
        List list = (List) Parsers.newParser(getMatrixParserConfig()).nextValue(Parsers.newParseable(str));
        int size = list.size();
        int size2 = size == 0 ? 0 : ((List) list.get(0)).size();
        AMatrix newMatrix = newMatrix(size, size2);
        for (int i = 0; i < size; i++) {
            List list2 = (List) list.get(i);
            for (int i2 = 0; i2 < size2; i2++) {
                newMatrix.unsafeSet(i, i2, Tools.toDouble(list2.get(i2)));
            }
        }
        return newMatrix;
    }

    public static Matrix deepCopy(AMatrix aMatrix) {
        return create(aMatrix);
    }

    public static AMatrix create(Object... objArr) {
        return create((List<Object>) Arrays.asList(objArr));
    }

    public static Matrix create(double[][] dArr) {
        return Matrix.create(dArr);
    }

    public static AMatrix wrapStrided(double[] dArr, int i, int i2, int i3, int i4, int i5) {
        return (i3 == 0 && i2 == i4 && i5 == 1 && dArr.length == i * i2) ? Matrix.wrap(i, i2, dArr) : StridedMatrix.wrap(dArr, i, i2, i3, i4, i5);
    }

    public static AMatrix createSparse(List<INDArray> list) {
        int sliceCount = list.get(0).sliceCount();
        ArrayList arrayList = new ArrayList();
        for (INDArray iNDArray : list) {
            if (iNDArray.dimensionality() != 1 || iNDArray.sliceCount() != sliceCount) {
                throw new IllegalArgumentException(ErrorMessages.incompatibleShape(iNDArray));
            }
            arrayList.add(iNDArray.sparse().asVector());
        }
        return SparseRowMatrix.create(arrayList);
    }
}
