package hex.genmodel.algos.kmeans;

import com.google.common.io.ByteStreams;
import hex.genmodel.MojoModel;
import hex.genmodel.MojoReaderBackend;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:hex/genmodel/algos/kmeans/KMeansMojoModelTest.class */
public class KMeansMojoModelTest {
    private MojoModel _mojo;
    private double[][] _rows;
    private RowData[] _rowData;

    /* loaded from: input_file:hex/genmodel/algos/kmeans/KMeansMojoModelTest$ClasspathReaderBackend.class */
    private static class ClasspathReaderBackend implements MojoReaderBackend {
        private ClasspathReaderBackend() {
        }

        public BufferedReader getTextFile(String str) throws IOException {
            return new BufferedReader(new InputStreamReader(KMeansMojoModelTest.class.getResourceAsStream(str)));
        }

        public byte[] getBinaryFile(String str) throws IOException {
            return ByteStreams.toByteArray(KMeansMojoModelTest.class.getResourceAsStream(str));
        }

        public boolean exists(String str) {
            return true;
        }
    }

    /* JADX WARN: Type inference failed for: r1v3, types: [double[], double[][]] */
    @Before
    public void setup() throws IOException {
        this._mojo = KMeansMojoReader.readFrom(new ClasspathReaderBackend());
        this._rows = new double[]{new double[]{2.0d, 1.0d, 22.0d, 1.0d, 0.0d}, new double[]{2.0d, 1.0d, 2.0d, 3.0d, 1.0d}, new double[]{2.0d, 0.0d, 27.0d, 0.0d, 2.0d}};
        this._rowData = new RowData[this._rows.length];
        for (int i = 0; i < this._rows.length; i++) {
            this._rowData[i] = toRowData(this._mojo, this._rows[i]);
        }
    }

    @Test
    public void testPredict() throws Exception {
        EasyPredictModelWrapper easyPredictModelWrapper = new EasyPredictModelWrapper(this._mojo);
        for (int i = 0; i < 3; i++) {
            Assert.assertEquals(i, easyPredictModelWrapper.predict(this._rowData[i]).cluster);
            double[] dArr = new double[1];
            this._mojo.score0(this._rows[i], dArr);
            Assert.assertEquals(i, dArr[0], 0.0d);
        }
    }

    private static RowData toRowData(MojoModel mojoModel, double[] dArr) {
        RowData rowData = new RowData();
        for (String str : mojoModel._names) {
            int colIdx = mojoModel.getColIdx(str);
            String[] domainValues = mojoModel.getDomainValues(colIdx);
            if (domainValues != null) {
                rowData.put(str, domainValues[(int) dArr[colIdx]]);
            } else {
                rowData.put(str, Double.valueOf(dArr[colIdx]));
            }
        }
        return rowData;
    }
}
