/*
 * Decompiled with CFR 0.152.
 */
package org.apache.wayang.tensorflow.model;

import java.util.Arrays;
import org.apache.wayang.basic.model.op.ArgMax;
import org.apache.wayang.basic.model.op.Cast;
import org.apache.wayang.basic.model.op.Eq;
import org.apache.wayang.basic.model.op.Get;
import org.apache.wayang.basic.model.op.Input;
import org.apache.wayang.basic.model.op.Mean;
import org.apache.wayang.basic.model.op.Op;
import org.apache.wayang.basic.model.op.Reshape;
import org.apache.wayang.basic.model.op.Slice;
import org.apache.wayang.basic.model.op.Transpose;
import org.apache.wayang.basic.model.op.ZeroLike;
import org.apache.wayang.basic.model.op.nn.BatchNorm2D;
import org.apache.wayang.basic.model.op.nn.BatchNorm3D;
import org.apache.wayang.basic.model.op.nn.Conv2D;
import org.apache.wayang.basic.model.op.nn.Conv3D;
import org.apache.wayang.basic.model.op.nn.ConvLSTM2D;
import org.apache.wayang.basic.model.op.nn.CrossEntropyLoss;
import org.apache.wayang.basic.model.op.nn.Linear;
import org.apache.wayang.basic.model.op.nn.MSELoss;
import org.apache.wayang.basic.model.op.nn.ReLU;
import org.apache.wayang.basic.model.op.nn.Sigmoid;
import org.apache.wayang.basic.model.op.nn.Softmax;
import org.apache.wayang.basic.model.optimizer.Adam;
import org.apache.wayang.basic.model.optimizer.GradientDescent;
import org.apache.wayang.basic.model.optimizer.Optimizer;
import org.apache.wayang.tensorflow.model.op.nn.TensorflowBatchNorm2D;
import org.apache.wayang.tensorflow.model.op.nn.TensorflowBatchNorm3D;
import org.apache.wayang.tensorflow.model.op.nn.TensorflowConv2D;
import org.apache.wayang.tensorflow.model.op.nn.TensorflowConv3D;
import org.apache.wayang.tensorflow.model.op.nn.TensorflowConvLSTM2D;
import org.apache.wayang.tensorflow.model.op.nn.TensorflowLinear;
import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.BooleanNdArray;
import org.tensorflow.ndarray.ByteNdArray;
import org.tensorflow.ndarray.DoubleNdArray;
import org.tensorflow.ndarray.FloatNdArray;
import org.tensorflow.ndarray.IntNdArray;
import org.tensorflow.ndarray.LongNdArray;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.OneHot;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.core.ReduceSum;
import org.tensorflow.op.core.ZerosLike;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.math.Equal;
import org.tensorflow.op.math.Mean;
import org.tensorflow.types.TBool;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TFloat64;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TUint8;

public class Convertor {
    public static Operand<?> convert(Graph graph, Ops tf, Op op, Operand<?> ... inputs) {
        if (op instanceof ArgMax) {
            return Convertor.convert(tf, (ArgMax)op, inputs[0]);
        }
        if (op instanceof BatchNorm2D) {
            return Convertor.convert(graph, tf, (BatchNorm2D)op, inputs[0], inputs[1]);
        }
        if (op instanceof BatchNorm3D) {
            return Convertor.convert(graph, tf, (BatchNorm3D)op, inputs[0], inputs[1]);
        }
        if (op instanceof Cast) {
            return Convertor.convert(tf, (Cast)op, inputs[0]);
        }
        if (op instanceof ConvLSTM2D) {
            return Convertor.convert(tf, (ConvLSTM2D)op, inputs[0]);
        }
        if (op instanceof Conv2D) {
            return Convertor.convert(tf, (Conv2D)op, inputs[0]);
        }
        if (op instanceof Conv3D) {
            return Convertor.convert(tf, (Conv3D)op, inputs[0]);
        }
        if (op instanceof CrossEntropyLoss) {
            return Convertor.convert(tf, (CrossEntropyLoss)op, inputs[0], inputs[1]);
        }
        if (op instanceof Eq) {
            return Convertor.convert(tf, (Eq)op, inputs[0], inputs[1]);
        }
        if (op instanceof Get) {
            return Convertor.convert(tf, (Get)op, inputs[0]);
        }
        if (op instanceof Input) {
            return Convertor.convert(tf, (Input)op);
        }
        if (op instanceof Linear) {
            return Convertor.convert(tf, (Linear)op, inputs[0]);
        }
        if (op instanceof Mean) {
            return Convertor.convert(tf, (Mean)op, inputs[0]);
        }
        if (op instanceof MSELoss) {
            return Convertor.convert(tf, (MSELoss)op, inputs[0], inputs[1]);
        }
        if (op instanceof ReLU) {
            return Convertor.convert(tf, (ReLU)op, inputs[0]);
        }
        if (op instanceof Reshape) {
            return Convertor.convert(tf, (Reshape)op, inputs[0]);
        }
        if (op instanceof Sigmoid) {
            return Convertor.convert(tf, (Sigmoid)op, inputs[0]);
        }
        if (op instanceof Slice) {
            return Convertor.convert(tf, (Slice)op, inputs[0]);
        }
        if (op instanceof Softmax) {
            return Convertor.convert(tf, (Softmax)op, inputs[0]);
        }
        if (op instanceof Transpose) {
            return Convertor.convert(tf, (Transpose)op, inputs[0]);
        }
        if (op instanceof ZeroLike) {
            return Convertor.convert(tf, (ZeroLike)op, inputs[0]);
        }
        throw new RuntimeException("Unsupported operator: " + String.valueOf(op.getClass()));
    }

