package org.apache.mahout.df.mapreduce.partial;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import org.apache.commons.lang.ArrayUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.df.builder.DefaultTreeBuilder;
import org.apache.mahout.df.builder.TreeBuilder;
import org.apache.mahout.df.callback.PredictionCallback;
import org.apache.mahout.df.mapreduce.MapredOutput;
import org.apache.mahout.df.node.Leaf;
import org.apache.mahout.df.node.Node;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/df/mapreduce/partial/PartialBuilderTest.class */
public final class PartialBuilderTest extends MahoutTestCase {
    private static final int NUM_MAPS = 5;
    private static final int NUM_TREES = 32;
    private static final int NUM_INSTANCES = 20;

    /* loaded from: input_file:org/apache/mahout/df/mapreduce/partial/PartialBuilderTest$PartialBuilderChecker.class */
    static class PartialBuilderChecker extends PartialBuilder {
        private final Long seed;
        private final TreeBuilder treeBuilder;
        private final Path datasetPath;

        PartialBuilderChecker(TreeBuilder treeBuilder, Path path, Path path2, Long l) {
            super(treeBuilder, path, path2, l);
            this.seed = l;
            this.treeBuilder = treeBuilder;
            this.datasetPath = path2;
        }

        protected boolean runJob(Job job) throws IOException {
            Configuration configuration = job.getConfiguration();
            Assert.assertEquals(this.seed, getRandomSeed(configuration));
            Assert.assertEquals(1L, configuration.getInt("mapred.map.tasks", -1));
            Assert.assertEquals(32L, getNbTrees(configuration));
            Assert.assertFalse(isOutput(configuration));
            Assert.assertTrue(isOobEstimate(configuration));
            Assert.assertEquals(this.treeBuilder, getTreeBuilder(configuration));
            Assert.assertEquals(this.datasetPath, getDistributedCacheFile(configuration, 0));
            return true;
        }
    }

    /* loaded from: input_file:org/apache/mahout/df/mapreduce/partial/PartialBuilderTest$TestCallback.class */
    static class TestCallback implements PredictionCallback {
        private final TreeID[] keys;
        private final MapredOutput[] values;

        TestCallback(TreeID[] treeIDArr, MapredOutput[] mapredOutputArr) {
            this.keys = treeIDArr;
            this.values = mapredOutputArr;
        }

        public void prediction(int i, int i2, int i3) {
            Assert.assertTrue("key not found", ArrayUtils.indexOf(this.keys, new TreeID(i2 / PartialBuilderTest.NUM_INSTANCES, i)) >= 0);
            Assert.assertEquals(this.values[r0].getPredictions()[i2 % PartialBuilderTest.NUM_INSTANCES], i3);
        }
    }

    @Test
    public void testProcessOutput() throws Exception {
        Configuration configuration = new Configuration();
        configuration.setInt("mapred.map.tasks", NUM_MAPS);
        Random random = RandomUtils.getRandom();
        Writable[] writableArr = new TreeID[NUM_TREES];
        Writable[] writableArr2 = new MapredOutput[NUM_TREES];
        int[] iArr = new int[NUM_MAPS];
        randomKeyValues(random, writableArr, writableArr2, iArr);
        Path testTempDirPath = getTestTempDirPath("testdata");
        SequenceFile.Writer createWriter = SequenceFile.createWriter(testTempDirPath.getFileSystem(configuration), configuration, new Path(testTempDirPath, "PartialBuilderTest.seq"), TreeID.class, MapredOutput.class);
        for (int i = 0; i < NUM_TREES; i++) {
            createWriter.append(writableArr[i], writableArr2[i]);
        }
        createWriter.close();
        TreeID[] treeIDArr = new TreeID[NUM_TREES];
        Node[] nodeArr = new Node[NUM_TREES];
        PartialBuilder.processOutput(new Job(configuration), testTempDirPath, iArr, treeIDArr, nodeArr, new TestCallback(writableArr, writableArr2));
        for (int i2 = 0; i2 < NUM_TREES; i2++) {
            assertEquals(writableArr2[i2].getTree(), nodeArr[i2]);
        }
        assertTrue("keys not equal", Arrays.deepEquals(writableArr, treeIDArr));
    }

    @Test
    public void testConfigure() {
        new PartialBuilderChecker(new DefaultTreeBuilder(), new Path("notUsedDataPath"), new Path("notUsedDatasetPath"), 5L);
    }

    private static void randomKeyValues(Random random, TreeID[] treeIDArr, MapredOutput[] mapredOutputArr, int[] iArr) {
        int nextInt;
        int i = 0;
        int i2 = 0;
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < NUM_MAPS; i3++) {
            do {
                nextInt = random.nextInt(NUM_MAPS);
            } while (arrayList.contains(Integer.valueOf(nextInt)));
            arrayList.add(Integer.valueOf(nextInt));
            int nbTrees = Step1Mapper.nbTrees(NUM_MAPS, NUM_TREES, nextInt);
            for (int i4 = 0; i4 < nbTrees; i4++) {
                Leaf leaf = new Leaf(random.nextInt(100));
                treeIDArr[i] = new TreeID(nextInt, i4);
                mapredOutputArr[i] = new MapredOutput(leaf, nextIntArray(random, NUM_INSTANCES));
                i++;
            }
            iArr[i3] = i2;
            i2 += NUM_INSTANCES;
        }
    }

    private static int[] nextIntArray(Random random, int i) {
        int[] iArr = new int[i];
        for (int i2 = 0; i2 < i; i2++) {
            iArr[i2] = random.nextInt(101) - 1;
        }
        return iArr;
    }
}
