package org.apache.flink.ml.feature.binarizer;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.api.Transformer;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/feature/binarizer/Binarizer.class */
public class Binarizer implements Transformer<Binarizer>, BinarizerParams<Binarizer> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    /* loaded from: input_file:org/apache/flink/ml/feature/binarizer/Binarizer$BinarizeFunction.class */
    private static class BinarizeFunction implements MapFunction<Row, Row> {
        private final String[] inputCols;
        private final Double[] thresholds;

        public BinarizeFunction(String[] strArr, Double[] dArr) {
            this.inputCols = strArr;
            this.thresholds = dArr;
        }

        public Row map(Row row) {
            if (null == row) {
                return null;
            }
            Row row2 = new Row(this.inputCols.length);
            for (int i = 0; i < this.inputCols.length; i++) {
                row2.setField(i, binarizerFunc(row.getField(this.inputCols[i]), this.thresholds[i].doubleValue()));
            }
            return Row.join(row, new Row[]{row2});
        }

        private Object binarizerFunc(Object obj, double d) {
            if (obj instanceof DenseVector) {
                DenseVector denseVector = (DenseVector) obj;
                DenseVector m183clone = denseVector.m183clone();
                for (int i = 0; i < m183clone.size(); i++) {
                    m183clone.values[i] = denseVector.get(i) > d ? 1.0d : 0.0d;
                }
                return m183clone;
            }
            if (!(obj instanceof SparseVector)) {
                return Double.valueOf(Double.parseDouble(obj.toString()) > d ? 1.0d : 0.0d);
            }
            SparseVector sparseVector = (SparseVector) obj;
            int[] iArr = new int[sparseVector.indices.length];
            int i2 = 0;
            for (int i3 = 0; i3 < sparseVector.indices.length; i3++) {
                if (sparseVector.values[i3] > d) {
                    int i4 = i2;
                    i2++;
                    iArr[i4] = sparseVector.indices[i3];
                }
            }
            double[] dArr = new double[i2];
            Arrays.fill(dArr, 1.0d);
            return new SparseVector(sparseVector.size(), Arrays.copyOf(iArr, i2), dArr);
        }
    }

    public Binarizer() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    @Override // org.apache.flink.ml.api.AlgoOperator
    public Table[] transform(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        StreamTableEnvironment tableEnvironment = ((TableImpl) tableArr[0]).getTableEnvironment();
        RowTypeInfo rowTypeInfo = TableUtils.getRowTypeInfo(tableArr[0].getResolvedSchema());
        String[] inputCols = getInputCols();
        Preconditions.checkArgument(inputCols.length == getThresholds().length);
        TypeInformation[] typeInformationArr = new TypeInformation[inputCols.length];
        for (int i = 0; i < inputCols.length; i++) {
            Class typeClass = rowTypeInfo.getTypeAt(rowTypeInfo.getFieldIndex(inputCols[i])).getTypeClass();
            if (typeClass.equals(SparseVector.class)) {
                typeInformationArr[i] = SparseVectorTypeInfo.INSTANCE;
            } else if (typeClass.equals(DenseVector.class)) {
                typeInformationArr[i] = DenseVectorTypeInfo.INSTANCE;
            } else if (typeClass.equals(Vector.class)) {
                typeInformationArr[i] = VectorTypeInfo.INSTANCE;
            } else {
                typeInformationArr[i] = Types.DOUBLE;
            }
        }
        return new Table[]{tableEnvironment.fromDataStream(tableEnvironment.toDataStream(tableArr[0]).map(new BinarizeFunction(inputCols, getThresholds()), new RowTypeInfo((TypeInformation[]) ArrayUtils.addAll(rowTypeInfo.getFieldTypes(), typeInformationArr), (String[]) ArrayUtils.addAll(rowTypeInfo.getFieldNames(), getOutputCols()))))};
    }

    @Override // org.apache.flink.ml.api.Stage
    public void save(String str) throws IOException {
        ReadWriteUtils.saveMetadata(this, str);
    }

    public static Binarizer load(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        return (Binarizer) ReadWriteUtils.loadStageParam(str);
    }

    @Override // org.apache.flink.ml.param.WithParams
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }
}
