/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.transformer;

import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.LambdaBlock;
import ai.djl.nn.core.Linear;
import ai.djl.training.ParameterStore;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;

public class PointwiseFeedForwardBlock
extends AbstractBlock {
    private static final byte VERSION = 1;
    private Shape outputShape;

    public PointwiseFeedForwardBlock(List<Integer> hiddenSizes, int outputSize, Function<NDList, NDList> activationFunction) {
        super((byte)1);
        int count = 0;
        for (int hiddenSize : hiddenSizes) {
            this.addChildBlock("linear_" + count, Linear.builder().optBias(true).setUnits(hiddenSize).build());
            this.addChildBlock("activation_" + count, new LambdaBlock(activationFunction));
            ++count;
        }
        this.addChildBlock("output_layer", Linear.builder().optBias(true).setUnits(outputSize).build());
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
        return new Shape[]{this.outputShape};
    }

    @Override
    public void initializeChildBlocks(NDManager manager, DataType dataType, Shape ... inputShapes) {
        this.inputNames = Collections.singletonList("input");
        if (inputShapes.length != 1) {
            throw new IllegalArgumentException("Pointwise feed forward blocks can only have one input.");
        }
        Shape inputShape = inputShapes[0];
        if (inputShape.dimension() < 2) {
            throw new IllegalArgumentException("Pointwise feed forward blocks need an input of at least dimension 2.");
        }
        Shape lastShape = inputShape;
        for (Block child : this.children.values()) {
            lastShape = child.initialize(manager, dataType, lastShape)[0];
        }
        this.outputShape = lastShape;
    }

    @Override
    protected NDList forwardInternal(ParameterStore ps, NDList inputs, boolean training, PairList<String, Object> params) {
        NDList layerResult = inputs;
        for (Pair child : this.getChildren()) {
            layerResult = ((Block)child.getValue()).forward(ps, layerResult, training);
        }
        return layerResult;
    }
}

