/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.test.unit.classify;

import com.aliasi.classify.Classification;
import com.aliasi.classify.Classified;
import com.aliasi.classify.LogisticRegressionClassifier;
import com.aliasi.corpus.XValidatingObjectCorpus;
import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.RegressionPrior;
import com.aliasi.tokenizer.RegExTokenizerFactory;
import com.aliasi.tokenizer.TokenFeatureExtractor;
import java.io.IOException;
import java.util.Random;
import junit.framework.Assert;
import org.junit.Test;

public class LogisticRegressionClassifierTest {
    @Test
    public void test1() throws IOException {
        Random random = new Random();
        int numFolds = 10;
        XValidatingObjectCorpus corpus = new XValidatingObjectCorpus(numFolds);
        int j = 0;
        while (j < 4) {
            Classification c = new Classification("cat_" + (char)(97 + j));
            int i = 0;
            while (i < 100) {
                StringBuilder input = LogisticRegressionClassifierTest.generateExample(j);
                Classified<StringBuilder> classified = new Classified<StringBuilder>(input, c);
                corpus.handle(classified);
                ++i;
            }
            ++j;
        }
        corpus.permuteCorpus(random);
        TokenFeatureExtractor featureExtractor = new TokenFeatureExtractor(new RegExTokenizerFactory("\\S+"));
        boolean addIntercept = true;
        RegressionPrior prior = RegressionPrior.noninformative();
        int priorBlockSize = 4;
        double initLearningRate = 0.01;
        double annealingRate = 500.0;
        double minImprovement = 0.001;
        int minEpochs = 2;
        int maxEpochs = 10000;
        int minFeatureCount = 2;
        int rollingAverageSize = 5;
        AnnealingSchedule annealingSchedule = AnnealingSchedule.inverse(initLearningRate, annealingRate);
        LogisticRegressionClassifier<CharSequence> classifier = LogisticRegressionClassifier.train(corpus, featureExtractor, minFeatureCount, addIntercept, prior, priorBlockSize, null, annealingSchedule, minImprovement, rollingAverageSize, minEpochs, maxEpochs, null, null);
        int j2 = 0;
        while (j2 < 4) {
            Classification c = new Classification("cat_" + (char)(97 + j2));
            int i = 0;
            while (i < 10) {
                StringBuilder sb = LogisticRegressionClassifierTest.generateExample(j2);
                Assert.assertEquals((String)c.bestCategory(), (String)classifier.classify((Object)sb).bestCategory());
                ++i;
            }
            ++j2;
        }
        priorBlockSize = 2;
        LogisticRegressionClassifier<CharSequence> classifier2 = LogisticRegressionClassifier.train(corpus, featureExtractor, minFeatureCount, addIntercept, prior, priorBlockSize, classifier, annealingSchedule, minImprovement /= 1000.0, rollingAverageSize, minEpochs, maxEpochs, null, null);
    }

    static StringBuilder generateExample(int j) {
        Random random = new Random();
        StringBuilder sb = new StringBuilder();
        int k = 0;
        while (k < 100) {
            if (k > 0) {
                sb.append(' ');
            }
            if (random.nextBoolean()) {
                sb.append((char)(97 + j));
            } else {
                sb.append((char)(97 + random.nextInt(10)));
            }
            ++k;
        }
        return sb;
    }
}

