/*
 * Decompiled with CFR 0.152.
 */
package org.apache.wayang.tensorflow.model.op.nn;

import java.util.Arrays;
import java.util.List;
import org.apache.wayang.basic.model.op.nn.Conv2D;
import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.linalg.Transpose;
import org.tensorflow.op.nn.BiasAdd;
import org.tensorflow.op.nn.Conv2d;
import org.tensorflow.op.random.TruncatedNormal;
import org.tensorflow.types.family.TNumber;

public class TensorflowConv2D<T extends TNumber> {
    private final Ops tf;
    private final Conv2D op;
    private final Variable<T> kernel;
    private final Variable<T> bias;

    public TensorflowConv2D(Ops tf, Conv2D op, Class<T> tClass) {
        this.tf = tf;
        this.op = op;
        this.kernel = tf.withName("Conv2DKernel").variable((Operand)tf.random.truncatedNormal((Operand)tf.array(this.kernelShape()), tClass, new TruncatedNormal.Options[0]), new Variable.Options[0]);
        this.bias = op.getBias() ? tf.withName("Conv2DBias").variable((Operand)tf.random.truncatedNormal((Operand)tf.array(new int[]{op.getOutChannels()}), tClass, new TruncatedNormal.Options[0]), new Variable.Options[0]) : null;
    }

    private int[] kernelShape() {
        int[] kernelSize = this.op.getKernelSize();
        if (kernelSize.length == 1) {
            return new int[]{kernelSize[0], kernelSize[0], this.op.getInChannels(), this.op.getOutChannels()};
        }
        if (kernelSize.length == 2) {
            return new int[]{kernelSize[0], kernelSize[1], this.op.getInChannels(), this.op.getOutChannels()};
        }
        throw new RuntimeException("Unsupported Kernel: " + Arrays.toString(kernelSize));
    }

    private List<Long> strideShape() {
        int[] stride = this.op.getStride();
        if (stride.length == 1) {
            return Arrays.asList(1L, 1L, stride[0], stride[0]);
        }
        if (stride.length == 2) {
            return Arrays.asList(1L, 1L, stride[0], stride[1]);
        }
        throw new RuntimeException("Unsupported Stride: " + Arrays.toString(stride));
    }

    public Operand<T> callV1(Operand<T> input) {
        if (!this.op.getBias()) {
            return this.tf.withName((String)this.op.getName()).nn.conv2d(input, this.kernel, this.strideShape(), this.op.getPadding(), new Conv2d.Options[]{Conv2d.dataFormat((String)"NCHW")});
        }
        return this.tf.withName((String)this.op.getName()).nn.biasAdd((Operand)this.tf.nn.conv2d(input, this.kernel, this.strideShape(), this.op.getPadding(), new Conv2d.Options[]{Conv2d.dataFormat((String)"NCHW")}), this.bias, new BiasAdd.Options[]{BiasAdd.dataFormat((String)"NCHW")});
    }

    public Operand<T> call(Operand<T> input) {
        Transpose transpose = this.tf.linalg.transpose(input, (Operand)this.tf.array(new int[]{0, 2, 3, 1}));
        Conv2d conv = this.tf.nn.conv2d((Operand)transpose, this.kernel, this.strideShape(), this.op.getPadding(), new Conv2d.Options[]{Conv2d.dataFormat((String)"NHWC")});
        if (this.op.getBias()) {
            conv = this.tf.nn.biasAdd((Operand)conv, this.bias, new BiasAdd.Options[]{BiasAdd.dataFormat((String)"NHWC")});
        }
        return this.tf.withName((String)this.op.getName()).linalg.transpose((Operand)conv, (Operand)this.tf.array(new int[]{0, 3, 1, 2}));
    }
}

