/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.test.unit.stats;

import com.aliasi.stats.RegressionPrior;
import com.aliasi.test.unit.Asserts;
import com.aliasi.util.AbstractExternalizable;
import java.io.IOException;
import junit.framework.Assert;
import org.junit.Test;

public class RegressionPriorTest {
    @Test
    public void testMeans() {
        RegressionPrior prior1 = RegressionPrior.gaussian(1.0, true);
        Assert.assertEquals((Object)0.0, (Object)prior1.mode(0));
        Assert.assertEquals((Object)0.0, (Object)prior1.mode(1));
        RegressionPrior prior2 = RegressionPrior.shiftMeans(new double[]{1.0, 2.0, -3.0}, prior1);
        Assert.assertEquals((Object)1.0, (Object)prior2.mode(0));
        Assert.assertEquals((Object)2.0, (Object)prior2.mode(1));
        Assert.assertEquals((Object)-3.0, (Object)prior2.mode(2));
        Assert.assertEquals((double)0.0, (double)prior2.gradient(1.0, 0), (double)1.0E-4);
        Assert.assertEquals((double)0.0, (double)prior2.gradient(2.0, 1), (double)1.0E-4);
        Assert.assertEquals((double)0.0, (double)prior2.gradient(-3.0, 2), (double)1.0E-4);
        RegressionPrior prior3 = RegressionPrior.shiftMeans(new double[]{2.0, 1.0, 3.0}, prior2);
        Assert.assertEquals((Object)3.0, (Object)prior3.mode(0));
        Assert.assertEquals((Object)3.0, (Object)prior3.mode(1));
        Assert.assertEquals((Object)0.0, (Object)prior3.mode(2));
        Assert.assertEquals((double)0.0, (double)prior3.gradient(3.0, 0), (double)1.0E-4);
        Assert.assertEquals((double)0.0, (double)prior3.gradient(3.0, 1), (double)1.0E-4);
        Assert.assertEquals((double)0.0, (double)prior3.gradient(0.0, 2), (double)1.0E-4);
    }

    @Test
    public void testElasticNet() {
        RegressionPrior prior = RegressionPrior.elasticNet(0.3, 2.0, true);
        RegressionPrior laplacePrior = RegressionPrior.laplace(1.0 / Math.sqrt(2.0), true);
        RegressionPrior gaussianPrior = RegressionPrior.gaussian(Math.sqrt(2.0) / 2.0, true);
        int i = -5;
        while (i < 5) {
            Assert.assertEquals((double)(0.3 * laplacePrior.log2Prior(i, 2) + 0.7 * gaussianPrior.log2Prior(i, 2)), (double)prior.log2Prior(i, 2), (double)1.0E-4);
            Assert.assertEquals((double)(0.3 * laplacePrior.log2Prior(i, 0) + 0.7 * gaussianPrior.log2Prior(i, 0)), (double)prior.log2Prior(i, 0), (double)1.0E-4);
            Assert.assertEquals((double)(0.3 * laplacePrior.gradient(i, 1) + 0.7 * gaussianPrior.gradient(i, 1)), (double)prior.gradient(i, 1), (double)1.0E-4);
            Assert.assertEquals((double)(0.3 * laplacePrior.gradient(i, 0) + 0.7 * gaussianPrior.gradient(i, 0)), (double)prior.gradient(i, 0), (double)1.0E-4);
            ++i;
        }
        RegressionPrior priorNonInt = RegressionPrior.elasticNet(0.3, 2.0, false);
        int i2 = -5;
        while (i2 < 5) {
            Assert.assertEquals((Object)(0.3 * laplacePrior.log2Prior(i2, 2) + 0.7 * gaussianPrior.log2Prior(i2, 2)), (Object)prior.log2Prior(i2, 2));
            Assert.assertEquals((double)0.0, (double)prior.log2Prior(i2, 0), (double)1.0E-4);
            Assert.assertEquals((double)(0.3 * laplacePrior.gradient(i2, 1) + 0.7 * gaussianPrior.gradient(i2, 1)), (double)prior.gradient(i2, 1), (double)1.0E-4);
            Assert.assertEquals((double)0.0, (double)prior.gradient(5.0, 0), (double)1.0E-4);
            ++i2;
        }
    }

