package ai.djl.basicdataset;

import ai.djl.Application;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.Resource;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import ai.djl.translate.TranslateException;
import ai.djl.util.Progress;
import java.io.BufferedReader;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;

/* loaded from: input_file:ai/djl/basicdataset/AirfoilRandomAccess.class */
public final class AirfoilRandomAccess extends RandomAccessDataset {
    private static final String ARTIFACT_ID = "airfoil";
    private static final String[] FEATURE_ARRAY = {"freq", "aoa", "chordlen", "freestreamvel", "ssdt"};
    private Set<String> features;
    private Set<String> availableFeatures;
    private String label;
    private List<CSVRecord> csvRecords;
    private Dataset.Usage usage;
    private float[][] data;
    private float[] labelArray;
    private Map<String, Integer> stringToIndex;
    private Resource resource;
    private boolean prepared;

    /* renamed from: ai.djl.basicdataset.AirfoilRandomAccess$1, reason: invalid class name */
    /* loaded from: input_file:ai/djl/basicdataset/AirfoilRandomAccess$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$ai$djl$training$dataset$Dataset$Usage = new int[Dataset.Usage.values().length];

        static {
            try {
                $SwitchMap$ai$djl$training$dataset$Dataset$Usage[Dataset.Usage.TRAIN.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$ai$djl$training$dataset$Dataset$Usage[Dataset.Usage.TEST.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$ai$djl$training$dataset$Dataset$Usage[Dataset.Usage.VALIDATION.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:ai/djl/basicdataset/AirfoilRandomAccess$Builder.class */
    public static final class Builder extends RandomAccessDataset.BaseBuilder<Builder> {
        Repository repository = BasicDatasets.REPOSITORY;
        String groupId = BasicDatasets.GROUP_ID;
        String artifactId = AirfoilRandomAccess.ARTIFACT_ID;
        Dataset.Usage usage = Dataset.Usage.TRAIN;

        Builder() {
        }

        /* renamed from: self, reason: merged with bridge method [inline-methods] */
        public Builder m3self() {
            return this;
        }

        public Builder optUsage(Dataset.Usage usage) {
            this.usage = usage;
            return m3self();
        }

        public Builder optRepository(Repository repository) {
            this.repository = repository;
            return m3self();
        }

        public Builder optGroupId(String str) {
            this.groupId = str;
            return this;
        }

        public Builder optArtifactId(String str) {
            if (str.contains(":")) {
                String[] split = str.split(":");
                this.groupId = split[0];
                this.artifactId = split[1];
            } else {
                this.artifactId = str;
            }
            return this;
        }

        public AirfoilRandomAccess build() {
            return new AirfoilRandomAccess(this, null);
        }
    }

    private AirfoilRandomAccess(Builder builder) {
        super(builder);
        this.resource = new Resource(builder.repository, MRL.dataset(Application.Tabular.ANY, builder.groupId, builder.artifactId), "1.0");
        this.usage = builder.usage;
        this.features = new HashSet();
        this.availableFeatures = new HashSet(Arrays.asList(FEATURE_ARRAY));
        this.label = "ssoundpres";
        this.stringToIndex = new HashMap();
        for (int i = 0; i < FEATURE_ARRAY.length; i++) {
            this.stringToIndex.put(FEATURE_ARRAY[i], Integer.valueOf(i));
        }
        this.stringToIndex.put(this.label, Integer.valueOf(FEATURE_ARRAY.length));
    }

    public void whitenAll() throws IOException, TranslateException {
        prepare();
        float[] fArr = new float[FEATURE_ARRAY.length + 1];
        float[] fArr2 = new float[FEATURE_ARRAY.length + 1];
        for (CSVRecord cSVRecord : this.csvRecords) {
            for (String str : FEATURE_ARRAY) {
                int intValue = this.stringToIndex.get(str).intValue();
                fArr[intValue] = fArr[intValue] + getRecordFloat(cSVRecord, str);
            }
            int intValue2 = this.stringToIndex.get(this.label).intValue();
            fArr[intValue2] = fArr[intValue2] + getRecordFloat(cSVRecord, this.label);
        }
        for (int i = 0; i < fArr.length; i++) {
            int i2 = i;
            fArr[i2] = fArr[i2] / ((float) size());
        }
        for (CSVRecord cSVRecord2 : this.csvRecords) {
            for (String str2 : FEATURE_ARRAY) {
                int intValue3 = this.stringToIndex.get(str2).intValue();
                fArr2[intValue3] = fArr2[intValue3] + ((float) Math.pow(getRecordFloat(cSVRecord2, r0) - fArr[intValue3], 2.0d));
            }
            int intValue4 = this.stringToIndex.get(this.label).intValue();
            fArr2[intValue4] = fArr2[intValue4] + ((float) Math.pow(getRecordFloat(cSVRecord2, this.label) - fArr[intValue4], 2.0d));
        }
        for (int i3 = 0; i3 < fArr2.length; i3++) {
            fArr2[i3] = (float) Math.sqrt(fArr2[i3] / this.csvRecords.size());
        }
        this.data = new float[(int) size()][getFeatureArraySize()];
        this.labelArray = new float[(int) size()];
        for (int i4 = 0; i4 < size(); i4++) {
            CSVRecord cSVRecord3 = this.csvRecords.get(i4);
            for (String str3 : FEATURE_ARRAY) {
                int intValue5 = this.stringToIndex.get(str3).intValue();
                this.data[i4][intValue5] = (getRecordFloat(cSVRecord3, str3) - fArr[intValue5]) / fArr2[intValue5];
            }
            this.labelArray[i4] = (getRecordFloat(cSVRecord3, this.label) - fArr[FEATURE_ARRAY.length]) / fArr2[FEATURE_ARRAY.length];
        }
    }

