/*
 * Decompiled with CFR 0.152.
 */
package org.apache.wayang.tensorflow.model.op.nn;

import java.util.Arrays;
import java.util.List;
import org.apache.wayang.basic.model.op.nn.Conv3D;
import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.nn.Conv3d;
import org.tensorflow.op.random.TruncatedNormal;
import org.tensorflow.types.family.TNumber;

public class TensorflowConv3D<T extends TNumber> {
    private final Ops tf;
    private final Conv3D op;
    private final Variable<T> kernel;
    private final Variable<T> bias;

    public TensorflowConv3D(Ops tf, Conv3D op, Class<T> tClass) {
        this.tf = tf;
        this.op = op;
        this.kernel = tf.withName("Conv3DKernel").variable((Operand)tf.random.truncatedNormal((Operand)tf.array(this.kernelShape()), tClass, new TruncatedNormal.Options[0]), new Variable.Options[0]);
        this.bias = op.getBias() ? tf.withName("Conv3DBias").variable((Operand)tf.random.truncatedNormal((Operand)tf.array(new int[]{op.getOutChannels()}), tClass, new TruncatedNormal.Options[0]), new Variable.Options[0]) : null;
    }

    private int[] kernelShape() {
        int[] kernelSize = this.op.getKernelSize();
        if (kernelSize.length == 1) {
            return new int[]{kernelSize[0], kernelSize[0], kernelSize[0], this.op.getInChannels(), this.op.getOutChannels()};
        }
        if (kernelSize.length == 3) {
            return new int[]{kernelSize[0], kernelSize[1], kernelSize[2], this.op.getInChannels(), this.op.getOutChannels()};
        }
        throw new RuntimeException("Unsupported Kernel: " + Arrays.toString(kernelSize));
    }

    private List<Long> strideShape() {
        int[] stride = this.op.getStride();
        if (stride.length == 1) {
            return Arrays.asList(1L, 1L, stride[0], stride[0], stride[0]);
        }
        if (stride.length == 3) {
            return Arrays.asList(1L, 1L, stride[0], stride[1], stride[2]);
        }
        throw new RuntimeException("Unsupported Stride: " + Arrays.toString(stride));
    }

    public Operand<T> call(Operand<T> input) {
        if (!this.op.getBias()) {
            return this.tf.withName((String)this.op.getName()).nn.conv3d(input, this.kernel, this.strideShape(), this.op.getPadding(), new Conv3d.Options[]{Conv3d.dataFormat((String)"NCDHW")});
        }
        return this.tf.withName((String)this.op.getName()).math.add((Operand)this.tf.nn.conv3d(input, this.kernel, this.strideShape(), this.op.getPadding(), new Conv3d.Options[]{Conv3d.dataFormat((String)"NCDHW")}), (Operand)this.tf.reshape(this.bias, (Operand)this.tf.array(new int[]{-1, 1, 1, 1})));
    }
}

