package org.apache.flink.ml.feature.stopwordsremover;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.apache.flink.ml.api.Transformer;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.table.api.Expressions;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.functions.FunctionContext;
import org.apache.flink.table.functions.ScalarFunction;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/feature/stopwordsremover/StopWordsRemover.class */
public class StopWordsRemover implements Transformer<StopWordsRemover>, StopWordsRemoverParams<StopWordsRemover> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    /* loaded from: input_file:org/apache/flink/ml/feature/stopwordsremover/StopWordsRemover$RemoveStopWordsFunction.class */
    public static class RemoveStopWordsFunction extends ScalarFunction {
        private final Set<String> stopWords;
        private final Locale locale;
        private final boolean caseSensitive;
        private transient Predicate<String> predicate;

        public RemoveStopWordsFunction(Set<String> set, Locale locale, boolean z) {
            this.locale = locale;
            this.caseSensitive = z;
            if (z) {
                this.stopWords = set;
            } else {
                this.stopWords = (Set) set.stream().map(str -> {
                    if (str == null) {
                        return null;
                    }
                    return str.toLowerCase(locale);
                }).collect(Collectors.toSet());
            }
        }

        public void open(FunctionContext functionContext) throws Exception {
            super.open(functionContext);
            if (this.caseSensitive) {
                this.predicate = str -> {
                    return !this.stopWords.contains(str);
                };
            } else {
                this.predicate = str2 -> {
                    if (str2 != null) {
                        str2 = str2.toLowerCase(this.locale);
                    }
                    return !this.stopWords.contains(str2);
                };
            }
        }

        public String[] eval(String[] strArr) {
            return (String[]) Arrays.stream(strArr).filter(this.predicate).toArray(i -> {
                return new String[i];
            });
        }
    }

    public StopWordsRemover() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    @Override // org.apache.flink.ml.api.AlgoOperator
    public Table[] transform(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        Preconditions.checkArgument(getInputCols().length == getOutputCols().length);
        String[] inputCols = getInputCols();
        String[] outputCols = getOutputCols();
        RemoveStopWordsFunction removeStopWordsFunction = new RemoveStopWordsFunction(new HashSet(Arrays.asList(getStopWords())), new Locale(getLocale()), getCaseSensitive());
        Expression[] expressionArr = new Expression[inputCols.length + 1];
        expressionArr[0] = Expressions.$("*");
        for (int i = 0; i < inputCols.length; i++) {
            expressionArr[i + 1] = (Expression) Expressions.call(removeStopWordsFunction, new Object[]{Expressions.$(inputCols[i])}).as(outputCols[i], new String[0]);
        }
        return new Table[]{tableArr[0].select(expressionArr)};
    }

    @Override // org.apache.flink.ml.api.Stage
    public void save(String str) throws IOException {
        ReadWriteUtils.saveMetadata(this, str);
    }

    public static StopWordsRemover load(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        return (StopWordsRemover) ReadWriteUtils.loadStageParam(str);
    }

    @Override // org.apache.flink.ml.param.WithParams
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    public static String[] loadDefaultStopWords(String str) {
        return StopWordsRemoverUtils.loadDefaultStopWords(str);
    }

    public static String getDefaultOrUS() {
        return StopWordsRemoverUtils.getDefaultOrUS();
    }

    public static Set<String> getAvailableLocales() {
        return StopWordsRemoverUtils.getAvailableLocales();
    }
}
