package org.apache.flink.ml.classification.naivebayes;

import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/classification/naivebayes/NaiveBayes.class */
public class NaiveBayes implements Estimator<NaiveBayes, NaiveBayesModel>, NaiveBayesParams<NaiveBayes> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/classification/naivebayes/NaiveBayes$AggregateIntoArrayFunction.class */
    public static class AggregateIntoArrayFunction implements MapPartitionFunction<Tuple4<Double, Integer, Map<Double, Double>, Integer>, Tuple3<Double, Integer, Map<Double, Double>[]>> {
        private AggregateIntoArrayFunction() {
        }

        public void mapPartition(Iterable<Tuple4<Double, Integer, Map<Double, Double>, Integer>> iterable, Collector<Tuple3<Double, Integer, Map<Double, Double>[]>> collector) {
            HashMap hashMap = new HashMap();
            for (Tuple4<Double, Integer, Map<Double, Double>, Integer> tuple4 : iterable) {
                ((List) hashMap.computeIfAbsent(tuple4.f0, d -> {
                    return new ArrayList();
                })).add(tuple4);
            }
            for (List<Tuple4> list : hashMap.values()) {
                int intValue = ((Integer) list.stream().map(tuple42 -> {
                    return (Integer) tuple42.f1;
                }).max((v0, v1) -> {
                    return v0.compareTo(v1);
                }).orElse(-1)).intValue() + 1;
                Preconditions.checkArgument(((Integer) list.stream().map(tuple43 -> {
                    return (Integer) tuple43.f3;
                }).min((v0, v1) -> {
                    return v0.compareTo(v1);
                }).orElse(Integer.MAX_VALUE)).intValue() == ((Integer) list.stream().map(tuple44 -> {
                    return (Integer) tuple44.f3;
                }).max((v0, v1) -> {
                    return v0.compareTo(v1);
                }).orElse(Integer.MIN_VALUE)).intValue(), "Feature vectors should be of equal length.");
                HashMap hashMap2 = new HashMap();
                HashMap hashMap3 = new HashMap();
                for (Tuple4 tuple45 : list) {
                    Map[] mapArr = (Map[]) hashMap3.computeIfAbsent(tuple45.f0, d2 -> {
                        return new HashMap[intValue];
                    });
                    hashMap2.put(tuple45.f0, tuple45.f3);
                    mapArr[((Integer) tuple45.f1).intValue()] = (Map) tuple45.f2;
                }
                Iterator it = hashMap3.keySet().iterator();
                while (it.hasNext()) {
                    double doubleValue = ((Double) it.next()).doubleValue();
                    collector.collect(new Tuple3(Double.valueOf(doubleValue), hashMap2.get(Double.valueOf(doubleValue)), hashMap3.get(Double.valueOf(doubleValue))));
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/classification/naivebayes/NaiveBayes$ExtractFeatureFunction.class */
    public static class ExtractFeatureFunction implements FlatMapFunction<Tuple2<Vector, Double>, Tuple3<Double, Integer, Double>> {
        private ExtractFeatureFunction() {
        }

        public void flatMap(Tuple2<Vector, Double> tuple2, Collector<Tuple3<Double, Integer, Double>> collector) {
            Preconditions.checkNotNull(tuple2.f1);
            for (int i = 0; i < ((Vector) tuple2.f0).size(); i++) {
                collector.collect(new Tuple3(tuple2.f1, Integer.valueOf(i), Double.valueOf(((Vector) tuple2.f0).get(i))));
            }
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Tuple2<Vector, Double>) obj, (Collector<Tuple3<Double, Integer, Double>>) collector);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/classification/naivebayes/NaiveBayes$GenerateFeatureWeightMapFunction.class */
    public static class GenerateFeatureWeightMapFunction implements MapPartitionFunction<Tuple3<Double, Integer, Double>, Tuple4<Double, Integer, Map<Double, Double>, Integer>> {
        private GenerateFeatureWeightMapFunction() {
        }

        public void mapPartition(Iterable<Tuple3<Double, Integer, Double>> iterable, Collector<Tuple4<Double, Integer, Map<Double, Double>, Integer>> collector) {
            ArrayList<Tuple3> arrayList = new ArrayList();
            Iterator<Tuple3<Double, Integer, Double>> it = iterable.iterator();
            arrayList.getClass();
            it.forEachRemaining((v1) -> {
                r1.add(v1);
            });
            HashMap hashMap = new HashMap();
            HashMap hashMap2 = new HashMap();
            for (Tuple3 tuple3 : arrayList) {
                Tuple2 tuple2 = new Tuple2(tuple3.f0, tuple3.f1);
                Map map = (Map) hashMap.computeIfAbsent(tuple2, tuple22 -> {
                    return new HashMap();
                });
                map.put(tuple3.f2, Double.valueOf(((Double) map.getOrDefault(tuple3.f2, Double.valueOf(0.0d))).doubleValue() + 1.0d));
                hashMap2.put(tuple2, Integer.valueOf(((Integer) hashMap2.getOrDefault(tuple2, 0)).intValue() + 1));
            }
            for (Map.Entry entry : hashMap.entrySet()) {
                collector.collect(new Tuple4(((Tuple2) entry.getKey()).f0, ((Tuple2) entry.getKey()).f1, entry.getValue(), hashMap2.get(entry.getKey())));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/classification/naivebayes/NaiveBayes$GenerateModelFunction.class */
    public static class GenerateModelFunction implements MapPartitionFunction<Tuple3<Double, Integer, Map<Double, Double>[]>, NaiveBayesModelData> {
        private final double smoothing;

        private GenerateModelFunction(double d) {
            this.smoothing = d;
        }

        public void mapPartition(Iterable<Tuple3<Double, Integer, Map<Double, Double>[]>> iterable, Collector<NaiveBayesModelData> collector) {
            ArrayList arrayList = new ArrayList();
            Iterator<Tuple3<Double, Integer, Map<Double, Double>[]>> it = iterable.iterator();
            arrayList.getClass();
            it.forEachRemaining((v1) -> {
                r1.add(v1);
            });
            int length = ((Map[]) ((Tuple3) arrayList.get(0)).f2).length;
            Iterator it2 = arrayList.iterator();
            while (it2.hasNext()) {
                Preconditions.checkArgument(length == ((Map[]) ((Tuple3) it2.next()).f2).length, "Feature vectors should be of equal length.");
            }
            double[] dArr = new double[length];
            HashSet[] hashSetArr = new HashSet[length];
            for (int i = 0; i < length; i++) {
                hashSetArr[i] = new HashSet();
            }
            Iterator it3 = arrayList.iterator();
            while (it3.hasNext()) {
                Tuple3 tuple3 = (Tuple3) it3.next();
                for (int i2 = 0; i2 < length; i2++) {
                    int i3 = i2;
                    dArr[i3] = dArr[i3] + ((Integer) tuple3.f1).intValue();
                    hashSetArr[i2].addAll(((Map[]) tuple3.f2)[i2].keySet());
                }
            }
            int[] iArr = new int[length];
            double d = 0.0d;
            int size = arrayList.size();
            for (int i4 = 0; i4 < length; i4++) {
                iArr[i4] = hashSetArr[i4].size();
                d += dArr[i4];
            }
            double log = Math.log(d + (size * this.smoothing));
            HashMap[][] hashMapArr = new HashMap[size][length];
            double[] dArr2 = new double[size];
            double[] dArr3 = new double[size];
            for (int i5 = 0; i5 < size; i5++) {
                Map[] mapArr = (Map[]) ((Tuple3) arrayList.get(i5)).f2;
                for (int i6 = 0; i6 < length; i6++) {
                    HashMap hashMap = new HashMap();
                    double log2 = Math.log((((Integer) ((Tuple3) arrayList.get(i5)).f1).intValue() * 1.0d) + (this.smoothing * iArr[i6]));
                    Iterator it4 = hashSetArr[i6].iterator();
                    while (it4.hasNext()) {
                        Double d2 = (Double) it4.next();
                        hashMap.put(d2, Double.valueOf(Math.log(((Double) mapArr[i6].getOrDefault(d2, Double.valueOf(0.0d))).doubleValue() + this.smoothing) - log2));
                    }
                    hashMapArr[i5][i6] = hashMap;
                }
                dArr3[i5] = ((Double) ((Tuple3) arrayList.get(i5)).f0).doubleValue();
                dArr2[i5] = Math.log((((Integer) ((Tuple3) arrayList.get(i5)).f1).intValue() * length) + this.smoothing) - log;
            }
            collector.collect(new NaiveBayesModelData(hashMapArr, Vectors.dense(dArr2), Vectors.dense(dArr3)));
        }
    }

    public NaiveBayes() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    /* renamed from: fit, reason: merged with bridge method [inline-methods] */
    public NaiveBayesModel m15fit(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        final String featuresCol = getFeaturesCol();
        final String labelCol = getLabelCol();
        double doubleValue = getSmoothing().doubleValue();
        StreamTableEnvironment tableEnvironment = ((TableImpl) tableArr[0]).getTableEnvironment();
        DataStream mapPartition = DataStreamUtils.mapPartition(DataStreamUtils.mapPartition(DataStreamUtils.mapPartition(tableEnvironment.toDataStream(tableArr[0]).map(new MapFunction<Row, Tuple2<Vector, Double>>() { // from class: org.apache.flink.ml.classification.naivebayes.NaiveBayes.1
            public Tuple2<Vector, Double> map(Row row) throws Exception {
                Number number = (Number) row.getField(labelCol);
                Preconditions.checkNotNull(number, "Input data should contain label value.");
                Preconditions.checkArgument(((double) number.intValue()) == number.doubleValue(), "Label value should be indexed number.");
                return new Tuple2<>((Vector) row.getField(featuresCol), Double.valueOf(number.doubleValue()));
            }
        }).flatMap(new ExtractFeatureFunction()).keyBy(tuple3 -> {
            return Integer.valueOf(new Tuple2(tuple3.f0, tuple3.f1).hashCode());
        }), new GenerateFeatureWeightMapFunction()).keyBy(tuple4 -> {
            return (Double) tuple4.f0;
        }), new AggregateIntoArrayFunction()), new GenerateModelFunction(doubleValue));
        mapPartition.getTransformation().setParallelism(1);
        NaiveBayesModel m16setModelData = new NaiveBayesModel().m16setModelData(tableEnvironment.fromDataStream(mapPartition));
        ReadWriteUtils.updateExistingParams(m16setModelData, this.paramMap);
        return m16setModelData;
    }

    public void save(String str) throws IOException {
        ReadWriteUtils.saveMetadata(this, str);
    }

    public static NaiveBayes load(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        return ReadWriteUtils.loadStageParam(str);
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1754051916:
                if (implMethodName.equals("lambda$fit$ae0522b0$1")) {
                    z = true;
                    break;
                }
                break;
            case 1120142134:
                if (implMethodName.equals("lambda$fit$fcf4ef7d$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/java/functions/KeySelector") && serializedLambda.getFunctionalInterfaceMethodName().equals("getKey") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/classification/naivebayes/NaiveBayes") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/api/java/tuple/Tuple3;)Ljava/lang/Integer;")) {
                    return tuple3 -> {
                        return Integer.valueOf(new Tuple2(tuple3.f0, tuple3.f1).hashCode());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/java/functions/KeySelector") && serializedLambda.getFunctionalInterfaceMethodName().equals("getKey") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/classification/naivebayes/NaiveBayes") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/api/java/tuple/Tuple4;)Ljava/lang/Double;")) {
                    return tuple4 -> {
                        return (Double) tuple4.f0;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