    public static Operand<TInt32> convert(Ops tf, ArgMax op, Operand<?> input) {
        return tf.withName((String)op.getName()).math.argMax(input, (Operand)tf.constant(op.getDim()), TInt32.class);
    }

    public static Operand<?> convert(Graph graph, Ops tf, BatchNorm2D op, Operand<?> input, Operand<TBool> trainingMode) {
        if (op.getDType() == Op.DType.FLOAT32) {
            return new TensorflowBatchNorm2D<TFloat32>(graph, tf, op, TFloat32.class).call(input, trainingMode);
        }
        if (op.getDType() == Op.DType.FLOAT64) {
            return new TensorflowBatchNorm2D<TFloat64>(graph, tf, op, TFloat64.class).call(input, trainingMode);
        }
        throw new RuntimeException("Unsupported DType: " + String.valueOf(op.getDType()));
    }

    public static Operand<?> convert(Graph graph, Ops tf, BatchNorm3D op, Operand<?> input, Operand<TBool> trainingMode) {
        if (op.getDType() == Op.DType.FLOAT32) {
            return new TensorflowBatchNorm3D<TFloat32>(graph, tf, op, TFloat32.class).call(input, trainingMode);
        }
        if (op.getDType() == Op.DType.FLOAT64) {
            return new TensorflowBatchNorm3D<TFloat64>(graph, tf, op, TFloat64.class).call(input, trainingMode);
        }
        throw new RuntimeException("Unsupported DType: " + String.valueOf(op.getDType()));
    }

    public static Operand<?> convert(Ops tf, Cast op, Operand<?> input) {
        if (op.getDType() == Op.DType.INT32) {
            return tf.withName((String)op.getName()).dtypes.cast(input, TInt32.class, new Cast.Options[0]);
        }
        if (op.getDType() == Op.DType.INT64) {
            return tf.withName((String)op.getName()).dtypes.cast(input, TInt64.class, new Cast.Options[0]);
        }
        if (op.getDType() == Op.DType.FLOAT32) {
            return tf.withName((String)op.getName()).dtypes.cast(input, TFloat32.class, new Cast.Options[0]);
        }
        if (op.getDType() == Op.DType.FLOAT64) {
            return tf.withName((String)op.getName()).dtypes.cast(input, TFloat64.class, new Cast.Options[0]);
        }
        if (op.getDType() == Op.DType.BYTE) {
            return tf.withName((String)op.getName()).dtypes.cast(input, TUint8.class, new Cast.Options[0]);
        }
        if (op.getDType() == Op.DType.BOOL) {
            return tf.withName((String)op.getName()).dtypes.cast(input, TBool.class, new Cast.Options[0]);
        }
        throw new RuntimeException("Unsupported DType: " + String.valueOf(op.getDType()));
    }

    public static Operand<?> convert(Ops tf, ConvLSTM2D op, Operand<?> input) {
        if (op.getDType() == Op.DType.FLOAT32) {
            return new TensorflowConvLSTM2D<TFloat32>(tf, op, TFloat32.class).call(input);
        }
        if (op.getDType() == Op.DType.FLOAT64) {
            return new TensorflowConvLSTM2D<TFloat64>(tf, op, TFloat64.class).call(input);
        }
        throw new RuntimeException("Unsupported DType: " + String.valueOf(op.getDType()));
    }

    public static Operand<?> convert(Ops tf, Conv2D op, Operand<?> input) {
        if (op.getDType() == Op.DType.FLOAT32) {
            return new TensorflowConv2D<TFloat32>(tf, op, TFloat32.class).call(input);
        }
        if (op.getDType() == Op.DType.FLOAT64) {
            return new TensorflowConv2D<TFloat64>(tf, op, TFloat64.class).call(input);
        }
        throw new RuntimeException("Unsupported DType: " + String.valueOf(op.getDType()));
    }

