/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicmodelzoo.cv.object_detection.ssd;

import ai.djl.MalformedModelException;
import ai.djl.modality.cv.MultiBoxPrior;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.LambdaBlock;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.pooling.Pool;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public final class SingleShotDetection
extends AbstractBlock {
    private static final byte VERSION = 2;
    private List<Block> features;
    private List<Block> classPredictionBlocks;
    private List<Block> anchorPredictionBlocks;
    private List<MultiBoxPrior> multiBoxPriors;
    private int numClasses;

    private SingleShotDetection(Builder builder) {
        super((byte)2);
        this.features = builder.features;
        this.features.forEach(block -> this.addChildBlock(block.getClass().getSimpleName(), (Block)block));
        this.numClasses = builder.numClasses;
        this.classPredictionBlocks = builder.classPredictionBlocks;
        this.classPredictionBlocks.forEach(block -> this.addChildBlock(block.getClass().getSimpleName(), (Block)block));
        this.anchorPredictionBlocks = builder.anchorPredictionBlocks;
        this.anchorPredictionBlocks.forEach(block -> this.addChildBlock(block.getClass().getSimpleName(), (Block)block));
        this.multiBoxPriors = builder.multiBoxPriors;
    }

    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        NDList networkOutput = inputs;
        NDArray[] anchorsOutputs = new NDArray[this.features.size()];
        NDArray[] classOutputs = new NDArray[this.features.size()];
        NDArray[] boundingBoxOutputs = new NDArray[this.features.size()];
        for (int i = 0; i < this.features.size(); ++i) {
            networkOutput = this.features.get(i).forward(parameterStore, networkOutput, training);
            MultiBoxPrior multiBoxPrior = this.multiBoxPriors.get(i);
            anchorsOutputs[i] = multiBoxPrior.generateAnchorBoxes(networkOutput.singletonOrThrow());
            classOutputs[i] = this.classPredictionBlocks.get(i).forward(parameterStore, networkOutput, training).singletonOrThrow();
            boundingBoxOutputs[i] = this.anchorPredictionBlocks.get(i).forward(parameterStore, networkOutput, training).singletonOrThrow();
        }
        NDArray anchors = NDArrays.concat((NDList)new NDList(anchorsOutputs), (int)1);
        NDArray classPredictions = this.concatPredictions(new NDList(classOutputs));
        NDArray boundingBoxPredictions = this.concatPredictions(new NDList(boundingBoxOutputs));
        classPredictions = classPredictions.reshape(new long[]{classPredictions.size(0), -1L, this.numClasses + 1});
        return new NDList(new NDArray[]{anchors, classPredictions, boundingBoxPredictions});
    }

    private NDArray concatPredictions(NDList output) {
        NDArray[] flattenOutput = (NDArray[])output.stream().map(array -> array.transpose(new int[]{0, 2, 3, 1}).reshape(new long[]{array.size(0), -1L})).toArray(NDArray[]::new);
        return NDArrays.concat((NDList)new NDList(flattenOutput), (int)1);
    }

    public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
        Shape[] childInputShapes = inputShapes;
        Shape[] anchorShapes = new Shape[this.features.size()];
        Shape[] classPredictionShapes = new Shape[this.features.size()];
        Shape[] anchorPredictionShapes = new Shape[this.features.size()];
        for (int i = 0; i < this.features.size(); ++i) {
            childInputShapes = this.features.get(i).getOutputShapes(manager, childInputShapes);
            anchorShapes[i] = this.multiBoxPriors.get(i).generateAnchorBoxes(manager.ones(childInputShapes[0])).getShape();
            classPredictionShapes[i] = this.classPredictionBlocks.get(i).getOutputShapes(manager, childInputShapes)[0];
            anchorPredictionShapes[i] = this.anchorPredictionBlocks.get(i).getOutputShapes(manager, childInputShapes)[0];
        }
        Shape anchorOutputShape = new Shape(new long[0]);
        for (Shape shape : anchorShapes) {
            anchorOutputShape = this.concatShape(anchorOutputShape, shape, 1);
        }
        NDList classPredictions = new NDList();
        for (Shape shape : classPredictionShapes) {
            classPredictions.add((Object)manager.ones(shape));
        }
        NDArray classPredictionOutput = this.concatPredictions(classPredictions);
        Shape classPredictionOutputShape = classPredictionOutput.reshape(new long[]{classPredictionOutput.size(0), -1L, this.numClasses + 1}).getShape();
        NDList anchorPredictions = new NDList();
        for (Shape shape : anchorPredictionShapes) {
            anchorPredictions.add((Object)manager.ones(shape));
        }
        Shape shape = this.concatPredictions(anchorPredictions).getShape();
        return new Shape[]{anchorOutputShape, classPredictionOutputShape, shape};
    }

    private Shape concatShape(Shape shape, Shape concat, int axis) {
        if (shape.dimension() == 0) {
            return concat;
        }
        if (shape.dimension() != concat.dimension()) {
            throw new IllegalArgumentException("Shapes must have same dimensions");
        }
        long[] dimensions = new long[shape.dimension()];
        for (int i = 0; i < shape.dimension(); ++i) {
            if (axis == i) {
                dimensions[i] = shape.get(i) + concat.get(i);
                continue;
            }
            if (shape.get(i) != concat.get(i)) {
                throw new UnsupportedOperationException("These shapes cannot be concatenated along axis " + i);
            }
            dimensions[i] = shape.get(i);
        }
        return new Shape(dimensions);
    }

    public Shape[] initialize(NDManager manager, DataType dataType, Shape ... inputShapes) {
        this.beforeInitialize(inputShapes);
        Shape[] shapes = inputShapes;
        for (int i = 0; i < this.features.size(); ++i) {
            shapes = this.features.get(i).initialize(manager, dataType, shapes);
            this.classPredictionBlocks.get(i).initialize(manager, dataType, shapes);
            this.anchorPredictionBlocks.get(i).initialize(manager, dataType, shapes);
        }
        return this.getOutputShapes(manager, inputShapes);
    }

    public void loadMetadata(byte version, DataInputStream is) throws IOException, MalformedModelException {
        if (version == 2) {
            this.readInputShapes(is);
        } else if (version != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
    }

    public static SequentialBlock getDownSamplingBlock(int numFilters) {
        SequentialBlock sequentialBlock = new SequentialBlock();
        for (int i = 0; i < 2; ++i) {
            sequentialBlock.add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setKernelShape(new Shape(new long[]{3L, 3L}))).setFilters(numFilters)).optPadding(new Shape(new long[]{1L, 1L}))).build()).add((Block)BatchNorm.builder().build()).add(Activation::relu);
        }
        sequentialBlock.add(Pool.maxPool2dBlock((Shape)new Shape(new long[]{2L, 2L}), (Shape)new Shape(new long[]{2L, 2L}), (Shape)new Shape(new long[]{0L, 0L})));
        return sequentialBlock;
    }

    public static Conv2d getClassPredictionBlock(int numAnchors, int numClasses) {
        return ((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setKernelShape(new Shape(new long[]{3L, 3L}))).setFilters((numClasses + 1) * numAnchors)).optPadding(new Shape(new long[]{1L, 1L}))).build();
    }

    public static Conv2d getAnchorPredictionBlock(int numAnchors) {
        return ((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setKernelShape(new Shape(new long[]{3L, 3L}))).setFilters(4 * numAnchors)).optPadding(new Shape(new long[]{1L, 1L}))).build();
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private Block network;
        private int numFeatures = -1;
        private List<Block> features;
        private List<List<Float>> sizes;
        private List<List<Float>> ratios;
        private List<Block> classPredictionBlocks = new ArrayList<Block>();
        private List<Block> anchorPredictionBlocks = new ArrayList<Block>();
        private List<MultiBoxPrior> multiBoxPriors = new ArrayList<MultiBoxPrior>();
        private int numClasses;
        private boolean globalPool = true;

        Builder() {
        }

        public Builder setSizes(List<List<Float>> sizes) {
            this.sizes = sizes;
            return this;
        }

        public Builder setRatios(List<List<Float>> ratios) {
            this.ratios = ratios;
            return this;
        }

        public Builder setNumClasses(int numClasses) {
            this.numClasses = numClasses;
            return this;
        }

        public Builder setBaseNetwork(Block network) {
            this.network = network;
            return this;
        }

        public Builder setNumFeatures(int numFeatures) {
            this.numFeatures = numFeatures;
            return this;
        }

        public Builder optFeatures(List<Block> features) {
            this.features = features;
            return this;
        }

        public Builder optGlobalPool(boolean globalPool) {
            this.globalPool = globalPool;
            return this;
        }

        public SingleShotDetection build() {
            if (this.features == null && this.numFeatures < 0) {
                throw new IllegalArgumentException("Either numFeatures or features must be set");
            }
            if (this.features == null) {
                this.features = new ArrayList<Block>();
                this.features.add(this.network);
                for (int i = 0; i < this.numFeatures; ++i) {
                    this.features.add((Block)SingleShotDetection.getDownSamplingBlock(128));
                }
            }
            if (this.globalPool) {
                this.features.add((Block)LambdaBlock.singleton(array -> {
                    NDArray result = Pool.globalAvgPool2d((NDArray)array);
                    return result.reshape(result.getShape().add(new long[]{1L, 1L}));
                }));
            }
            int numberOfFeatureMaps = this.features.size();
            if (this.sizes.size() != this.ratios.size() || this.sizes.size() != numberOfFeatureMaps) {
                throw new IllegalArgumentException("Sizes and ratios must be of size: " + numberOfFeatureMaps);
            }
            for (int i = 0; i < numberOfFeatureMaps; ++i) {
                List<Float> size = this.sizes.get(i);
                List<Float> ratio = this.ratios.get(i);
                int numAnchors = size.size() + ratio.size() - 1;
                this.classPredictionBlocks.add((Block)SingleShotDetection.getClassPredictionBlock(numAnchors, this.numClasses));
                this.anchorPredictionBlocks.add((Block)SingleShotDetection.getAnchorPredictionBlock(numAnchors));
                this.multiBoxPriors.add(MultiBoxPrior.builder().setSizes(size).setRatios(ratio).build());
            }
            return new SingleShotDetection(this);
        }
    }
}

