package org.apache.flink.ml.feature.standardscaler;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/feature/standardscaler/StandardScaler.class */
public class StandardScaler implements Estimator<StandardScaler, StandardScalerModel>, StandardScalerParams<StandardScaler> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/feature/standardscaler/StandardScaler$BuildModelOperator.class */
    public static class BuildModelOperator extends AbstractStreamOperator<StandardScalerModelData> implements OneInputStreamOperator<Tuple3<DenseVector, DenseVector, Long>, StandardScalerModelData>, BoundedOneInput {
        private ListState<DenseVector> sumState;
        private ListState<DenseVector> squaredSumState;
        private ListState<Long> numElementsState;
        private DenseVector sum;
        private DenseVector squaredSum;
        private long numElements;

        private BuildModelOperator() {
        }

        public void endInput() {
            if (this.numElements <= 0) {
                throw new RuntimeException("The training set is empty.");
            }
            BLAS.scal(1.0d / this.numElements, this.sum);
            double[] dArr = this.sum.values;
            double[] dArr2 = this.squaredSum.values;
            if (this.numElements > 1) {
                for (int i = 0; i < dArr.length; i++) {
                    dArr2[i] = Math.sqrt((this.squaredSum.values[i] - ((this.numElements * dArr[i]) * dArr[i])) / (this.numElements - 1));
                }
            } else {
                Arrays.fill(dArr2, 0.0d);
            }
            this.output.collect(new StreamRecord(new StandardScalerModelData(Vectors.dense(dArr), Vectors.dense(dArr2))));
        }

        public void processElement(StreamRecord<Tuple3<DenseVector, DenseVector, Long>> streamRecord) {
            Tuple3 tuple3 = (Tuple3) streamRecord.getValue();
            if (this.numElements == 0) {
                this.sum = (DenseVector) tuple3.f0;
                this.squaredSum = (DenseVector) tuple3.f1;
                this.numElements = ((Long) tuple3.f2).longValue();
            } else {
                BLAS.axpy(1.0d, (Vector) tuple3.f0, this.sum);
                BLAS.axpy(1.0d, (Vector) tuple3.f1, this.squaredSum);
                this.numElements += ((Long) tuple3.f2).longValue();
            }
        }

        public void initializeState(StateInitializationContext stateInitializationContext) throws Exception {
            super.initializeState(stateInitializationContext);
            this.sumState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("sumState", TypeInformation.of(DenseVector.class)));
            this.squaredSumState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("squaredSumState", TypeInformation.of(DenseVector.class)));
            this.numElementsState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("numElementsState", BasicTypeInfo.LONG_TYPE_INFO));
            this.sum = (DenseVector) OperatorStateUtils.getUniqueElement(this.sumState, "sumState").orElse(null);
            this.squaredSum = (DenseVector) OperatorStateUtils.getUniqueElement(this.squaredSumState, "squaredSumState").orElse(null);
            this.numElements = ((Long) OperatorStateUtils.getUniqueElement(this.numElementsState, "numElementsState").orElse(0L)).longValue();
        }

        public void snapshotState(StateSnapshotContext stateSnapshotContext) throws Exception {
            super.snapshotState(stateSnapshotContext);
            if (this.numElements > 0) {
                this.sumState.update(Collections.singletonList(this.sum));
                this.squaredSumState.update(Collections.singletonList(this.squaredSum));
                this.numElementsState.update(Collections.singletonList(Long.valueOf(this.numElements)));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/feature/standardscaler/StandardScaler$ComputeMetaOperator.class */
    public static class ComputeMetaOperator extends AbstractStreamOperator<Tuple3<DenseVector, DenseVector, Long>> implements OneInputStreamOperator<Row, Tuple3<DenseVector, DenseVector, Long>>, BoundedOneInput {
        private ListState<DenseVector> sumState;
        private ListState<DenseVector> squaredSumState;
        private ListState<Long> numElementsState;
        private DenseVector sum;
        private DenseVector squaredSum;
        private long numElements;
        private final String inputCol;

        public ComputeMetaOperator(String str) {
            this.inputCol = str;
        }

        public void endInput() {
            if (this.numElements > 0) {
                this.output.collect(new StreamRecord(Tuple3.of(this.sum, this.squaredSum, Long.valueOf(this.numElements))));
            }
        }

        public void processElement(StreamRecord<Row> streamRecord) {
            Vector vector = (Vector) ((Row) streamRecord.getValue()).getField(this.inputCol);
            if (this.numElements == 0) {
                this.sum = new DenseVector(vector.size());
                this.squaredSum = new DenseVector(vector.size());
            }
            BLAS.axpy(1.0d, vector, this.sum);
            BLAS.hDot(vector, vector);
            BLAS.axpy(1.0d, vector, this.squaredSum);
            this.numElements++;
        }

        public void initializeState(StateInitializationContext stateInitializationContext) throws Exception {
            super.initializeState(stateInitializationContext);
            this.sumState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("sumState", TypeInformation.of(DenseVector.class)));
            this.squaredSumState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("squaredSumState", TypeInformation.of(DenseVector.class)));
            this.numElementsState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("numElementsState", BasicTypeInfo.LONG_TYPE_INFO));
            this.sum = (DenseVector) OperatorStateUtils.getUniqueElement(this.sumState, "sumState").orElse(null);
            this.squaredSum = (DenseVector) OperatorStateUtils.getUniqueElement(this.squaredSumState, "squaredSumState").orElse(null);
            this.numElements = ((Long) OperatorStateUtils.getUniqueElement(this.numElementsState, "numElementsState").orElse(0L)).longValue();
        }

        public void snapshotState(StateSnapshotContext stateSnapshotContext) throws Exception {
            super.snapshotState(stateSnapshotContext);
            if (this.numElements > 0) {
                this.sumState.update(Collections.singletonList(this.sum));
                this.squaredSumState.update(Collections.singletonList(this.squaredSum));
                this.numElementsState.update(Collections.singletonList(Long.valueOf(this.numElements)));
            }
        }
    }

    public StandardScaler() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    /* renamed from: fit, reason: merged with bridge method [inline-methods] */
    public StandardScalerModel m62fit(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        StreamTableEnvironment tableEnvironment = ((TableImpl) tableArr[0]).getTableEnvironment();
        StandardScalerModel m63setModelData = new StandardScalerModel().m63setModelData(tableEnvironment.fromDataStream(tableEnvironment.toDataStream(tableArr[0]).transform("computeMeta", new TupleTypeInfo(new TypeInformation[]{TypeInformation.of(DenseVector.class), TypeInformation.of(DenseVector.class), BasicTypeInfo.LONG_TYPE_INFO}), new ComputeMetaOperator(getInputCol())).transform("buildModel", TypeInformation.of(StandardScalerModelData.class), new BuildModelOperator()).setParallelism(1)));
        ReadWriteUtils.updateExistingParams(m63setModelData, this.paramMap);
        return m63setModelData;
    }

    public void save(String str) throws IOException {
        ReadWriteUtils.saveMetadata(this, str);
    }

    public static StandardScaler load(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        return ReadWriteUtils.loadStageParam(str);
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }
}
