package org.apache.flink.ml.feature.imputer;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.util.QuantileSummary;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/feature/imputer/Imputer.class */
public class Imputer implements Estimator<Imputer, ImputerModel>, ImputerParams<Imputer> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/feature/imputer/Imputer$MeanStrategyAggregator.class */
    public static class MeanStrategyAggregator implements AggregateFunction<Row, Map<String, Tuple2<Double, Long>>, ImputerModelData> {
        private final String[] columnNames;
        private final double missingValue;

        public MeanStrategyAggregator(String[] strArr, double d) {
            this.columnNames = strArr;
            this.missingValue = d;
        }

        /* renamed from: createAccumulator, reason: merged with bridge method [inline-methods] */
        public Map<String, Tuple2<Double, Long>> m137createAccumulator() {
            HashMap hashMap = new HashMap();
            Arrays.stream(this.columnNames).forEach(str -> {
                hashMap.put(str, Tuple2.of(Double.valueOf(0.0d), 0L));
            });
            return hashMap;
        }

        public Map<String, Tuple2<Double, Long>> add(Row row, Map<String, Tuple2<Double, Long>> map) {
            map.forEach((str, tuple2) -> {
                Object field = row.getField(str);
                if (field != null) {
                    Double valueOf = Double.valueOf(field.toString());
                    if (valueOf.equals(Double.valueOf(this.missingValue)) || valueOf.equals(Double.valueOf(Double.NaN))) {
                        return;
                    }
                    tuple2.f0 = Double.valueOf(((Double) tuple2.f0).doubleValue() + valueOf.doubleValue());
                    tuple2.f1 = Long.valueOf(((Long) tuple2.f1).longValue() + 1);
                }
            });
            return map;
        }

        public ImputerModelData getResult(Map<String, Tuple2<Double, Long>> map) {
            Preconditions.checkState(((Long) map.entrySet().stream().findFirst().get().getValue().f1).longValue() > 0, "The training set is empty or does not contains valid data.");
            HashMap hashMap = new HashMap();
            map.forEach((str, tuple2) -> {
                hashMap.put(str, Double.valueOf(((Double) tuple2.f0).doubleValue() / ((Long) tuple2.f1).longValue()));
            });
            return new ImputerModelData(hashMap);
        }

        public Map<String, Tuple2<Double, Long>> merge(Map<String, Tuple2<Double, Long>> map, Map<String, Tuple2<Double, Long>> map2) {
            Preconditions.checkArgument(map.size() == map2.size());
            map.forEach((str, tuple2) -> {
                Tuple2 tuple2 = (Tuple2) map2.get(str);
                tuple2.f0 = Double.valueOf(((Double) tuple2.f0).doubleValue() + ((Double) tuple2.f0).doubleValue());
                Tuple2 tuple22 = (Tuple2) map2.get(str);
                tuple22.f1 = Long.valueOf(((Long) tuple22.f1).longValue() + ((Long) tuple2.f1).longValue());
            });
            return map2;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/feature/imputer/Imputer$MedianStrategyAggregator.class */
    public static class MedianStrategyAggregator implements AggregateFunction<Row, Map<String, QuantileSummary>, ImputerModelData> {
        private final String[] columnNames;
        private final double missingValue;
        private final double relativeError;

        public MedianStrategyAggregator(String[] strArr, double d, double d2) {
            this.columnNames = strArr;
            this.missingValue = d;
            this.relativeError = d2;
        }

        /* renamed from: createAccumulator, reason: merged with bridge method [inline-methods] */
        public Map<String, QuantileSummary> m138createAccumulator() {
            HashMap hashMap = new HashMap();
            Arrays.stream(this.columnNames).forEach(str -> {
                hashMap.put(str, new QuantileSummary(this.relativeError));
            });
            return hashMap;
        }

        public Map<String, QuantileSummary> add(Row row, Map<String, QuantileSummary> map) {
            map.forEach((str, quantileSummary) -> {
                Object field = row.getField(str);
                if (field != null) {
                    Double valueOf = Double.valueOf(field.toString());
                    if (valueOf.equals(Double.valueOf(this.missingValue)) || valueOf.equals(Double.valueOf(Double.NaN))) {
                        return;
                    }
                    quantileSummary.insert(valueOf.doubleValue());
                }
            });
            return map;
        }

        public ImputerModelData getResult(Map<String, QuantileSummary> map) {
            HashMap hashMap = new HashMap();
            map.forEach((str, quantileSummary) -> {
                QuantileSummary compress = quantileSummary.compress();
                if (compress.isEmpty()) {
                    throw new FlinkRuntimeException(String.format("Surrogate cannot be computed. All the values in column [%s] are null, NaN or missingValue.", str));
                }
                hashMap.put(str, Double.valueOf(compress.query(0.5d)));
            });
            return new ImputerModelData(hashMap);
        }

        public Map<String, QuantileSummary> merge(Map<String, QuantileSummary> map, Map<String, QuantileSummary> map2) {
            Preconditions.checkArgument(map.size() == map2.size());
            map.forEach((str, quantileSummary) -> {
                map2.put(str, ((QuantileSummary) map2.get(str)).compress().merge(quantileSummary.compress()));
            });
            return map2;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/feature/imputer/Imputer$MostFrequentStrategyAggregator.class */
    public static class MostFrequentStrategyAggregator implements AggregateFunction<Row, Map<String, Map<Double, Long>>, ImputerModelData> {
        private final String[] columnNames;
        private final double missingValue;

        public MostFrequentStrategyAggregator(String[] strArr, double d) {
            this.columnNames = strArr;
            this.missingValue = d;
        }

        /* renamed from: createAccumulator, reason: merged with bridge method [inline-methods] */
        public Map<String, Map<Double, Long>> m139createAccumulator() {
            HashMap hashMap = new HashMap();
            Arrays.stream(this.columnNames).forEach(str -> {
                hashMap.put(str, new HashMap());
            });
            return hashMap;
        }

        public Map<String, Map<Double, Long>> add(Row row, Map<String, Map<Double, Long>> map) {
            map.forEach((str, map2) -> {
                Object field = row.getField(str);
                if (field != null) {
                    Double valueOf = Double.valueOf(field.toString());
                    if (valueOf.equals(Double.valueOf(this.missingValue)) || valueOf.equals(Double.valueOf(Double.NaN))) {
                        return;
                    }
                    if (map2.containsKey(valueOf)) {
                        map2.put(valueOf, Long.valueOf(((Long) map2.get(valueOf)).longValue() + 1));
                    } else {
                        map2.put(valueOf, 1L);
                    }
                }
            });
            return map;
        }

        public ImputerModelData getResult(Map<String, Map<Double, Long>> map) {
            Preconditions.checkState(map.entrySet().stream().filter(entry -> {
                return ((Map) entry.getValue()).size() > 0;
            }).count() > 0, "The training set is empty or does not contains valid data.");
            HashMap hashMap = new HashMap();
            map.forEach((str, map2) -> {
                long j = Long.MIN_VALUE;
                double d = Double.NaN;
                for (Map.Entry entry2 : map2.entrySet()) {
                    if (j <= ((Long) entry2.getValue()).longValue()) {
                        d = j == ((Long) entry2.getValue()).longValue() ? Math.min(((Double) entry2.getKey()).doubleValue(), d) : ((Double) entry2.getKey()).doubleValue();
                        j = ((Long) entry2.getValue()).longValue();
                    }
                }
                hashMap.put(str, Double.valueOf(d));
            });
            return new ImputerModelData(hashMap);
        }

        public Map<String, Map<Double, Long>> merge(Map<String, Map<Double, Long>> map, Map<String, Map<Double, Long>> map2) {
            Preconditions.checkArgument(map.size() == map2.size());
            map.forEach((str, map3) -> {
                Map map3 = (Map) map2.get(str);
                map3.forEach((d, l) -> {
                    if (map3.containsKey(d)) {
                        map3.put(d, Long.valueOf(l.longValue() + ((Long) map3.get(d)).longValue()));
                    } else {
                        map3.put(d, l);
                    }
                });
            });
            return map2;
        }
    }

    public Imputer() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.flink.ml.api.Estimator
    public ImputerModel fit(Table... tableArr) {
        DataStream aggregate;
        Preconditions.checkArgument(tableArr.length == 1);
        Preconditions.checkArgument(getInputCols().length == getOutputCols().length, "Num of input columns and output columns are inconsistent.");
        StreamTableEnvironment tableEnvironment = ((TableImpl) tableArr[0]).getTableEnvironment();
        DataStream dataStream = tableEnvironment.toDataStream(tableArr[0]);
        String strategy = getStrategy();
        boolean z = -1;
        switch (strategy.hashCode()) {
            case -1078031094:
                if (strategy.equals(ImputerParams.MEDIAN)) {
                    z = true;
                    break;
                }
                break;
            case 3347397:
                if (strategy.equals(ImputerParams.MEAN)) {
                    z = false;
                    break;
                }
                break;
            case 574622730:
                if (strategy.equals(ImputerParams.MOST_FREQUENT)) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                aggregate = DataStreamUtils.aggregate(dataStream, new MeanStrategyAggregator(getInputCols(), getMissingValue()), Types.MAP(Types.STRING, Types.TUPLE(new TypeInformation[]{Types.DOUBLE, Types.LONG})), ImputerModelData.TYPE_INFO);
                break;
            case true:
                aggregate = DataStreamUtils.aggregate(dataStream, new MedianStrategyAggregator(getInputCols(), getMissingValue(), getRelativeError()), Types.MAP(Types.STRING, TypeInformation.of(QuantileSummary.class)), ImputerModelData.TYPE_INFO);
                break;
            case true:
                aggregate = DataStreamUtils.aggregate(dataStream, new MostFrequentStrategyAggregator(getInputCols(), getMissingValue()), Types.MAP(Types.STRING, Types.MAP(Types.DOUBLE, Types.LONG)), ImputerModelData.TYPE_INFO);
                break;
            default:
                throw new RuntimeException("Unsupported strategy of Imputer: " + getStrategy());
        }
        ImputerModel modelData = new ImputerModel().setModelData(tableEnvironment.fromDataStream(aggregate, Schema.newBuilder().column("surrogates", DataTypes.MAP(DataTypes.STRING(), DataTypes.DOUBLE())).build()));
        ParamUtils.updateExistingParams(modelData, getParamMap());
        return modelData;
    }

    @Override // org.apache.flink.ml.param.WithParams
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    @Override // org.apache.flink.ml.api.Stage
    public void save(String str) throws IOException {
        ReadWriteUtils.saveMetadata(this, str);
    }

    public static Imputer load(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        return (Imputer) ReadWriteUtils.loadStageParam(str);
    }
}
