package org.apache.mahout.classifier.df.tools;

import java.util.ArrayList;
import java.util.Random;
import org.apache.mahout.classifier.df.DecisionForest;
import org.apache.mahout.classifier.df.builder.DecisionTreeBuilder;
import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.DataLoader;
import org.apache.mahout.classifier.df.data.Dataset;
import org.apache.mahout.classifier.df.node.CategoricalNode;
import org.apache.mahout.classifier.df.node.Leaf;
import org.apache.mahout.classifier.df.node.Node;
import org.apache.mahout.classifier.df.node.NumericalNode;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/classifier/df/tools/VisualizerTest.class */
public final class VisualizerTest extends MahoutTestCase {
    private static final String[] TRAIN_DATA = {"sunny,85,85,FALSE,no", "sunny,80,90,TRUE,no", "overcast,83,86,FALSE,yes", "rainy,70,96,FALSE,yes", "rainy,68,80,FALSE,yes", "rainy,65,70,TRUE,no", "overcast,64,65,TRUE,yes", "sunny,72,95,FALSE,no", "sunny,69,70,FALSE,yes", "rainy,75,80,FALSE,yes", "sunny,75,70,TRUE,yes", "overcast,72,90,TRUE,yes", "overcast,81,75,FALSE,yes", "rainy,71,91,TRUE,no"};
    private static final String[] TEST_DATA = {"rainy,70,96,TRUE,-", "overcast,64,65,TRUE,-", "sunny,75,90,TRUE,-"};
    private static final String[] ATTR_NAMES = {"outlook", "temperature", "humidity", "windy", "play"};
    private Random rng;
    private Data data;
    private Data testData;

    @Override // org.apache.mahout.common.MahoutTestCase
    @Before
    public void setUp() throws Exception {
        super.setUp();
        this.rng = RandomUtils.getRandom();
        Dataset generateDataset = DataLoader.generateDataset("C N N C L", false, TRAIN_DATA);
        this.data = DataLoader.loadData(generateDataset, TRAIN_DATA);
        this.testData = DataLoader.loadData(generateDataset, TEST_DATA);
    }

    @Test
    public void testTreeVisualize() throws Exception {
        DecisionTreeBuilder decisionTreeBuilder = new DecisionTreeBuilder();
        decisionTreeBuilder.setM(this.data.getDataset().nbAttributes() - 1);
        assertEquals(TreeVisualizer.toString(decisionTreeBuilder.build(this.rng, this.data), this.data.getDataset(), ATTR_NAMES), "\noutlook = rainy\n|   windy = FALSE : yes\n|   windy = TRUE : no\noutlook = sunny\n|   humidity < 85 : yes\n|   humidity >= 85 : no\noutlook = overcast : yes");
    }

    @Test
    public void testPredictTrace() throws Exception {
        DecisionTreeBuilder decisionTreeBuilder = new DecisionTreeBuilder();
        decisionTreeBuilder.setM(this.data.getDataset().nbAttributes() - 1);
        Assert.assertArrayEquals(TreeVisualizer.predictTrace(decisionTreeBuilder.build(this.rng, this.data), this.testData, ATTR_NAMES), new String[]{"outlook = rainy -> windy = TRUE -> no", "outlook = overcast -> yes", "outlook = sunny -> (humidity = 90) >= 85 -> no"});
    }

    @Test
    public void testForestVisualize() throws Exception {
        NumericalNode numericalNode = new NumericalNode(2, 90.0d, new Leaf(0.0d), new CategoricalNode(0, new double[]{0.0d, 1.0d, 2.0d}, new Node[]{new NumericalNode(1, 71.0d, new Leaf(0.0d), new Leaf(1.0d)), new Leaf(1.0d), new Leaf(0.0d)}));
        ArrayList arrayList = new ArrayList();
        arrayList.add(numericalNode);
        DecisionForest decisionForest = new DecisionForest(arrayList);
        assertEquals(ForestVisualizer.toString(decisionForest, this.data.getDataset(), (String[]) null), "Tree[1]:\n2 < 90 : yes\n2 >= 90\n|   0 = rainy\n|   |   1 < 71 : yes\n|   |   1 >= 71 : no\n|   0 = sunny : no\n|   0 = overcast : yes\n");
        assertEquals(ForestVisualizer.toString(decisionForest, this.data.getDataset(), ATTR_NAMES), "Tree[1]:\nhumidity < 90 : yes\nhumidity >= 90\n|   outlook = rainy\n|   |   temperature < 71 : yes\n|   |   temperature >= 71 : no\n|   outlook = sunny : no\n|   outlook = overcast : yes\n");
    }
}
