/*
 * Decompiled with CFR 0.152.
 */
package org.apache.wayang.tensorflow.model.op.nn;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import org.apache.wayang.basic.model.op.nn.BatchNorm2D;
import org.tensorflow.ConcreteFunction;
import org.tensorflow.Graph;
import org.tensorflow.GraphOperation;
import org.tensorflow.Operand;
import org.tensorflow.Signature;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Assign;
import org.tensorflow.op.core.Identity;
import org.tensorflow.op.core.If;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.nn.FusedBatchNorm;
import org.tensorflow.op.random.TruncatedNormal;
import org.tensorflow.types.TBool;
import org.tensorflow.types.family.TNumber;

public class TensorflowBatchNorm2D<T extends TNumber> {
    private final Graph graph;
    private final Ops tf;
    private final BatchNorm2D op;
    private final Class<T> tClass;
    private final Variable<T> weight;
    private final Variable<T> bias;
    private final Variable<T> runningMean;
    private final Variable<T> runningVar;

    public TensorflowBatchNorm2D(Graph graph, Ops tf, BatchNorm2D op, Class<T> tClass) {
        this.graph = graph;
        this.tf = tf;
        this.op = op;
        this.tClass = tClass;
        this.weight = tf.withName("BatchNorm2DWeight").variable((Operand)tf.random.truncatedNormal((Operand)tf.array(new int[]{op.getNumFeatures()}), tClass, new TruncatedNormal.Options[0]), new Variable.Options[0]);
        this.bias = tf.withName("BatchNorm2DBias").variable((Operand)tf.random.truncatedNormal((Operand)tf.array(new int[]{op.getNumFeatures()}), tClass, new TruncatedNormal.Options[0]), new Variable.Options[0]);
        this.runningMean = tf.withName("BatchNorm2DRunningMean").variable((Operand)tf.zeros((Operand)tf.array(new int[]{op.getNumFeatures()}), tClass), new Variable.Options[0]);
        this.runningVar = tf.withName("BatchNorm2DRunningVar").variable((Operand)tf.ones((Operand)tf.array(new int[]{op.getNumFeatures()}), tClass), new Variable.Options[0]);
    }

    public Operand<T> call(Operand<T> input, Operand<TBool> trainingMode) {
        List<Operand<?>> placeholders = this.getPlaceholders(input);
        ConcreteFunction training = this.training(input, placeholders);
        ConcreteFunction inference = this.inference(input, placeholders);
        Operand out = (Operand)this.tf.withName(this.op.getName()).ifOp(trainingMode, placeholders, Collections.singletonList(this.tClass), training, inference, new If.Options[0]).iterator().next();
        return out;
    }

    public List<Operand<?>> getPlaceholders(Operand<?> input) {
        Set operations = this.graph.subgraphTo(Collections.singleton(this.tf.identity(input)));
        ArrayList inputs = new ArrayList();
        for (GraphOperation x : operations) {
            if (!x.type().equals("Placeholder") && !x.type().equals("PlaceholderWithDefault")) continue;
            inputs.add((Operand<?>)x.output(0));
        }
        return inputs;
    }

    public Signature.Builder addPlaceholders(Signature.Builder builder, List<Operand<?>> placeholders) {
        for (Operand<?> placeholder : placeholders) {
            builder.input(placeholder.op().name(), placeholder);
        }
        return builder;
    }

    public ConcreteFunction training(Operand<T> input, List<Operand<?>> placeholders) {
        FusedBatchNorm batchNormTraining = this.tf.nn.fusedBatchNorm(input, this.weight, this.bias, this.runningMean, this.runningVar, new FusedBatchNorm.Options[]{FusedBatchNorm.epsilon((Float)Float.valueOf(this.op.getEpsilon())).exponentialAvgFactor(Float.valueOf(this.op.getMomentum())).dataFormat("NCHW").isTraining(Boolean.valueOf(true))});
        Add mean = this.tf.math.add((Operand)this.tf.math.mul((Operand)this.tf.dtypes.cast((Operand)this.tf.constant(1.0f - this.op.getMomentum()), this.tClass, new Cast.Options[0]), (Operand)this.tf.stopGradient(this.runningMean)), (Operand)batchNormTraining.batchMean());
        Add var = this.tf.math.add((Operand)this.tf.math.mul((Operand)this.tf.dtypes.cast((Operand)this.tf.constant(1.0f - this.op.getMomentum()), this.tClass, new Cast.Options[0]), (Operand)this.tf.stopGradient(this.runningVar)), (Operand)batchNormTraining.batchVariance());
        Identity y = this.tf.withControlDependencies(new Op[]{this.tf.assign(this.runningMean, (Operand)mean, new Assign.Options[0]), this.tf.assign(this.runningVar, (Operand)var, new Assign.Options[0])}).identity((Operand)batchNormTraining.y());
        return ConcreteFunction.create((Signature)this.addPlaceholders(Signature.builder(), placeholders).output("y", (Operand)y).build(), (Graph)this.graph);
    }

    public ConcreteFunction inference(Operand<T> input, List<Operand<?>> placeholders) {
        FusedBatchNorm batchNormInference = this.tf.nn.fusedBatchNorm(input, this.weight, this.bias, this.runningMean, this.runningVar, new FusedBatchNorm.Options[]{FusedBatchNorm.epsilon((Float)Float.valueOf(this.op.getEpsilon())).exponentialAvgFactor(Float.valueOf(this.op.getMomentum())).dataFormat("NCHW").isTraining(Boolean.valueOf(false))});
        return ConcreteFunction.create((Signature)this.addPlaceholders(Signature.builder(), placeholders).output("y", (Operand)batchNormInference.y()).build(), (Graph)this.graph);
    }
}

