package org.apache.mahout.classifier.bayes;

import java.io.IOException;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.classifier.ClassifierData;
import org.apache.mahout.classifier.ClassifierResult;
import org.apache.mahout.classifier.ResultAnalyzer;
import org.apache.mahout.classifier.bayes.algorithm.BayesAlgorithm;
import org.apache.mahout.classifier.bayes.algorithm.CBayesAlgorithm;
import org.apache.mahout.classifier.bayes.common.BayesParameters;
import org.apache.mahout.classifier.bayes.datastore.InMemoryBayesDatastore;
import org.apache.mahout.classifier.bayes.exceptions.InvalidDatastoreException;
import org.apache.mahout.classifier.bayes.mapreduce.bayes.BayesClassifierDriver;
import org.apache.mahout.classifier.bayes.model.ClassifierContext;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.nlp.NGrams;

/* loaded from: input_file:org/apache/mahout/classifier/bayes/BayesClassifierSelfTest.class */
public class BayesClassifierSelfTest extends MahoutTestCase {
    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.mahout.common.MahoutTestCase
    public void setUp() throws Exception {
        super.setUp();
        ClassifierData.writeDataToFile("testdata/bayesinput", ClassifierData.DATA);
    }

    public void testSelfTestBayes() throws InvalidDatastoreException, IOException {
        BayesParameters bayesParameters = new BayesParameters(1);
        bayesParameters.set("alpha_i", "1.0");
        bayesParameters.set("dataSource", "hdfs");
        TrainClassifier.trainNaiveBayes("testdata/bayesinput", "testdata/bayesmodel", bayesParameters);
        bayesParameters.set("verbose", "true");
        bayesParameters.set("basePath", "testdata/bayesmodel");
        bayesParameters.set("classifierType", "bayes");
        bayesParameters.set("dataSource", "hdfs");
        bayesParameters.set("defaultCat", "unknown");
        bayesParameters.set("encoding", "UTF-8");
        bayesParameters.set("alpha_i", "1.0");
        ClassifierContext classifierContext = new ClassifierContext(new BayesAlgorithm(), new InMemoryBayesDatastore(bayesParameters));
        classifierContext.initialize();
        ResultAnalyzer resultAnalyzer = new ResultAnalyzer(classifierContext.getLabels(), bayesParameters.get("defaultCat"));
        for (String[] strArr : ClassifierData.DATA) {
            List generateNGramsWithoutLabel = new NGrams(strArr[1], Integer.parseInt(bayesParameters.get("gramSize"))).generateNGramsWithoutLabel();
            assertEquals(3, classifierContext.classifyDocument((String[]) generateNGramsWithoutLabel.toArray(new String[generateNGramsWithoutLabel.size()]), bayesParameters.get("defaultCat"), 100).length);
            ClassifierResult classifyDocument = classifierContext.classifyDocument((String[]) generateNGramsWithoutLabel.toArray(new String[generateNGramsWithoutLabel.size()]), bayesParameters.get("defaultCat"));
            assertEquals(strArr[0], classifyDocument.getLabel());
            resultAnalyzer.addInstance(strArr[0], classifyDocument);
        }
        int[][] confusionMatrix = resultAnalyzer.getConfusionMatrix().getConfusionMatrix();
        for (int i = 0; i < 3; i++) {
            for (int i2 = 0; i2 < 3; i2++) {
                if (i == i2) {
                    assertEquals(4, confusionMatrix[i][i2]);
                } else {
                    assertEquals(0, confusionMatrix[i][i2]);
                }
            }
        }
        bayesParameters.set("testDirPath", "testdata/bayesinput");
        TestClassifier.classifyParallel(bayesParameters);
        Configuration configuration = new Configuration();
        Path path = new Path("testdata/bayesinput-output/part*");
        int[][] confusionMatrix2 = BayesClassifierDriver.readResult(FileSystem.get(path.toUri(), configuration), path, configuration, bayesParameters).getConfusionMatrix();
        for (int i3 = 0; i3 < 3; i3++) {
            for (int i4 = 0; i4 < 3; i4++) {
                if (i3 == i4) {
                    assertEquals(4, confusionMatrix2[i3][i4]);
                } else {
                    assertEquals(0, confusionMatrix2[i3][i4]);
                }
            }
        }
    }

    public void testSelfTestCBayes() throws InvalidDatastoreException, IOException {
        BayesParameters bayesParameters = new BayesParameters(1);
        bayesParameters.set("alpha_i", "1.0");
        bayesParameters.set("dataSource", "hdfs");
        TrainClassifier.trainCNaiveBayes("testdata/bayesinput", "testdata/cbayesmodel", bayesParameters);
        bayesParameters.set("verbose", "true");
        bayesParameters.set("basePath", "testdata/cbayesmodel");
        bayesParameters.set("classifierType", "cbayes");
        bayesParameters.set("dataSource", "hdfs");
        bayesParameters.set("defaultCat", "unknown");
        bayesParameters.set("encoding", "UTF-8");
        bayesParameters.set("alpha_i", "1.0");
        ClassifierContext classifierContext = new ClassifierContext(new CBayesAlgorithm(), new InMemoryBayesDatastore(bayesParameters));
        classifierContext.initialize();
        ResultAnalyzer resultAnalyzer = new ResultAnalyzer(classifierContext.getLabels(), bayesParameters.get("defaultCat"));
        for (String[] strArr : ClassifierData.DATA) {
            List generateNGramsWithoutLabel = new NGrams(strArr[1], Integer.parseInt(bayesParameters.get("gramSize"))).generateNGramsWithoutLabel();
            assertEquals(3, classifierContext.classifyDocument((String[]) generateNGramsWithoutLabel.toArray(new String[generateNGramsWithoutLabel.size()]), bayesParameters.get("defaultCat"), 100).length);
            ClassifierResult classifyDocument = classifierContext.classifyDocument((String[]) generateNGramsWithoutLabel.toArray(new String[generateNGramsWithoutLabel.size()]), bayesParameters.get("defaultCat"));
            assertEquals(strArr[0], classifyDocument.getLabel());
            resultAnalyzer.addInstance(strArr[0], classifyDocument);
        }
        int[][] confusionMatrix = resultAnalyzer.getConfusionMatrix().getConfusionMatrix();
        for (int i = 0; i < 3; i++) {
            for (int i2 = 0; i2 < 3; i2++) {
                if (i == i2) {
                    assertEquals(4, confusionMatrix[i][i2]);
                } else {
                    assertEquals(0, confusionMatrix[i][i2]);
                }
            }
        }
        bayesParameters.set("testDirPath", "testdata/bayesinput");
        TestClassifier.classifyParallel(bayesParameters);
        Configuration configuration = new Configuration();
        Path path = new Path("testdata/bayesinput-output/part*");
        int[][] confusionMatrix2 = BayesClassifierDriver.readResult(FileSystem.get(path.toUri(), configuration), path, configuration, bayesParameters).getConfusionMatrix();
        for (int i3 = 0; i3 < 3; i3++) {
            for (int i4 = 0; i4 < 3; i4++) {
                if (i3 == i4) {
                    assertEquals(4, confusionMatrix2[i3][i4]);
                } else {
                    assertEquals(0, confusionMatrix2[i3][i4]);
                }
            }
        }
    }
}
