package org.apache.flink.ml.feature.standardscaler;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.ml.api.Model;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/feature/standardscaler/OnlineStandardScalerModel.class */
public class OnlineStandardScalerModel implements Model<OnlineStandardScalerModel>, OnlineStandardScalerModelParams<OnlineStandardScalerModel> {
    private final Map<Param<?>, Object> paramMap = new HashMap();
    private Table modelDataTable;

    /* loaded from: input_file:org/apache/flink/ml/feature/standardscaler/OnlineStandardScalerModel$PredictionOperator.class */
    private static class PredictionOperator extends AbstractStreamOperator<Row> implements TwoInputStreamOperator<Row, StandardScalerModelData, Row> {
        private final RowTypeInfo inputTypeInfo;
        private final String inputCol;
        private final boolean withMean;
        private final boolean withStd;
        private final long maxAllowedModelDelayMs;
        private final String modelVersionCol;
        private ListState<StreamRecord> bufferedPointsState;
        private ListState<StandardScalerModelData> modelDataState;
        private StandardScalerModelData modelData;
        private DenseVector mean;
        private DenseVector scale;
        private long modelVersion;
        private long modelTimeStamp;

        public PredictionOperator(RowTypeInfo rowTypeInfo, String str, boolean z, boolean z2, long j, String str2) {
            this.inputTypeInfo = rowTypeInfo;
            this.inputCol = str;
            this.withMean = z;
            this.withStd = z2;
            this.maxAllowedModelDelayMs = j;
            this.modelVersionCol = str2;
        }

