package org.apache.flink.ml.classification;

import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.lang3.RandomUtils;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.classification.linearsvc.LinearSVC;
import org.apache.flink.ml.classification.linearsvc.LinearSVCModel;
import org.apache.flink.ml.classification.linearsvc.LinearSVCModelData;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/classification/LinearSVCTest.class */
public class LinearSVCTest extends AbstractTestBase {

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private StreamExecutionEnvironment env;
    private StreamTableEnvironment tEnv;
    private static final List<Row> trainData = Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 2.0d, 3.0d, 4.0d}), Double.valueOf(0.0d), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{2.0d, 2.0d, 3.0d, 4.0d}), Double.valueOf(0.0d), Double.valueOf(2.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{3.0d, 2.0d, 3.0d, 4.0d}), Double.valueOf(0.0d), Double.valueOf(3.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{4.0d, 2.0d, 3.0d, 4.0d}), Double.valueOf(0.0d), Double.valueOf(4.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{5.0d, 2.0d, 3.0d, 4.0d}), Double.valueOf(0.0d), Double.valueOf(5.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{11.0d, 2.0d, 3.0d, 4.0d}), Double.valueOf(1.0d), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{12.0d, 2.0d, 3.0d, 4.0d}), Double.valueOf(1.0d), Double.valueOf(2.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{13.0d, 2.0d, 3.0d, 4.0d}), Double.valueOf(1.0d), Double.valueOf(3.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{14.0d, 2.0d, 3.0d, 4.0d}), Double.valueOf(1.0d), Double.valueOf(4.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{15.0d, 2.0d, 3.0d, 4.0d}), Double.valueOf(1.0d), Double.valueOf(5.0d)}));
    private static final double[] expectedCoefficient = {0.47d, -0.273d, -0.41d, -0.546d};
    private static final double TOLERANCE = 1.0E-7d;
    private Table trainDataTable;

    @Before
    public void before() {
        this.env = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(this.env);
        Collections.shuffle(trainData);
        this.trainDataTable = this.tEnv.fromDataStream(this.env.fromCollection(trainData, new RowTypeInfo(new TypeInformation[]{DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE}, new String[]{"features", "label", "weight"})));
    }

    private void verifyPredictionResult(Table table, String str, String str2, String str3) throws Exception {
        for (Row row : IteratorUtils.toList(this.tEnv.toDataStream(table).executeAndCollect())) {
            DenseVector dense = ((Vector) row.getField(str)).toDense();
            double doubleValue = ((Double) row.getField(str2)).doubleValue();
            DenseVector denseVector = (DenseVector) row.getField(str3);
            if (dense.get(0) <= 5.0d) {
                Assert.assertEquals(0.0d, doubleValue, TOLERANCE);
                Assert.assertTrue(denseVector.get(0) < 0.0d);
            } else {
                Assert.assertEquals(1.0d, doubleValue, TOLERANCE);
                Assert.assertTrue(denseVector.get(0) > 0.0d);
            }
        }
    }

    @Test
    public void testParam() {
        LinearSVC linearSVC = new LinearSVC();
        Assert.assertEquals("features", linearSVC.getFeaturesCol());
        Assert.assertEquals("label", linearSVC.getLabelCol());
        Assert.assertNull(linearSVC.getWeightCol());
        Assert.assertEquals(20L, linearSVC.getMaxIter());
        Assert.assertEquals(1.0E-6d, linearSVC.getTol(), TOLERANCE);
        Assert.assertEquals(0.1d, linearSVC.getLearningRate(), TOLERANCE);
        Assert.assertEquals(32L, linearSVC.getGlobalBatchSize());
        Assert.assertEquals(0.0d, linearSVC.getReg(), TOLERANCE);
        Assert.assertEquals(0.0d, linearSVC.getElasticNet(), TOLERANCE);
        Assert.assertEquals(0.0d, linearSVC.getThreshold().doubleValue(), TOLERANCE);
        Assert.assertEquals("prediction", linearSVC.getPredictionCol());
        Assert.assertEquals("rawPrediction", linearSVC.getRawPredictionCol());
        ((LinearSVC) ((LinearSVC) ((LinearSVC) ((LinearSVC) ((LinearSVC) ((LinearSVC) ((LinearSVC) ((LinearSVC) ((LinearSVC) ((LinearSVC) ((LinearSVC) linearSVC.setFeaturesCol("test_features")).setLabelCol("test_label")).setWeightCol("test_weight")).setMaxIter(1000)).setTol(Double.valueOf(0.001d))).setLearningRate(Double.valueOf(0.5d))).setGlobalBatchSize(1000)).setReg(Double.valueOf(0.1d))).setElasticNet(Double.valueOf(0.5d))).setThreshold(Double.valueOf(0.5d))).setPredictionCol("test_predictionCol")).setRawPredictionCol("test_rawPredictionCol");
        Assert.assertEquals("test_features", linearSVC.getFeaturesCol());
        Assert.assertEquals("test_label", linearSVC.getLabelCol());
        Assert.assertEquals("test_weight", linearSVC.getWeightCol());
        Assert.assertEquals(1000L, linearSVC.getMaxIter());
        Assert.assertEquals(0.001d, linearSVC.getTol(), TOLERANCE);
        Assert.assertEquals(0.5d, linearSVC.getLearningRate(), TOLERANCE);
        Assert.assertEquals(1000L, linearSVC.getGlobalBatchSize());
        Assert.assertEquals(0.1d, linearSVC.getReg(), TOLERANCE);
        Assert.assertEquals(0.5d, linearSVC.getElasticNet(), TOLERANCE);
        Assert.assertEquals(0.5d, linearSVC.getThreshold().doubleValue(), TOLERANCE);
        Assert.assertEquals("test_predictionCol", linearSVC.getPredictionCol());
        Assert.assertEquals("test_rawPredictionCol", linearSVC.getRawPredictionCol());
    }

    @Test
    public void testOutputSchema() {
        Assert.assertEquals(Arrays.asList("test_features", "test_label", "test_weight", "test_predictionCol", "test_rawPredictionCol"), ((LinearSVC) ((LinearSVC) ((LinearSVC) ((LinearSVC) ((LinearSVC) new LinearSVC().setFeaturesCol("test_features")).setLabelCol("test_label")).setWeightCol("test_weight")).setPredictionCol("test_predictionCol")).setRawPredictionCol("test_rawPredictionCol")).fit(new Table[]{this.trainDataTable}).transform(new Table[]{this.trainDataTable.as("test_features", new String[]{"test_label", "test_weight"})})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testFitAndPredict() throws Exception {
        LinearSVC linearSVC = (LinearSVC) new LinearSVC().setWeightCol("weight");
        verifyPredictionResult(linearSVC.fit(new Table[]{this.trainDataTable}).transform(new Table[]{this.trainDataTable})[0], linearSVC.getFeaturesCol(), linearSVC.getPredictionCol(), linearSVC.getRawPredictionCol());
    }

    @Test
    public void testInputTypeConversion() throws Exception {
        this.trainDataTable = TestUtils.convertDataTypesToSparseInt(this.tEnv, this.trainDataTable);
        Assert.assertArrayEquals(new Class[]{SparseVector.class, Integer.class, Integer.class}, TestUtils.getColumnDataTypes(this.trainDataTable));
        LinearSVC linearSVC = (LinearSVC) new LinearSVC().setWeightCol("weight");
        verifyPredictionResult(linearSVC.fit(new Table[]{this.trainDataTable}).transform(new Table[]{this.trainDataTable})[0], linearSVC.getFeaturesCol(), linearSVC.getPredictionCol(), linearSVC.getRawPredictionCol());
    }

    @Test
    public void testSaveLoadAndPredict() throws Exception {
        LinearSVC saveAndReload = TestUtils.saveAndReload(this.tEnv, (LinearSVC) new LinearSVC().setWeightCol("weight"), this.tempFolder.newFolder().getAbsolutePath(), LinearSVC::load);
        LinearSVCModel saveAndReload2 = TestUtils.saveAndReload(this.tEnv, saveAndReload.fit(new Table[]{this.trainDataTable}), this.tempFolder.newFolder().getAbsolutePath(), LinearSVCModel::load);
        Assert.assertEquals(Collections.singletonList("coefficient"), saveAndReload2.getModelData()[0].getResolvedSchema().getColumnNames());
        verifyPredictionResult(saveAndReload2.transform(new Table[]{this.trainDataTable})[0], saveAndReload.getFeaturesCol(), saveAndReload.getPredictionCol(), saveAndReload.getRawPredictionCol());
    }

    @Test
    public void testGetModelData() throws Exception {
        List list = IteratorUtils.toList(LinearSVCModelData.getModelDataStream(((LinearSVC) new LinearSVC().setWeightCol("weight")).fit(new Table[]{this.trainDataTable}).getModelData()[0]).executeAndCollect());
        Assert.assertEquals(1L, list.size());
        Assert.assertArrayEquals(expectedCoefficient, ((LinearSVCModelData) list.get(0)).coefficient.values, 0.1d);
    }

    @Test
    public void testSetModelData() throws Exception {
        LinearSVC linearSVC = (LinearSVC) new LinearSVC().setWeightCol("weight");
        LinearSVCModel fit = linearSVC.fit(new Table[]{this.trainDataTable});
        LinearSVCModel linearSVCModel = new LinearSVCModel();
        ParamUtils.updateExistingParams(linearSVCModel, fit.getParamMap());
        linearSVCModel.setModelData(fit.getModelData());
        verifyPredictionResult(linearSVCModel.transform(new Table[]{this.trainDataTable})[0], linearSVC.getFeaturesCol(), linearSVC.getPredictionCol(), linearSVC.getRawPredictionCol());
    }

    @Test
    public void testMoreSubtaskThanData() throws Exception {
        Table fromDataStream = this.tEnv.fromDataStream(this.env.fromCollection(Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 2.0d, 3.0d, 4.0d}), Double.valueOf(0.0d), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{11.0d, 2.0d, 3.0d, 4.0d}), Double.valueOf(1.0d), Double.valueOf(1.0d)})), new RowTypeInfo(new TypeInformation[]{DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE}, new String[]{"features", "label", "weight"})));
        LinearSVC linearSVC = (LinearSVC) ((LinearSVC) new LinearSVC().setWeightCol("weight")).setGlobalBatchSize(128);
        verifyPredictionResult(linearSVC.fit(new Table[]{fromDataStream}).transform(new Table[]{fromDataStream})[0], linearSVC.getFeaturesCol(), linearSVC.getPredictionCol(), linearSVC.getRawPredictionCol());
    }

    @Test
    public void testRegularization() throws Exception {
        checkRegularization(0.0d, RandomUtils.nextDouble(0.0d, 1.0d), expectedCoefficient);
        checkRegularization(0.1d, 0.0d, new double[]{0.437d, -0.262d, -0.393d, -0.524d});
        checkRegularization(0.1d, 1.0d, new double[]{0.426d, -0.197d, -0.329d, -0.463d});
        checkRegularization(0.1d, 0.5d, new double[]{0.419d, -0.238d, -0.372d, -0.505d});
    }

    @Test
    public void testThreshold() throws Exception {
        checkThreshold(-1.7976931348623157E308d, 1.0d);
        checkThreshold(Double.MAX_VALUE, 0.0d);
    }

    private void checkRegularization(double d, double d2, double[] dArr) throws Exception {
        Assert.assertArrayEquals(dArr, ((LinearSVCModelData) IteratorUtils.toList(LinearSVCModelData.getModelDataStream(((LinearSVC) ((LinearSVC) ((LinearSVC) new LinearSVC().setWeightCol("weight")).setReg(Double.valueOf(d))).setElasticNet(Double.valueOf(d2))).fit(new Table[]{this.trainDataTable}).getModelData()[0]).executeAndCollect()).get(0)).coefficient.values, 0.001d);
    }

    private void checkThreshold(double d, double d2) throws Exception {
        LinearSVC linearSVC = (LinearSVC) new LinearSVC().setWeightCol("weight");
        Iterator it = IteratorUtils.toList(this.tEnv.toDataStream(((LinearSVC) linearSVC.setThreshold(Double.valueOf(d))).fit(new Table[]{this.trainDataTable}).transform(new Table[]{this.trainDataTable})[0]).executeAndCollect()).iterator();
        while (it.hasNext()) {
            Assert.assertEquals(Double.valueOf(d2), ((Row) it.next()).getField(linearSVC.getPredictionCol()));
        }
    }
}
