package org.apache.mahout.classifier.df.mapreduce.inmem;

import java.util.List;
import java.util.Random;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.mahout.classifier.df.mapreduce.Builder;
import org.apache.mahout.classifier.df.mapreduce.inmem.InMemInputFormat;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormatTest.class */
public final class InMemInputFormatTest extends MahoutTestCase {
    @Test
    public void testSplits() throws Exception {
        Random random = RandomUtils.getRandom();
        for (int i = 0; i < 1; i++) {
            int nextInt = random.nextInt(100) + 1;
            int nextInt2 = random.nextInt(1000) + 1;
            Configuration configuration = new Configuration();
            Builder.setNbTrees(configuration, nextInt2);
            List splits = new InMemInputFormat().getSplits(configuration, nextInt);
            assertEquals(nextInt, splits.size());
            int i2 = nextInt2 / nextInt;
            int i3 = 0;
            int i4 = 0;
            for (int i5 = 0; i5 < nextInt; i5++) {
                assertTrue(splits.get(i5) instanceof InMemInputFormat.InMemInputSplit);
                InMemInputFormat.InMemInputSplit inMemInputSplit = (InMemInputFormat.InMemInputSplit) splits.get(i5);
                assertEquals(i4, inMemInputSplit.getFirstId());
                if (i5 < nextInt - 1) {
                    assertEquals(i2, inMemInputSplit.getNbTrees());
                } else {
                    assertEquals(nextInt2 - i3, inMemInputSplit.getNbTrees());
                }
                i3 += inMemInputSplit.getNbTrees();
                i4 += inMemInputSplit.getNbTrees();
            }
        }
    }

    @Test
    public void testRecordReader() throws Exception {
        Random random = RandomUtils.getRandom();
        for (int i = 0; i < 1; i++) {
            int nextInt = random.nextInt(100) + 1;
            int nextInt2 = random.nextInt(1000) + 1;
            Configuration configuration = new Configuration();
            Builder.setNbTrees(configuration, nextInt2);
            List splits = new InMemInputFormat().getSplits(configuration, nextInt);
            for (int i2 = 0; i2 < nextInt; i2++) {
                InMemInputFormat.InMemInputSplit inMemInputSplit = (InMemInputFormat.InMemInputSplit) splits.get(i2);
                InMemInputFormat.InMemRecordReader inMemRecordReader = new InMemInputFormat.InMemRecordReader(inMemInputSplit);
                inMemRecordReader.initialize(inMemInputSplit, (TaskAttemptContext) null);
                int i3 = 0;
                while (i3 < inMemInputSplit.getNbTrees()) {
                    assertEquals(Boolean.valueOf(i3 < inMemInputSplit.getNbTrees()), Boolean.valueOf(inMemRecordReader.nextKeyValue()));
                    assertEquals(inMemInputSplit.getFirstId() + i3, inMemRecordReader.getCurrentKey().get());
                    i3++;
                }
            }
        }
    }
}
