package org.apache.flink.ml.feature.hashingtf;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.api.Transformer;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.shaded.guava30.com.google.common.hash.HashFunction;
import org.apache.flink.shaded.guava30.com.google.common.hash.Hashing;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/feature/hashingtf/HashingTF.class */
public class HashingTF implements Transformer<HashingTF>, HashingTFParams<HashingTF> {
    private final Map<Param<?>, Object> paramMap = new HashMap();
    private static final HashFunction HASH_FUNC = Hashing.murmur3_32(0);

    /* loaded from: input_file:org/apache/flink/ml/feature/hashingtf/HashingTF$HashTFFunction.class */
    public static class HashTFFunction implements MapFunction<Row, Row> {
        private final String inputCol;
        private final boolean binary;
        private final int numFeatures;

        public HashTFFunction(String str, boolean z, int i) {
            this.inputCol = str;
            this.binary = z;
            this.numFeatures = i;
        }

        public Row map(Row row) throws Exception {
            Iterable iterable;
            Object field = row.getField(this.inputCol);
            if (field.getClass().isArray()) {
                iterable = Arrays.asList((Object[]) field);
            } else {
                if (!(field instanceof Iterable)) {
                    throw new IllegalArgumentException("Input format " + field.getClass().getCanonicalName() + " is not supported for input column " + this.inputCol + ". Supported options are Array and Iterable.");
                }
                iterable = (Iterable) field;
            }
            HashMap hashMap = new HashMap();
            Iterator it = iterable.iterator();
            while (it.hasNext()) {
                int nonNegativeMod = HashingTF.nonNegativeMod(HashingTF.hash(it.next()), this.numFeatures);
                if (!hashMap.containsKey(Integer.valueOf(nonNegativeMod))) {
                    hashMap.put(Integer.valueOf(nonNegativeMod), 1);
                } else if (!this.binary) {
                    hashMap.put(Integer.valueOf(nonNegativeMod), Integer.valueOf(((Integer) hashMap.get(Integer.valueOf(nonNegativeMod))).intValue() + 1));
                }
            }
            int[] iArr = new int[hashMap.size()];
            double[] dArr = new double[hashMap.size()];
            int i = 0;
            Iterator it2 = hashMap.entrySet().iterator();
            while (it2.hasNext()) {
                iArr[i] = ((Integer) ((Map.Entry) it2.next()).getKey()).intValue();
                dArr[i] = ((Integer) r0.getValue()).intValue();
                i++;
            }
            return Row.join(row, new Row[]{Row.of(new Object[]{Vectors.sparse(this.numFeatures, iArr, dArr)})});
        }
    }

    public HashingTF() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    @Override // org.apache.flink.ml.api.AlgoOperator
    public Table[] transform(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        StreamTableEnvironment tableEnvironment = ((TableImpl) tableArr[0]).getTableEnvironment();
        RowTypeInfo rowTypeInfo = TableUtils.getRowTypeInfo(tableArr[0].getResolvedSchema());
        return new Table[]{tableEnvironment.fromDataStream(tableEnvironment.toDataStream(tableArr[0]).map(new HashTFFunction(getInputCol(), getBinary(), getNumFeatures()), new RowTypeInfo((TypeInformation[]) ArrayUtils.addAll(rowTypeInfo.getFieldTypes(), new TypeInformation[]{SparseVectorTypeInfo.INSTANCE}), (String[]) ArrayUtils.addAll(rowTypeInfo.getFieldNames(), new String[]{getOutputCol()}))))};
    }

    @Override // org.apache.flink.ml.api.Stage
    public void save(String str) throws IOException {
        ReadWriteUtils.saveMetadata(this, str);
    }

    @Override // org.apache.flink.ml.param.WithParams
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    public static HashingTF load(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        return (HashingTF) ReadWriteUtils.loadStageParam(str);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static int hash(Object obj) {
        if (obj == null) {
            return 0;
        }
        if (obj instanceof Boolean) {
            return HASH_FUNC.hashInt(((Boolean) obj).booleanValue() ? 1 : 0).asInt();
        }
        if (obj instanceof Byte) {
            return HASH_FUNC.hashInt(((Byte) obj).byteValue()).asInt();
        }
        if (obj instanceof Short) {
            return HASH_FUNC.hashInt(((Short) obj).shortValue()).asInt();
        }
        if (obj instanceof Integer) {
            return HASH_FUNC.hashInt(((Integer) obj).intValue()).asInt();
        }
        if (obj instanceof Long) {
            return HASH_FUNC.hashLong(((Long) obj).longValue()).asInt();
        }
        if (obj instanceof Float) {
            return HASH_FUNC.hashInt(Float.floatToIntBits(((Float) obj).floatValue())).asInt();
        }
        if (obj instanceof Double) {
            return HASH_FUNC.hashLong(Double.doubleToLongBits(((Double) obj).doubleValue())).asInt();
        }
        if (obj instanceof String) {
            return HASH_FUNC.hashUnencodedChars((String) obj).asInt();
        }
        throw new UnsupportedOperationException("HashingTF does not support type " + obj.getClass().getCanonicalName() + " of input data.");
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static int nonNegativeMod(int i, int i2) {
        int i3 = i % i2;
        return i3 < 0 ? i3 + i2 : i3;
    }
}
