/*
 * Decompiled with CFR 0.152.
 */
package org.apache.wayang.tensorflow.model.op.nn;

import org.apache.wayang.basic.model.op.nn.BatchNorm2D;
import org.apache.wayang.basic.model.op.nn.BatchNorm3D;
import org.apache.wayang.tensorflow.model.op.nn.TensorflowBatchNorm2D;
import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Reshape;
import org.tensorflow.types.TBool;
import org.tensorflow.types.family.TNumber;

public class TensorflowBatchNorm3D<T extends TNumber> {
    private final Ops tf;
    private final BatchNorm3D op;
    private final TensorflowBatchNorm2D<T> batchNorm2D;

    public TensorflowBatchNorm3D(Graph graph, Ops tf, BatchNorm3D op, Class<T> tClass) {
        this.tf = tf;
        this.op = op;
        BatchNorm2D op2 = new BatchNorm2D(op.getNumFeatures(), op.getEpsilon(), op.getMomentum(), op.getDType());
        this.batchNorm2D = new TensorflowBatchNorm2D<T>(graph, tf, op2, tClass);
    }

    public Operand<T> call(Operand<T> input, Operand<TBool> trainingMode) {
        long[] s = input.shape().asArray();
        Reshape input2D = this.tf.reshape(input, (Operand)this.tf.array(new long[]{s[0], s[1], s[2], -1L}));
        Operand<T> output = this.batchNorm2D.call((Operand<T>)input2D, trainingMode);
        return this.tf.withName(this.op.getName()).reshape(output, (Operand)this.tf.constant(s));
    }
}