    @Test
    public void testMeanOffsets() {
        RegressionPrior basePrior = RegressionPrior.gaussian(1.0, false);
        RegressionPrior prior = RegressionPrior.shiftMeans(new double[]{1.0, -2.0, 3.0}, basePrior);
        Assert.assertEquals((Object)basePrior.log2Prior(0.0, 0), (Object)prior.log2Prior(1.0, 0));
        Assert.assertEquals((Object)basePrior.log2Prior(1.0, 0), (Object)prior.log2Prior(2.0, 0));
        Assert.assertEquals((Object)basePrior.log2Prior(-1.0, 0), (Object)prior.log2Prior(0.0, 0));
        Assert.assertEquals((Object)basePrior.gradient(0.0, 0), (Object)prior.gradient(1.0, 0));
        Assert.assertEquals((Object)basePrior.gradient(1.0, 0), (Object)prior.gradient(2.0, 0));
        Assert.assertEquals((Object)basePrior.gradient(-2.0, 0), (Object)prior.gradient(-1.0, 0));
        Assert.assertEquals((Object)basePrior.log2Prior(3.0, 1), (Object)prior.log2Prior(1.0, 1));
        Assert.assertEquals((Object)basePrior.log2Prior(7.0, 2), (Object)prior.log2Prior(10.0, 2));
        Assert.assertEquals((Object)basePrior.gradient(7.0, 2), (Object)prior.gradient(10.0, 2));
    }

    @Test(expected=IllegalArgumentException.class)
    public void testElasticNetEx1() {
        RegressionPrior.elasticNet(-1.0, 2.0, true);
    }

    @Test(expected=IllegalArgumentException.class)
    public void testElasticNetEx2() {
        RegressionPrior.elasticNet(Double.NaN, 2.0, true);
    }

    @Test(expected=IllegalArgumentException.class)
    public void testElasticNetEx3() {
        RegressionPrior.elasticNet(Double.POSITIVE_INFINITY, 2.0, true);
    }

    @Test(expected=IllegalArgumentException.class)
    public void testElasticNetEx4() {
        RegressionPrior.elasticNet(0.5, -1.0, true);
    }

    @Test(expected=IllegalArgumentException.class)
    public void testElasticNetEx5() {
        RegressionPrior.elasticNet(0.5, Double.NaN, true);
    }

    @Test(expected=IllegalArgumentException.class)
    public void testElasticNetEx6() {
        RegressionPrior.elasticNet(0.5, Double.POSITIVE_INFINITY, true);
    }

    @Test(expected=IllegalArgumentException.class)
    public void testElasticNetEx7() {
        RegressionPrior.elasticNet(0.5, 0.0, true);
    }

    @Test
    public void testSerialization() throws IOException, ClassNotFoundException {
        double[] priorVariances = new double[]{1.0, 2.0, 3.0};
        double priorVariance = 1.0;
        this.assertSerialization(RegressionPrior.shiftMeans(new double[]{1.0, -2.0, 3.0}, RegressionPrior.gaussian(priorVariance, false)), 3);
        this.assertSerialization(RegressionPrior.elasticNet(0.95, 2.0, false), -1);
        this.assertSerialization(RegressionPrior.cauchy(priorVariances), 3);
        this.assertSerialization(RegressionPrior.cauchy(priorVariance, true), -1);
        this.assertSerialization(RegressionPrior.cauchy(priorVariance, false), -1);
        this.assertSerialization(RegressionPrior.gaussian(priorVariances), 3);
        this.assertSerialization(RegressionPrior.gaussian(priorVariance, true), -1);
        this.assertSerialization(RegressionPrior.gaussian(priorVariance, false), -1);
        this.assertSerialization(RegressionPrior.laplace(priorVariances), 3);
        this.assertSerialization(RegressionPrior.laplace(priorVariance, true), -1);
        this.assertSerialization(RegressionPrior.laplace(priorVariance, false), -1);
        this.assertSerialization(RegressionPrior.noninformative(), -1);
    }

    void assertSerialization(RegressionPrior prior, int dimensionality) throws IOException, ClassNotFoundException {
        RegressionPrior prior2 = (RegressionPrior)AbstractExternalizable.serializeDeserialize(prior);
        int i = 0;
        while (i < dimensionality || dimensionality == -1 && i < 10) {
            Assert.assertEquals((double)prior.log2Prior(2.0, i), (double)prior2.log2Prior(2.0, i), (double)1.0E-5);
            Assert.assertEquals((double)prior.log2Prior(-1.0, i), (double)prior2.log2Prior(-1.0, i), (double)1.0E-5);
            Assert.assertEquals((double)prior.gradient(5.0, i), (double)prior2.gradient(5.0, i), (double)1.0E-5);
            Assert.assertEquals((double)prior.gradient(-2.0, i), (double)prior2.gradient(-2.0, i), (double)1.0E-5);
            ++i;
        }
        if (dimensionality > 0) {
            try {
                prior.gradient(2.0, dimensionality + 1);
                Assert.fail();
            }
            catch (ArrayIndexOutOfBoundsException e) {
                Asserts.succeed();
            }
        }
    }
}

