/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicmodelzoo.cv.classification;

import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.norm.Dropout;
import ai.djl.nn.pooling.Pool;

public final class NiN {
    private NiN() {
    }

    public static Block niN(Builder builder) {
        NiN nin = new NiN();
        return new SequentialBlock().add((Block)nin.niNBlock(builder.numChannels[0], new Shape(new long[]{11L, 11L}), new Shape(new long[]{4L, 4L}), new Shape(new long[]{0L, 0L}))).add(Pool.maxPool2dBlock((Shape)new Shape(new long[]{3L, 3L}), (Shape)new Shape(new long[]{2L, 2L}))).add((Block)nin.niNBlock(builder.numChannels[1], new Shape(new long[]{5L, 5L}), new Shape(new long[]{1L, 1L}), new Shape(new long[]{2L, 2L}))).add(Pool.maxPool2dBlock((Shape)new Shape(new long[]{3L, 3L}), (Shape)new Shape(new long[]{2L, 2L}))).add((Block)nin.niNBlock(builder.numChannels[2], new Shape(new long[]{3L, 3L}), new Shape(new long[]{1L, 1L}), new Shape(new long[]{1L, 1L}))).add(Pool.maxPool2dBlock((Shape)new Shape(new long[]{3L, 3L}), (Shape)new Shape(new long[]{2L, 2L}))).add((Block)Dropout.builder().optRate(builder.dropOutRate).build()).add((Block)nin.niNBlock(builder.numChannels[3], new Shape(new long[]{3L, 3L}), new Shape(new long[]{1L, 1L}), new Shape(new long[]{1L, 1L}))).add(Pool.globalAvgPool2dBlock()).add(Blocks.batchFlattenBlock());
    }

    public static Builder builder() {
        return new Builder();
    }

    public SequentialBlock niNBlock(int numChannels, Shape kernelShape, Shape strideShape, Shape paddingShape) {
        return new SequentialBlock().add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setKernelShape(kernelShape)).optStride(strideShape)).optPadding(paddingShape)).setFilters(numChannels)).build()).add(Activation::relu).add((Block)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setKernelShape(new Shape(new long[]{1L, 1L}))).setFilters(numChannels)).build()).add(Activation::relu).add((Block)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setKernelShape(new Shape(new long[]{1L, 1L}))).setFilters(numChannels)).build()).add(Activation::relu);
    }

    public static final class Builder {
        int numLayers = 4;
        int[] numChannels = new int[]{96, 256, 384, 10};
        float dropOutRate = 0.5f;

        Builder() {
        }

        public Builder setDropOutRate(float dropOutRate) {
            this.dropOutRate = dropOutRate;
            return this;
        }

        public Builder setNumChannels(int[] numChannels) {
            if (numChannels.length != this.numLayers) {
                throw new IllegalArgumentException("number of channels can be equal to " + this.numLayers);
            }
            this.numChannels = numChannels;
            return this;
        }

        public Block build() {
            return NiN.niN(this);
        }
    }
}

