/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.recurrent;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterType;
import ai.djl.nn.recurrent.RNN;
import java.io.DataInputStream;
import java.io.IOException;

public abstract class RecurrentBlock
extends AbstractBlock {
    private static final byte VERSION = 2;
    private static final LayoutType[] EXPECTED_LAYOUT = new LayoutType[]{LayoutType.BATCH, LayoutType.TIME, LayoutType.CHANNEL};
    protected long stateSize;
    protected float dropRate;
    protected int numLayers;
    protected int gates;
    protected boolean batchFirst;
    protected boolean hasBiases;
    protected boolean bidirectional;
    protected boolean returnState;

    public RecurrentBlock(BaseBuilder<?> builder) {
        super((byte)2);
        this.stateSize = builder.stateSize;
        this.dropRate = builder.dropRate;
        this.numLayers = builder.numLayers;
        this.batchFirst = builder.batchFirst;
        this.hasBiases = builder.hasBiases;
        this.bidirectional = builder.bidirectional;
        this.returnState = builder.returnState;
        ParameterType[] parameterTypes = new ParameterType[]{ParameterType.WEIGHT, ParameterType.BIAS};
        String[] directions = new String[]{"l"};
        if (builder.bidirectional) {
            directions = new String[]{"l", "r"};
        }
        String[] gateStrings = new String[]{"i2h", "h2h"};
        for (ParameterType parameterType : parameterTypes) {
            for (int i = 0; i < this.numLayers; ++i) {
                for (String direction : directions) {
                    for (String gateString : gateStrings) {
                        String name = direction + '_' + i + '_' + gateString + '_' + parameterType.name();
                        this.addParameter(new Parameter(name, this, parameterType));
                    }
                }
            }
        }
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
        Shape inputShape = inputs[0];
        Shape outputShape = new Shape(inputShape.get(0), inputShape.get(1), this.stateSize * (long)this.getNumDirections());
        if (!this.returnState) {
            return new Shape[]{outputShape};
        }
        return new Shape[]{outputShape, new Shape((long)this.numLayers * (long)this.getNumDirections(), inputShape.get(this.batchFirst ? 0 : 1), this.stateSize)};
    }

    @Override
    public void beforeInitialize(Shape[] inputs) {
        super.beforeInitialize(inputs);
        Shape inputShape = inputs[0];
        Block.validateLayout(EXPECTED_LAYOUT, inputShape.getLayout());
    }

    @Override
    public Shape getParameterShape(String name, Shape[] inputShapes) {
        int layer = Integer.parseInt(name.split("_")[1]);
        Shape shape = inputShapes[0];
        long inputs = shape.get(2);
        if (layer > 0) {
            inputs = this.stateSize * (long)this.getNumDirections();
        }
        if (name.contains("BIAS")) {
            return new Shape((long)this.gates * this.stateSize);
        }
        if (name.contains("i2h")) {
            return new Shape((long)this.gates * this.stateSize, inputs);
        }
        if (name.contains("h2h")) {
            return new Shape((long)this.gates * this.stateSize, this.stateSize);
        }
        throw new IllegalArgumentException("Invalid parameter name");
    }

    @Override
    public void loadMetadata(byte version, DataInputStream is) throws IOException, MalformedModelException {
        if (version == 2) {
            this.readInputShapes(is);
        } else if (version != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
    }

    protected int getNumDirections() {
        return this.bidirectional ? 2 : 1;
    }

    public static abstract class BaseBuilder<T extends BaseBuilder> {
        protected float dropRate;
        protected long stateSize;
        protected int numLayers;
        protected boolean batchFirst = true;
        protected boolean hasBiases = true;
        protected boolean bidirectional;
        protected boolean returnState;
        protected RNN.Activation activation;

        public T optDropRate(float dropRate) {
            this.dropRate = dropRate;
            return this.self();
        }

        public T setStateSize(int stateSize) {
            this.stateSize = stateSize;
            return this.self();
        }

        public T setNumLayers(int numLayers) {
            this.numLayers = numLayers;
            return this.self();
        }

        public T optBidirectional(boolean useBidirectional) {
            this.bidirectional = useBidirectional;
            return this.self();
        }

        public T optBatchFirst(boolean batchFirst) {
            this.batchFirst = batchFirst;
            return this.self();
        }

        public T optHasBiases(boolean hasBiases) {
            this.hasBiases = hasBiases;
            return this.self();
        }

        public T optReturnState(boolean returnState) {
            this.returnState = returnState;
            return this.self();
        }

        protected abstract T self();
    }
}

