package org.apache.flink.ml.feature.elementwiseproduct;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.api.Transformer;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/feature/elementwiseproduct/ElementwiseProduct.class */
public class ElementwiseProduct implements Transformer<ElementwiseProduct>, ElementwiseProductParams<ElementwiseProduct> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    /* loaded from: input_file:org/apache/flink/ml/feature/elementwiseproduct/ElementwiseProduct$ElementwiseProductFunction.class */
    private static class ElementwiseProductFunction implements MapFunction<Row, Row> {
        private final String inputCol;
        private final Vector scalingVec;

        public ElementwiseProductFunction(String str, Vector vector) {
            this.inputCol = str;
            this.scalingVec = vector;
        }

        public Row map(Row row) {
            Vector vector = (Vector) row.getFieldAs(this.inputCol);
            if (vector == null) {
                return Row.join(row, new Row[]{Row.of(new Object[]{null})});
            }
            if (this.scalingVec.size() != vector.size()) {
                throw new IllegalArgumentException("The scaling vector size is " + this.scalingVec.size() + ", which is not equal input vector size(" + vector.size() + ").");
            }
            Vector m183clone = vector.m183clone();
            BLAS.hDot(this.scalingVec, m183clone);
            return Row.join(row, new Row[]{Row.of(new Object[]{m183clone})});
        }
    }

    public ElementwiseProduct() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    @Override // org.apache.flink.ml.api.AlgoOperator
    public Table[] transform(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        StreamTableEnvironment tableEnvironment = ((TableImpl) tableArr[0]).getTableEnvironment();
        RowTypeInfo rowTypeInfo = TableUtils.getRowTypeInfo(tableArr[0].getResolvedSchema());
        return new Table[]{tableEnvironment.fromDataStream(tableEnvironment.toDataStream(tableArr[0]).map(new ElementwiseProductFunction(getInputCol(), getScalingVec()), new RowTypeInfo((TypeInformation[]) ArrayUtils.addAll(rowTypeInfo.getFieldTypes(), new TypeInformation[]{VectorTypeInfo.INSTANCE}), (String[]) ArrayUtils.addAll(rowTypeInfo.getFieldNames(), new String[]{getOutputCol()}))))};
    }

    @Override // org.apache.flink.ml.api.Stage
    public void save(String str) throws IOException {
        ReadWriteUtils.saveMetadata(this, str);
    }

    public static ElementwiseProduct load(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        return (ElementwiseProduct) ReadWriteUtils.loadStageParam(str);
    }

    @Override // org.apache.flink.ml.param.WithParams
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }
}
