package hex.genmodel.algos.gbm;

import com.google.common.io.ByteStreams;
import hex.genmodel.ModelMojoReader;
import hex.genmodel.MojoReaderBackend;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.prediction.BinomialModelPrediction;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:hex/genmodel/algos/gbm/GbmMojoModelTest.class */
public class GbmMojoModelTest {
    private GbmMojoModel mojo;

    /* loaded from: input_file:hex/genmodel/algos/gbm/GbmMojoModelTest$ClasspathReaderBackend.class */
    private static class ClasspathReaderBackend implements MojoReaderBackend {
        private ClasspathReaderBackend() {
        }

        public BufferedReader getTextFile(String str) throws IOException {
            return new BufferedReader(new InputStreamReader(GbmMojoModelTest.class.getResourceAsStream("calibrated/" + str)));
        }

        public byte[] getBinaryFile(String str) throws IOException {
            return ByteStreams.toByteArray(GbmMojoModelTest.class.getResourceAsStream("calibrated/" + str));
        }

        public boolean exists(String str) {
            return true;
        }
    }

    @Before
    public void setup() throws Exception {
        this.mojo = ModelMojoReader.readFrom(new ClasspathReaderBackend());
        Assert.assertNotNull(this.mojo);
    }

    @Test
    public void testScore0() throws Exception {
        Assert.assertArrayEquals(new double[]{1.0d, 0.5416688d, 0.4583312d}, this.mojo.score0(new double[]{18.7d, 1.51d, 1.003d, 132.53d, 1.15d, 0.2d, 1.153d, 8.3d, 0.34d, 0.0d, 0.0d}, new double[3]), 1.0E-5d);
    }

    @Test
    public void scoreSingleTree() throws Exception {
        double[] dArr = {18.7d, 1.51d, 1.003d, 132.53d, 1.15d, 0.2d, 1.153d, 8.3d, 0.34d, 0.0d, 0.0d};
        for (int i = 0; i < 10; i++) {
            double[] dArr2 = new double[3];
            this.mojo.scoreSingleTree(dArr, i, dArr2);
            double[] dArr3 = new double[3];
            this.mojo.scoreTreeRange(dArr, i, i + 1, dArr3);
            Assert.assertArrayEquals(dArr3, dArr2, 0.0d);
        }
    }

    @Test
    public void testScoreTreeRange() throws Exception {
        double[] dArr = {18.7d, 1.51d, 1.003d, 132.53d, 1.15d, 0.2d, 1.153d, 8.3d, 0.34d, 0.0d, 0.0d};
        double[] dArr2 = new double[3];
        for (int i = 0; i < 10; i++) {
            double[] dArr3 = new double[dArr2.length];
            this.mojo.scoreTreeRange(dArr, i, i + 1, dArr3);
            for (int i2 = 0; i2 < dArr2.length; i2++) {
                int i3 = i2;
                dArr2[i3] = dArr2[i3] + dArr3[i2];
            }
        }
        this.mojo.unifyPreds(dArr, 0.0d, dArr2);
        Assert.assertArrayEquals(this.mojo.score0(dArr, new double[dArr2.length]), dArr2, 1.0E-8d);
    }

    @Test
    public void testPredict() throws Exception {
        BinomialModelPrediction predict = new EasyPredictModelWrapper(this.mojo).predict(new RowData() { // from class: hex.genmodel.algos.gbm.GbmMojoModelTest.1
            {
                put("SegSumT", Double.valueOf(18.7d));
                put("SegTSeas", Double.valueOf(1.51d));
                put("SegLowFlow", Double.valueOf(1.003d));
                put("DSDist", Double.valueOf(132.53d));
                put("DSMaxSlope", Double.valueOf(1.15d));
                put("USAvgT", Double.valueOf(0.2d));
                put("USRainDays", Double.valueOf(1.153d));
                put("USSlope", Double.valueOf(8.3d));
                put("USNative", Double.valueOf(0.34d));
                put("DSDam", Double.valueOf(0.0d));
                put("Method", "electric");
            }
        });
        Assert.assertEquals(1L, predict.labelIndex);
        Assert.assertEquals("1", predict.label);
        Assert.assertArrayEquals(new double[]{0.5416688d, 0.4583312d}, predict.classProbabilities, 1.0E-5d);
        Assert.assertArrayEquals(new double[]{0.3920402d, 0.6079598d}, predict.calibratedClassProbabilities, 1.0E-5d);
    }
}
