/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicmodelzoo.nlp;

import ai.djl.modality.nlp.Decoder;
import ai.djl.modality.nlp.embedding.TrainableTextEmbedding;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.nn.recurrent.RecurrentBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;

public class SimpleTextDecoder
extends Decoder {
    private static final byte VERSION = 1;
    private RecurrentBlock recurrentBlock;

    public SimpleTextDecoder(RecurrentBlock recurrentBlock, int vocabSize) {
        this(null, recurrentBlock, vocabSize);
    }

    public SimpleTextDecoder(TrainableTextEmbedding trainableTextEmbedding, RecurrentBlock recurrentBlock, long vocabSize) {
        super((byte)1, SimpleTextDecoder.getBlock(trainableTextEmbedding, recurrentBlock, vocabSize));
        this.recurrentBlock = recurrentBlock;
    }

    private static Block getBlock(TrainableTextEmbedding trainableTextEmbedding, RecurrentBlock recurrentBlock, long vocabSize) {
        SequentialBlock sequentialBlock = new SequentialBlock();
        sequentialBlock.add((Block)trainableTextEmbedding).add((Block)recurrentBlock).add((Block)Linear.builder().setUnits(vocabSize).build());
        return sequentialBlock;
    }

    public void initState(NDList encoderStates) {
        this.recurrentBlock.setBeginStates(encoderStates);
    }

    public NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        if (training) {
            return this.block.forward(parameterStore, inputs, true, params);
        }
        Shape inputShape = ((NDArray)inputs.get(0)).getShape();
        if (inputShape.get(1) != 1L) {
            throw new IllegalArgumentException("Input sequence length must be 1 during prediction");
        }
        NDList output = new NDList();
        for (int i = 0; i < 10; ++i) {
            inputs = this.block.forward(parameterStore, inputs, false);
            inputs = new NDList(new NDArray[]{inputs.head().argMax(2)});
            output.add((Object)inputs.head().transpose(new int[]{1, 0}));
        }
        return new NDList(new NDArray[]{NDArrays.stack((NDList)output).transpose(new int[]{2, 1, 0})});
    }
}

