package ai.djl.basicdataset.utils;

import ai.djl.basicdataset.TextDataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Sampler;
import ai.djl.util.RandomUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;

/* loaded from: input_file:ai/djl/basicdataset/utils/FixedBucketSampler.class */
public class FixedBucketSampler implements Sampler {
    private Set<Bucket> buckets;
    private int numBuckets;
    private int batchSize;
    private boolean dropLast;
    private boolean shuffle;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/djl/basicdataset/utils/FixedBucketSampler$Bucket.class */
    public static class Bucket {
        Set<Sample> samples;
        int index;

        public Bucket(int i, Set<Sample> set) {
            this.index = i;
            this.samples = set;
        }
    }

    /* loaded from: input_file:ai/djl/basicdataset/utils/FixedBucketSampler$Iterate.class */
    private class Iterate implements Iterator<List<Long>> {
        private long current;
        private long size;

        public Iterate(RandomAccessDataset randomAccessDataset) {
            if (FixedBucketSampler.this.dropLast) {
                this.size = randomAccessDataset.size() / FixedBucketSampler.this.batchSize;
            } else {
                this.size = ((randomAccessDataset.size() + FixedBucketSampler.this.batchSize) - 1) / FixedBucketSampler.this.batchSize;
            }
            if (!(randomAccessDataset instanceof TextDataset)) {
                throw new IllegalStateException("FixedBucketSampler can only be used with TextDataset");
            }
            if (FixedBucketSampler.this.buckets != null) {
                return;
            }
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < randomAccessDataset.size(); i++) {
                arrayList.add(new Sample(i, ((TextDataset) randomAccessDataset).getProcessedText(i, true).size()));
            }
            arrayList.sort(Comparator.comparingInt(sample -> {
                return sample.sentenceLength;
            }));
            FixedBucketSampler.this.buckets = new TreeSet(Comparator.comparingInt(bucket -> {
                return bucket.index;
            }));
            int size = arrayList.size() / FixedBucketSampler.this.numBuckets;
            int i2 = 0;
            int i3 = 0;
            while (true) {
                int i4 = i3;
                if (i4 >= arrayList.size()) {
                    return;
                }
                int i5 = i4 + size;
                if (i5 > arrayList.size()) {
                    i5 = arrayList.size();
                }
                int i6 = i2;
                i2++;
                FixedBucketSampler.this.buckets.add(new Bucket(i6, new HashSet(arrayList.subList(i4, i5))));
                i3 = i4 + size;
            }
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return this.current < this.size;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public List<Long> next() {
            int i = 0;
            ArrayList arrayList = new ArrayList();
            Iterator<Bucket> it = FixedBucketSampler.this.buckets.iterator();
            Bucket firstBucket = firstBucket(it);
            while (true) {
                Bucket bucket = firstBucket;
                if (i >= FixedBucketSampler.this.batchSize) {
                    break;
                }
                Set<Sample> set = bucket.samples;
                ArrayList arrayList2 = new ArrayList();
                Iterator<Sample> it2 = set.iterator();
                while (it2.hasNext()) {
                    arrayList2.add(it2.next());
                    i++;
                    if (i >= FixedBucketSampler.this.batchSize) {
                        break;
                    }
                }
                Iterator it3 = arrayList2.iterator();
                while (it3.hasNext()) {
                    set.remove((Sample) it3.next());
                }
                arrayList.addAll(arrayList2);
                if (i >= FixedBucketSampler.this.batchSize) {
                    break;
                }
                if (!it.hasNext()) {
                    if (!FixedBucketSampler.this.shuffle) {
                        throw new IllegalStateException("Code should never reach here");
                    }
                    it = FixedBucketSampler.this.buckets.iterator();
                }
                firstBucket = it.next();
            }
            ArrayList arrayList3 = new ArrayList();
            Iterator it4 = arrayList.iterator();
            while (it4.hasNext()) {
                arrayList3.add(Long.valueOf(((Sample) it4.next()).index));
            }
            this.current++;
            return arrayList3;
        }

        private Bucket firstBucket(Iterator<Bucket> it) {
            if (!FixedBucketSampler.this.shuffle) {
                return it.next();
            }
            int nextInt = RandomUtils.nextInt(FixedBucketSampler.this.buckets.size());
            for (int i = 0; i < nextInt - 1; i++) {
                it.next();
            }
            return it.next();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/djl/basicdataset/utils/FixedBucketSampler$Sample.class */
    public static class Sample {
        int sentenceLength;
        long index;

        public Sample(int i, int i2) {
            this.index = i;
            this.sentenceLength = i2;
        }
    }

    public FixedBucketSampler(int i, int i2, boolean z, boolean z2) {
        this.numBuckets = i2;
        this.batchSize = i;
        this.dropLast = z;
        this.shuffle = z2;
    }

    public FixedBucketSampler(int i, int i2) {
        this(i2, i, false, true);
    }

    public FixedBucketSampler(int i) {
        this(10, i);
    }

    public Iterator<List<Long>> sample(RandomAccessDataset randomAccessDataset) {
        return new Iterate(randomAccessDataset);
    }

    public int getBatchSize() {
        return this.batchSize;
    }
}