    public List<String> getFeatureOrder() {
        return new ArrayList(this.features);
    }

    public float getRecordFloat(CSVRecord cSVRecord, String str) {
        return Float.parseFloat(cSVRecord.get(str));
    }

    public void selectFirstN(int i) throws IOException, TranslateException {
        prepare();
        this.csvRecords.subList(i, this.csvRecords.size()).clear();
    }

    public static Builder builder() {
        return new Builder();
    }

    public float[] getLabel(int i) {
        return new float[]{this.labelArray[i]};
    }

    protected Record get(NDManager nDManager, long j) {
        int intExact = Math.toIntExact(j);
        return new Record(new NDList(new NDArray[]{getFeatureNDArray(nDManager, intExact)}), new NDList(new NDArray[]{nDManager.create(getLabel(intExact))}));
    }

    public CSVRecord getCSVRecord(int i) {
        return this.csvRecords.get(i);
    }

    public float[] getValueFloat(CSVRecord cSVRecord, String str) {
        return new float[]{Float.parseFloat(cSVRecord.get(str))};
    }

    public int getFeatureArraySize() {
        return this.features.size();
    }

    public NDArray getFeatureNDArray(NDManager nDManager, int i) {
        float[] fArr = new float[getFeatureArraySize()];
        int i2 = 0;
        Iterator<String> it = this.features.iterator();
        while (it.hasNext()) {
            fArr[i2] = this.data[i][this.stringToIndex.get(it.next()).intValue()];
            i2++;
        }
        return nDManager.create(fArr);
    }

    public void removeAllFeatures() {
        this.availableFeatures.addAll(this.features);
        this.features.clear();
    }

    public void addAllFeatures() {
        this.features.addAll(this.availableFeatures);
        this.availableFeatures.clear();
    }

    public void addFeature(String str) {
        String lowerCase = str.toLowerCase();
        if (this.availableFeatures.contains(lowerCase)) {
            this.availableFeatures.remove(lowerCase);
            this.features.add(lowerCase);
        }
    }

    public void removeFeature(String str) {
        String lowerCase = str.toLowerCase();
        if (this.features.contains(lowerCase)) {
            this.features.remove(lowerCase);
            this.availableFeatures.add(lowerCase);
        }
    }

    public void prepare(Progress progress) throws IOException {
        if (this.prepared) {
            return;
        }
        Artifact defaultArtifact = this.resource.getDefaultArtifact();
        this.resource.prepare(defaultArtifact);
        Path resourceDirectory = this.resource.getRepository().getResourceDirectory(defaultArtifact);
        switch (AnonymousClass1.$SwitchMap$ai$djl$training$dataset$Dataset$Usage[this.usage.ordinal()]) {
            case 1:
                BufferedReader newBufferedReader = Files.newBufferedReader(resourceDirectory.resolve("airfoil_self_noise.dat"));
                Throwable th = null;
                try {
                    CSVParser cSVParser = new CSVParser(newBufferedReader, CSVFormat.TDF.withHeader(new String[]{"freq", "aoa", "chordlen", "freestreamvel", "ssdt", "ssoundpres"}).withIgnoreHeaderCase().withTrim());
                    Throwable th2 = null;
                    try {
                        try {
                            this.csvRecords = cSVParser.getRecords();
                            if (cSVParser != null) {
                                if (0 != 0) {
                                    try {
                                        cSVParser.close();
                                    } catch (Throwable th3) {
                                        th2.addSuppressed(th3);
                                    }
                                } else {
                                    cSVParser.close();
                                }
                            }
                            this.data = new float[(int) size()][FEATURE_ARRAY.length];
                            this.labelArray = new float[(int) size()];
                            for (int i = 0; i < this.csvRecords.size(); i++) {
                                for (String str : FEATURE_ARRAY) {
                                    this.data[i][this.stringToIndex.get(str).intValue()] = getRecordFloat(getCSVRecord(i), str);
                                }
                                this.labelArray[i] = getRecordFloat(getCSVRecord(i), this.label);
                            }
                            this.prepared = true;
                            return;
                        } finally {
                        }
                    } catch (Throwable th4) {
                        if (cSVParser != null) {
                            if (th2 != null) {
                                try {
                                    cSVParser.close();
                                } catch (Throwable th5) {
                                    th2.addSuppressed(th5);
                                }
                            } else {
                                cSVParser.close();
                            }
                        }
                        throw th4;
                    }
                } finally {
                    if (newBufferedReader != null) {
                        if (0 != 0) {
                            try {
                                newBufferedReader.close();
                            } catch (Throwable th6) {
                                th.addSuppressed(th6);
                            }
                        } else {
                            newBufferedReader.close();
                        }
                    }
                }
            case 2:
                throw new UnsupportedOperationException("Test data not available.");
            case 3:
            default:
                throw new UnsupportedOperationException("Validation data not available.");
        }
    }

    protected long availableSize() {
        return this.csvRecords.size();
    }

    /* synthetic */ AirfoilRandomAccess(Builder builder, AnonymousClass1 anonymousClass1) {
        this(builder);
    }
}