    public static Operand<?> convert(Ops tf, Conv3D op, Operand<?> input) {
        if (op.getDType() == Op.DType.FLOAT32) {
            return new TensorflowConv3D<TFloat32>(tf, op, TFloat32.class).call(input);
        }
        if (op.getDType() == Op.DType.FLOAT64) {
            return new TensorflowConv3D<TFloat64>(tf, op, TFloat64.class).call(input);
        }
        throw new RuntimeException("Unsupported DType: " + String.valueOf(op.getDType()));
    }

    public static Operand<?> convert(Ops tf, CrossEntropyLoss op, Operand<?> predicted, Operand<?> labels) {
        OneHot oneHot;
        if (op.getDType() == Op.DType.FLOAT32) {
            oneHot = tf.oneHot(labels, (Operand)tf.constant(op.getLabels()), (Operand)tf.constant(1.0f), (Operand)tf.constant(0.0f), new OneHot.Options[0]);
        } else if (op.getDType() == Op.DType.FLOAT64) {
            oneHot = tf.oneHot(labels, (Operand)tf.constant(op.getLabels()), (Operand)tf.constant(1.0), (Operand)tf.constant(0.0), new OneHot.Options[0]);
        } else {
            throw new RuntimeException("Unsupported DType: " + String.valueOf(op.getDType()));
        }
        return tf.withName((String)op.getName()).math.mean((Operand)tf.math.neg((Operand)tf.reduceSum((Operand)tf.math.mul((Operand)tf.math.log((Operand)tf.nn.softmax(predicted)), (Operand)oneHot), (Operand)tf.array(new int[]{1}), new ReduceSum.Options[0])), (Operand)tf.array(new int[]{0}), new Mean.Options[0]);
    }

    public static Operand<?> convert(Ops tf, MSELoss op, Operand<?> predicted, Operand<?> labels) {
        return tf.withName((String)op.getName()).math.mean((Operand)tf.math.squaredDifference(predicted, labels), (Operand)tf.array(new int[]{0}), new Mean.Options[0]);
    }

    public static Operand<TBool> convert(Ops tf, Eq op, Operand<?> left, Operand<?> right) {
        return tf.withName((String)op.getName()).math.equal(left, right, new Equal.Options[0]);
    }

    public static Operand<?> convert(Ops tf, Get op, Operand<?> input) {
        if (op.getKey() instanceof String) {
            String key = (String)op.getKey();
            if (op.getDType() == Op.DType.INT32) {
                return tf.withName(op.getName()).tensorMapLookup(input, (Operand)tf.constant(key), TInt32.class).value();
            }
            if (op.getDType() == Op.DType.INT64) {
                return tf.withName(op.getName()).tensorMapLookup(input, (Operand)tf.constant(key), TInt64.class);
            }
            if (op.getDType() == Op.DType.FLOAT32) {
                return tf.withName(op.getName()).tensorMapLookup(input, (Operand)tf.constant(key), TFloat32.class);
            }
            if (op.getDType() == Op.DType.FLOAT64) {
                return tf.withName(op.getName()).tensorMapLookup(input, (Operand)tf.constant(key), TFloat64.class);
            }
            if (op.getDType() == Op.DType.BYTE) {
                return tf.withName(op.getName()).tensorMapLookup(input, (Operand)tf.constant(key), TUint8.class);
            }
            if (op.getDType() == Op.DType.BOOL) {
                return tf.withName(op.getName()).tensorMapLookup(input, (Operand)tf.constant(key), TBool.class);
            }
            throw new RuntimeException("Unsupported DType: " + String.valueOf(op.getDType()));
        }
        throw new RuntimeException("Unsupported Key Type: " + op.getKey().getClass().getName());
    }

    public static Operand<?> convert(Ops tf, Input op) {
        Shape shape = null;
        if (op.getShape() != null) {
            shape = Shape.of((long[])Arrays.stream(op.getShape()).mapToLong(e -> e).toArray());
        }
        if (op.getDType() == Op.DType.INT32) {
            return tf.withName(op.getName()).placeholder(TInt32.class, new Placeholder.Options[]{Placeholder.shape((Shape)shape)});
        }
        if (op.getDType() == Op.DType.INT64) {
            return tf.withName(op.getName()).placeholder(TInt64.class, new Placeholder.Options[]{Placeholder.shape((Shape)shape)});
        }
        if (op.getDType() == Op.DType.FLOAT32) {
            return tf.withName(op.getName()).placeholder(TFloat32.class, new Placeholder.Options[]{Placeholder.shape((Shape)shape)});
        }
        if (op.getDType() == Op.DType.FLOAT64) {
            return tf.withName(op.getName()).placeholder(TFloat64.class, new Placeholder.Options[]{Placeholder.shape((Shape)shape)});
        }
        if (op.getDType() == Op.DType.BYTE) {
            return tf.withName(op.getName()).placeholder(TUint8.class, new Placeholder.Options[]{Placeholder.shape((Shape)shape)});
        }
        if (op.getDType() == Op.DType.BOOL) {
            return tf.withName(op.getName()).placeholder(TBool.class, new Placeholder.Options[]{Placeholder.shape((Shape)shape)});
        }
        throw new RuntimeException("Unsupported DType: " + String.valueOf(op.getDType()));
    }

