/*
 * Decompiled with CFR 0.152.
 */
package org.apache.wayang.tensorflow.model.op.nn;

import org.apache.wayang.basic.model.op.nn.Linear;
import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.linalg.MatMul;
import org.tensorflow.op.random.TruncatedNormal;
import org.tensorflow.types.family.TNumber;

public class TensorflowLinear<T extends TNumber> {
    private final Ops tf;
    private final Linear op;
    private final Variable<T> weights;
    private final Variable<T> bias;

    public TensorflowLinear(Ops tf, Linear op, Class<T> tClass) {
        this.tf = tf;
        this.op = op;
        this.weights = tf.withName("LinearWeights").variable((Operand)tf.random.truncatedNormal((Operand)tf.array(new int[]{op.getInFeatures(), op.getOutFeatures()}), tClass, new TruncatedNormal.Options[0]), new Variable.Options[0]);
        this.bias = op.getBias() ? tf.withName("LinearBias").variable((Operand)tf.random.truncatedNormal((Operand)tf.array(new int[]{op.getOutFeatures()}), tClass, new TruncatedNormal.Options[0]), new Variable.Options[0]) : null;
    }

    public Operand<T> call(Operand<T> input) {
        if (!this.op.getBias()) {
            return this.tf.withName((String)this.op.getName()).linalg.matMul(input, this.weights, new MatMul.Options[0]);
        }
        return this.tf.withName((String)this.op.getName()).math.add((Operand)this.tf.linalg.matMul(input, this.weights, new MatMul.Options[0]), this.bias);
    }
}

