package org.apache.mahout.classifier.bayes;

import com.google.common.base.Charsets;
import com.google.common.io.Files;
import java.io.BufferedWriter;
import java.io.File;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.classifier.ClassifierData;
import org.apache.mahout.classifier.ClassifierResult;
import org.apache.mahout.classifier.ResultAnalyzer;
import org.apache.mahout.classifier.bayes.algorithm.BayesAlgorithm;
import org.apache.mahout.classifier.bayes.algorithm.CBayesAlgorithm;
import org.apache.mahout.classifier.bayes.common.BayesParameters;
import org.apache.mahout.classifier.bayes.datastore.InMemoryBayesDatastore;
import org.apache.mahout.classifier.bayes.mapreduce.bayes.BayesClassifierDriver;
import org.apache.mahout.classifier.bayes.model.ClassifierContext;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.nlp.NGrams;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/classifier/bayes/BayesClassifierSelfTest.class */
public final class BayesClassifierSelfTest extends MahoutTestCase {
    @Override // org.apache.mahout.common.MahoutTestCase
    @Before
    public void setUp() throws Exception {
        super.setUp();
        File testTempFile = getTestTempFile("bayesinput");
        BufferedWriter newWriter = Files.newWriter(testTempFile, Charsets.UTF_8);
        for (String[] strArr : ClassifierData.DATA) {
            newWriter.write(strArr[0] + '\t' + strArr[1] + '\n');
        }
        newWriter.close();
        Path testTempFilePath = getTestTempFilePath("bayesinput");
        testTempFilePath.getFileSystem(new Configuration()).copyFromLocalFile(new Path(testTempFile.getAbsolutePath()), testTempFilePath);
    }

    @Test
    public void testSelfTestBayes() throws Exception {
        BayesParameters bayesParameters = new BayesParameters();
        bayesParameters.setGramSize(1);
        bayesParameters.set("alpha_i", "1.0");
        bayesParameters.set("dataSource", "hdfs");
        Path testTempFilePath = getTestTempFilePath("bayesinput");
        Path testTempDirPath = getTestTempDirPath("bayesmodel");
        TrainClassifier.trainNaiveBayes(testTempFilePath, testTempDirPath, bayesParameters);
        bayesParameters.set("verbose", "true");
        bayesParameters.setBasePath(testTempDirPath.toString());
        bayesParameters.set("classifierType", "bayes");
        bayesParameters.set("dataSource", "hdfs");
        bayesParameters.set("defaultCat", "unknown");
        bayesParameters.set("encoding", "UTF-8");
        bayesParameters.set("alpha_i", "1.0");
        ClassifierContext classifierContext = new ClassifierContext(new BayesAlgorithm(), new InMemoryBayesDatastore(bayesParameters));
        classifierContext.initialize();
        ResultAnalyzer resultAnalyzer = new ResultAnalyzer(classifierContext.getLabels(), bayesParameters.get("defaultCat"));
        for (String[] strArr : ClassifierData.DATA) {
            List generateNGramsWithoutLabel = new NGrams(strArr[1], bayesParameters.getGramSize()).generateNGramsWithoutLabel();
            assertEquals(3L, classifierContext.classifyDocument((String[]) generateNGramsWithoutLabel.toArray(new String[generateNGramsWithoutLabel.size()]), bayesParameters.get("defaultCat"), 100).length);
            ClassifierResult classifyDocument = classifierContext.classifyDocument((String[]) generateNGramsWithoutLabel.toArray(new String[generateNGramsWithoutLabel.size()]), bayesParameters.get("defaultCat"));
            assertEquals(strArr[0], classifyDocument.getLabel());
            resultAnalyzer.addInstance(strArr[0], classifyDocument);
        }
        int[][] confusionMatrix = resultAnalyzer.getConfusionMatrix().getConfusionMatrix();
        int i = 0;
        while (i < 3) {
            int i2 = 0;
            while (i2 < 3) {
                assertEquals(i == i2 ? 4L : 0L, confusionMatrix[i][i2]);
                i2++;
            }
            i++;
        }
        bayesParameters.set("testDirPath", testTempFilePath.toString());
        TestClassifier.classifyParallel(bayesParameters);
        int[][] confusionMatrix2 = BayesClassifierDriver.readResult(getTestTempFilePath("bayesinput-output/part*"), new Configuration(), bayesParameters).getConfusionMatrix();
        int i3 = 0;
        while (i3 < 3) {
            int i4 = 0;
            while (i4 < 3) {
                assertEquals(i3 == i4 ? 4L : 0L, confusionMatrix2[i3][i4]);
                i4++;
            }
            i3++;
        }
    }

    @Test
    public void testSelfTestCBayes() throws Exception {
        BayesParameters bayesParameters = new BayesParameters();
        bayesParameters.setGramSize(1);
        bayesParameters.set("alpha_i", "1.0");
        bayesParameters.set("dataSource", "hdfs");
        Path testTempFilePath = getTestTempFilePath("bayesinput");
        Path testTempDirPath = getTestTempDirPath("cbayesmodel");
        TrainClassifier.trainCNaiveBayes(testTempFilePath, testTempDirPath, bayesParameters);
        bayesParameters.set("verbose", "true");
        bayesParameters.setBasePath(testTempDirPath.toString());
        bayesParameters.set("classifierType", "cbayes");
        bayesParameters.set("dataSource", "hdfs");
        bayesParameters.set("defaultCat", "unknown");
        bayesParameters.set("encoding", "UTF-8");
        bayesParameters.set("alpha_i", "1.0");
        ClassifierContext classifierContext = new ClassifierContext(new CBayesAlgorithm(), new InMemoryBayesDatastore(bayesParameters));
        classifierContext.initialize();
        ResultAnalyzer resultAnalyzer = new ResultAnalyzer(classifierContext.getLabels(), bayesParameters.get("defaultCat"));
        for (String[] strArr : ClassifierData.DATA) {
            List generateNGramsWithoutLabel = new NGrams(strArr[1], bayesParameters.getGramSize()).generateNGramsWithoutLabel();
            assertEquals(3L, classifierContext.classifyDocument((String[]) generateNGramsWithoutLabel.toArray(new String[generateNGramsWithoutLabel.size()]), bayesParameters.get("defaultCat"), 100).length);
            ClassifierResult classifyDocument = classifierContext.classifyDocument((String[]) generateNGramsWithoutLabel.toArray(new String[generateNGramsWithoutLabel.size()]), bayesParameters.get("defaultCat"));
            assertEquals(strArr[0], classifyDocument.getLabel());
            resultAnalyzer.addInstance(strArr[0], classifyDocument);
        }
        int[][] confusionMatrix = resultAnalyzer.getConfusionMatrix().getConfusionMatrix();
        int i = 0;
        while (i < 3) {
            int i2 = 0;
            while (i2 < 3) {
                assertEquals(i == i2 ? 4L : 0L, confusionMatrix[i][i2]);
                i2++;
            }
            i++;
        }
        bayesParameters.set("testDirPath", testTempFilePath.toString());
        TestClassifier.classifyParallel(bayesParameters);
        int[][] confusionMatrix2 = BayesClassifierDriver.readResult(getTestTempFilePath("bayesinput-output/part*"), new Configuration(), bayesParameters).getConfusionMatrix();
        int i3 = 0;
        while (i3 < 3) {
            int i4 = 0;
            while (i4 < 3) {
                assertEquals(i3 == i4 ? 4L : 0L, confusionMatrix2[i3][i4]);
                i4++;
            }
            i3++;
        }
    }
}