    public static Operand<?> convert(Ops tf, Linear op, Operand<?> input) {
        if (op.getDType() == Op.DType.FLOAT32) {
            return new TensorflowLinear<TFloat32>(tf, op, TFloat32.class).call(input);
        }
        if (op.getDType() == Op.DType.FLOAT64) {
            return new TensorflowLinear<TFloat64>(tf, op, TFloat64.class).call(input);
        }
        throw new RuntimeException("Unsupported DType: " + String.valueOf(op.getDType()));
    }

    public static Operand<?> convert(Ops tf, Mean op, Operand<?> input) {
        return tf.withName((String)op.getName()).math.mean(input, (Operand)tf.constant(op.getDim()), new Mean.Options[0]);
    }

    public static Operand<?> convert(Ops tf, ReLU op, Operand<?> input) {
        return tf.withName((String)op.getName()).nn.relu(input);
    }

    public static Operand<?> convert(Ops tf, Reshape op, Operand<?> input) {
        return tf.withName(op.getName()).reshape(input, (Operand)tf.constant(op.getShape()));
    }

    public static Operand<?> convert(Ops tf, Sigmoid op, Operand<?> input) {
        return tf.withName((String)op.getName()).math.sigmoid(input);
    }

    public static Operand<?> convert(Ops tf, Slice op, Operand<?> input) {
        int[][] range = op.getRange();
        int n = range.length;
        int[] begin = new int[n];
        int[] size = new int[n];
        for (int i = 0; i < n; ++i) {
            begin[i] = range[i][0];
            size[i] = range[i][1];
            if (size[i] == -1) continue;
            int n2 = i;
            size[n2] = size[n2] - begin[i];
        }
        org.tensorflow.op.core.Slice out = tf.withName(op.getName()).slice(input, (Operand)tf.constant(begin), (Operand)tf.constant(size));
        return out;
    }

    public static Operand<?> convert(Ops tf, Softmax op, Operand<?> input) {
        return tf.withName((String)op.getName()).nn.softmax(input);
    }

    public static Operand<?> convert(Ops tf, Transpose op, Operand<?> input) {
        org.tensorflow.op.linalg.Transpose out = tf.withName((String)op.getName()).linalg.transpose(input, (Operand)tf.constant(op.getPerm()));
        return out;
    }

    public static Operand<?> convert(Ops tf, ZeroLike op, Operand<?> input) {
        ZerosLike out = tf.withName(op.getName()).zerosLike(input);
        return out;
    }

    public static org.tensorflow.framework.optimizers.Optimizer convert(Graph graph, Optimizer optimizer) {
        if (optimizer instanceof GradientDescent) {
            GradientDescent gd = (GradientDescent)optimizer;
            return new org.tensorflow.framework.optimizers.GradientDescent(graph, gd.getName(), gd.getLearningRate());
        }
        if (optimizer instanceof Adam) {
            Adam adam = (Adam)optimizer;
            return new org.tensorflow.framework.optimizers.Adam(graph, adam.getName(), adam.getLearningRate(), adam.getBetaOne(), adam.getBetaTwo(), adam.getEpsilon());
        }
        throw new RuntimeException("Unsupported optimizer: " + String.valueOf(optimizer.getClass()));
    }

    public static Tensor ndArrayToTensor(NdArray<?> array) {
        if (array instanceof Tensor) {
            return (Tensor)array;
        }
        if (array instanceof IntNdArray) {
            return TInt32.tensorOf((NdArray)((IntNdArray)array));
        }
        if (array instanceof LongNdArray) {
            return TInt64.tensorOf((NdArray)((LongNdArray)array));
        }
        if (array instanceof FloatNdArray) {
            return TFloat32.tensorOf((NdArray)((FloatNdArray)array));
        }
        if (array instanceof DoubleNdArray) {
            return TFloat64.tensorOf((NdArray)((DoubleNdArray)array));
        }
        if (array instanceof ByteNdArray) {
            return TUint8.tensorOf((NdArray)((ByteNdArray)array));
        }
        if (array instanceof BooleanNdArray) {
            return TBool.tensorOf((NdArray)((BooleanNdArray)array));
        }
        throw new RuntimeException("Unsupported NdArray type: " + array.getClass().getName());
    }
}

