package org.apache.flink.ml.feature.countvectorizer;

import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/feature/countvectorizer/CountVectorizer.class */
public class CountVectorizer implements Estimator<CountVectorizer, CountVectorizerModel>, CountVectorizerParams<CountVectorizer> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/feature/countvectorizer/CountVectorizer$VocabularyAggregator.class */
    public static class VocabularyAggregator implements AggregateFunction<String[], Tuple2<Long, Map<String, Tuple2<Long, Long>>>, CountVectorizerModelData> {
        private final double minDF;
        private final double maxDF;
        private final int vocabularySize;

        public VocabularyAggregator(double d, double d2, int i) {
            this.minDF = d;
            this.maxDF = d2;
            this.vocabularySize = i;
        }

        /* renamed from: createAccumulator, reason: merged with bridge method [inline-methods] */
        public Tuple2<Long, Map<String, Tuple2<Long, Long>>> m42createAccumulator() {
            return Tuple2.of(0L, new HashMap());
        }

        public Tuple2<Long, Map<String, Tuple2<Long, Long>>> add(String[] strArr, Tuple2<Long, Map<String, Tuple2<Long, Long>>> tuple2) {
            HashMap hashMap = new HashMap();
            Arrays.stream(strArr).forEach(str -> {
                if (hashMap.containsKey(str)) {
                    hashMap.put(str, Long.valueOf(((Long) hashMap.get(str)).longValue() + 1));
                } else {
                    hashMap.put(str, 1L);
                }
            });
            Map map = (Map) tuple2.f1;
            hashMap.forEach((str2, l) -> {
                if (!map.containsKey(str2)) {
                    map.put(str2, Tuple2.of(l, 1L));
                    return;
                }
                Tuple2 tuple22 = (Tuple2) map.get(str2);
                tuple22.f0 = Long.valueOf(((Long) tuple22.f0).longValue() + l.longValue());
                Tuple2 tuple23 = (Tuple2) map.get(str2);
                tuple23.f1 = Long.valueOf(((Long) tuple23.f1).longValue() + 1);
            });
            tuple2.f0 = Long.valueOf(((Long) tuple2.f0).longValue() + 1);
            return tuple2;
        }

        public CountVectorizerModelData getResult(Tuple2<Long, Map<String, Tuple2<Long, Long>>> tuple2) {
            Preconditions.checkState(((Long) tuple2.f0).longValue() > 0, "The training set is empty.");
            if ((((Double) CountVectorizerParams.MIN_DF.defaultValue).equals(Double.valueOf(this.minDF)) && ((Double) CountVectorizerParams.MAX_DF.defaultValue).equals(Double.valueOf(this.maxDF))) ? false : true) {
                long longValue = ((Long) tuple2.f0).longValue();
                double d = this.minDF >= 1.0d ? this.minDF : this.minDF * longValue;
                double d2 = this.maxDF >= 1.0d ? this.maxDF : this.maxDF * longValue;
                Preconditions.checkState(d2 >= d, "maxDF must be >= minDF.");
                tuple2.f1 = ((Map) tuple2.f1).entrySet().stream().filter(entry -> {
                    return ((double) ((Long) ((Tuple2) entry.getValue()).f1).longValue()) >= d && ((double) ((Long) ((Tuple2) entry.getValue()).f1).longValue()) <= d2;
                }).collect(Collectors.toMap((v0) -> {
                    return v0.getKey();
                }, (v0) -> {
                    return v0.getValue();
                }));
            }
            ArrayList arrayList = new ArrayList(((Map) tuple2.f1).entrySet());
            arrayList.sort((entry2, entry3) -> {
                return ((Long) ((Tuple2) entry3.getValue()).f1).compareTo((Long) ((Tuple2) entry2.getValue()).f1);
            });
            List list = (List) arrayList.stream().map((v0) -> {
                return v0.getKey();
            }).collect(Collectors.toList());
            return new CountVectorizerModelData((String[]) list.subList(0, Math.min(list.size(), this.vocabularySize)).toArray(new String[0]));
        }

        public Tuple2<Long, Map<String, Tuple2<Long, Long>>> merge(Tuple2<Long, Map<String, Tuple2<Long, Long>>> tuple2, Tuple2<Long, Map<String, Tuple2<Long, Long>>> tuple22) {
            if (((Long) tuple2.f0).longValue() == 0) {
                return tuple22;
            }
            if (((Long) tuple22.f0).longValue() == 0) {
                return tuple2;
            }
            tuple22.f0 = Long.valueOf(((Long) tuple22.f0).longValue() + ((Long) tuple2.f0).longValue());
            ((Map) tuple2.f1).forEach((str, tuple23) -> {
                if (((Map) tuple22.f1).containsKey(str)) {
                    ((Map) tuple22.f1).put(str, Tuple2.of(Long.valueOf(((Long) tuple23.f0).longValue() + ((Long) ((Tuple2) ((Map) tuple22.f1).get(str)).f0).longValue()), Long.valueOf(((Long) tuple23.f1).longValue() + ((Long) ((Tuple2) ((Map) tuple22.f1).get(str)).f1).longValue())));
                } else {
                    ((Map) tuple22.f1).put(str, tuple23);
                }
            });
            return tuple22;
        }
    }

    public CountVectorizer() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    /* renamed from: fit, reason: merged with bridge method [inline-methods] */
    public CountVectorizerModel m41fit(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        double minDF = getMinDF();
        double maxDF = getMaxDF();
        if ((minDF >= 1.0d && maxDF >= 1.0d) || (minDF < 1.0d && maxDF < 1.0d)) {
            Preconditions.checkArgument(maxDF >= minDF, "maxDF must be >= minDF.");
        }
        String inputCol = getInputCol();
        StreamTableEnvironment tableEnvironment = ((TableImpl) tableArr[0]).getTableEnvironment();
        CountVectorizerModel m43setModelData = new CountVectorizerModel().m43setModelData(tableEnvironment.fromDataStream(DataStreamUtils.aggregate(tableEnvironment.toDataStream(tableArr[0]).map(row -> {
            return (String[]) row.getField(inputCol);
        }), new VocabularyAggregator(getMinDF(), getMaxDF(), getVocabularySize()), Types.TUPLE(new TypeInformation[]{Types.LONG, Types.MAP(Types.STRING, Types.TUPLE(new TypeInformation[]{Types.LONG, Types.LONG}))}), TypeInformation.of(CountVectorizerModelData.class))));
        ParamUtils.updateExistingParams(m43setModelData, getParamMap());
        return m43setModelData;
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    public void save(String str) throws IOException {
        ReadWriteUtils.saveMetadata(this, str);
    }

    public static CountVectorizer load(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        return ReadWriteUtils.loadStageParam(str);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1509494440:
                if (implMethodName.equals("lambda$fit$826503a0$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/feature/countvectorizer/CountVectorizer") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/String;Lorg/apache/flink/types/Row;)[Ljava/lang/String;")) {
                    String str = (String) serializedLambda.getCapturedArg(0);
                    return row -> {
                        return (String[]) row.getField(str);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