        public void initializeState(StateInitializationContext stateInitializationContext) throws Exception {
            super.initializeState(stateInitializationContext);
            this.bufferedPointsState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("bufferedPoints", new StreamElementSerializer(this.inputTypeInfo.createSerializer(getExecutionConfig()))));
            this.modelDataState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("modelData", TypeInformation.of(StandardScalerModelData.class)));
            this.modelData = (StandardScalerModelData) OperatorStateUtils.getUniqueElement(this.modelDataState, "modelData").orElse(null);
            if (this.modelData != null) {
                initializeModelData(this.modelData);
            } else {
                this.modelTimeStamp = -1L;
                this.modelVersion = -1L;
            }
        }

        public void snapshotState(StateSnapshotContext stateSnapshotContext) throws Exception {
            super.snapshotState(stateSnapshotContext);
            if (this.modelData != null) {
                this.modelDataState.clear();
                this.modelDataState.add(this.modelData);
            }
        }

        public void open() throws Exception {
            super.open();
            MetricGroup addGroup = getRuntimeContext().getMetricGroup().addGroup("ml").addGroup("model", OnlineStandardScalerModel.class.getSimpleName());
            addGroup.gauge("timestamp", () -> {
                return Long.valueOf(this.modelTimeStamp);
            });
            addGroup.gauge("version", () -> {
                return Long.valueOf(this.modelVersion);
            });
        }

        public void processElement1(StreamRecord<Row> streamRecord) throws Exception {
            if (streamRecord.getTimestamp() - this.maxAllowedModelDelayMs > this.modelTimeStamp || this.mean == null) {
                this.bufferedPointsState.add(streamRecord);
            } else {
                doPrediction(streamRecord);
            }
        }

        public void processElement2(StreamRecord<StandardScalerModelData> streamRecord) throws Exception {
            this.modelData = (StandardScalerModelData) streamRecord.getValue();
            initializeModelData(this.modelData);
            ArrayList arrayList = new ArrayList();
            boolean z = false;
            for (StreamRecord<Row> streamRecord2 : (Iterable) this.bufferedPointsState.get()) {
                if (streamRecord2.getTimestamp() - this.maxAllowedModelDelayMs <= this.modelTimeStamp) {
                    doPrediction(streamRecord2);
                    z = true;
                } else {
                    arrayList.add(streamRecord2);
                }
            }
            if (z) {
                this.bufferedPointsState.clear();
                if (arrayList.size() > 0) {
                    this.bufferedPointsState.update(arrayList);
                }
            }
        }

        private void initializeModelData(StandardScalerModelData standardScalerModelData) {
            this.modelTimeStamp = standardScalerModelData.timestamp;
            this.modelVersion = standardScalerModelData.version;
            this.mean = standardScalerModelData.mean;
            DenseVector denseVector = standardScalerModelData.std;
            if (this.withStd) {
                this.scale = denseVector;
                double[] dArr = this.scale.values;
                for (int i = 0; i < dArr.length; i++) {
                    dArr[i] = dArr[i] == 0.0d ? 0.0d : 1.0d / dArr[i];
                }
            }
        }

        private void doPrediction(StreamRecord<Row> streamRecord) {
            Row row = (Row) streamRecord.getValue();
            DenseVector clone = ((Vector) Objects.requireNonNull(row.getField(this.inputCol))).clone();
            if (this.withMean) {
                clone = clone.toDense();
                BLAS.axpy(-1.0d, this.mean, clone);
            }
            if (this.withStd) {
                BLAS.hDot(this.scale, clone);
            }
            if (this.modelVersionCol == null) {
                this.output.collect(new StreamRecord(Row.join(row, new Row[]{Row.of(new Object[]{clone})}), streamRecord.getTimestamp()));
            } else {
                this.output.collect(new StreamRecord(Row.join(row, new Row[]{Row.of(new Object[]{clone, Long.valueOf(this.modelVersion)})}), streamRecord.getTimestamp()));
            }
        }
    }

    public OnlineStandardScalerModel() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    public Table[] transform(Table... tableArr) {
        TypeInformation[] typeInformationArr;
        String[] strArr;
        Preconditions.checkArgument(tableArr.length == 1);
        StreamTableEnvironment tableEnvironment = ((TableImpl) tableArr[0]).getTableEnvironment();
        RowTypeInfo rowTypeInfo = TableUtils.getRowTypeInfo(tableArr[0].getResolvedSchema());
        String modelVersionCol = getModelVersionCol();
        if (modelVersionCol == null) {
            typeInformationArr = (TypeInformation[]) ArrayUtils.addAll(rowTypeInfo.getFieldTypes(), new TypeInformation[]{VectorTypeInfo.INSTANCE});
            strArr = (String[]) ArrayUtils.addAll(rowTypeInfo.getFieldNames(), new String[]{getOutputCol()});
        } else {
            typeInformationArr = (TypeInformation[]) ArrayUtils.addAll(rowTypeInfo.getFieldTypes(), new TypeInformation[]{VectorTypeInfo.INSTANCE, Types.LONG});
            strArr = (String[]) ArrayUtils.addAll(rowTypeInfo.getFieldNames(), new String[]{getOutputCol(), modelVersionCol});
        }
        return new Table[]{tableEnvironment.fromDataStream(tableEnvironment.toDataStream(tableArr[0]).connect(StandardScalerModelData.getModelDataStream(this.modelDataTable).broadcast()).transform("PredictionOperator", new RowTypeInfo(typeInformationArr, strArr), new PredictionOperator(rowTypeInfo, getInputCol(), getWithMean().booleanValue(), getWithStd().booleanValue(), getMaxAllowedModelDelayMs(), getModelVersionCol())))};
    }

    public void save(String str) throws IOException {
        ReadWriteUtils.saveMetadata(this, str);
    }

    public static OnlineStandardScalerModel load(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        return ReadWriteUtils.loadStageParam(str);
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    /* renamed from: setModelData, reason: merged with bridge method [inline-methods] */
    public OnlineStandardScalerModel m102setModelData(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        this.modelDataTable = tableArr[0];
        return this;
    }

    public Table[] getModelData() {
        return new Table[]{this.modelDataTable};
    }
}
