/*
 * Decompiled with CFR 0.152.
 */
package org.apache.wayang.tensorflow.model.op.nn;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.wayang.basic.model.op.nn.Conv2D;
import org.apache.wayang.basic.model.op.nn.ConvLSTM2D;
import org.apache.wayang.tensorflow.model.op.nn.TensorflowConv2D;
import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Concat;
import org.tensorflow.op.core.Gather;
import org.tensorflow.op.core.Stack;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.math.Mul;
import org.tensorflow.op.math.Sigmoid;
import org.tensorflow.op.math.Tanh;
import org.tensorflow.types.family.TNumber;

public class TensorflowConvLSTM2D<T extends TNumber> {
    private final Ops tf;
    private final ConvLSTM2D op;
    private final Cell<T> cell;
    private final Class<T> tClass;

    public TensorflowConvLSTM2D(Ops tf, ConvLSTM2D op, Class<T> tClass) {
        this.tf = tf;
        this.op = op;
        this.tClass = tClass;
        this.cell = new Cell<T>(tf, op, tClass);
    }

    public Operand<?> call(Operand<T> input) {
        long batchSize = input.shape().get(0);
        long seqLen = input.shape().get(1);
        long height = input.shape().get(3);
        long width = input.shape().get(4);
        Operand<T> h = this.tf.zeros((Operand)this.tf.array(new long[]{batchSize, this.op.getHiddenDim(), height, width}), this.tClass);
        Operand<T> c = this.tf.zeros((Operand)this.tf.array(new long[]{batchSize, this.op.getHiddenDim(), height, width}), this.tClass);
        String outKey = this.op.getOutput();
        ArrayList<Operand<T>> outputs = new ArrayList<Operand<T>>((int)seqLen);
        for (long t = 0L; t < seqLen; ++t) {
            Operand<T>[] hc = this.cell.call((Operand<T>)this.tf.gather(input, (Operand)this.tf.constant(t), (Operand)this.tf.constant(1), new Gather.Options[0]), h, c);
            h = hc[0];
            c = hc[1];
            if (!"output".equals(outKey)) continue;
            outputs.add(h);
        }
        if ("output".equals(outKey)) {
            return this.tf.stack(outputs, new Stack.Options[]{Stack.axis((Long)1L)});
        }
        if ("hidden".equals(outKey)) {
            return h;
        }
        if ("cell".equals(outKey)) {
            return c;
        }
        throw new IllegalArgumentException("Unrecognized output: " + outKey);
    }

    public static class Cell<T extends TNumber> {
        private final Ops tf;
        private final TensorflowConv2D<T> conv;

        public Cell(Ops tf, ConvLSTM2D op, Class<T> tClass) {
            this.tf = tf;
            this.conv = new TensorflowConv2D<T>(tf, new Conv2D(op.getInputDim() + op.getHiddenDim(), op.getHiddenDim() * 4, op.getKernelSize(), op.getStride(), "SAME", op.getBias(), op.getDType()), tClass);
        }

        public Operand<T>[] call(Operand<T> input, Operand<T> hCur, Operand<T> cCur) {
            Concat combined = this.tf.concat(Arrays.asList(input, hCur), (Operand)this.tf.constant(1));
            Operand<T> combinedConv = this.conv.call((Operand<T>)combined);
            List split = this.tf.split((Operand)this.tf.constant(1), combinedConv, Long.valueOf(4L)).output();
            Sigmoid i = this.tf.math.sigmoid((Operand)split.get(0));
            Sigmoid f = this.tf.math.sigmoid((Operand)split.get(1));
            Sigmoid o = this.tf.math.sigmoid((Operand)split.get(2));
            Tanh g = this.tf.math.tanh((Operand)split.get(3));
            Add cNext = this.tf.math.add((Operand)this.tf.math.mul((Operand)f, cCur), (Operand)this.tf.math.mul((Operand)i, (Operand)g));
            Mul hNext = this.tf.math.mul((Operand)o, (Operand)this.tf.math.tanh((Operand)cNext));
            return new Operand[]{hNext, cNext};
        }
    }
}

