/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.preprocessing.encoding.frequency;

import java.io.Serializable;
import java.util.Map;
import java.util.Set;
import org.apache.ignite.ml.math.exceptions.preprocessing.UnknownCategorialValueException;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.preprocessing.encoding.EncoderPreprocessor;
import org.apache.ignite.ml.structures.LabeledVector;

public class FrequencyEncoderPreprocessor<K, V>
extends EncoderPreprocessor<K, V> {
    protected static final long serialVersionUID = 6237711236382623488L;
    protected final Map<String, Double>[] encodingFrequencies;

    public FrequencyEncoderPreprocessor(Map<String, Double>[] encodingFrequencies, Preprocessor<K, V> basePreprocessor, Set<Integer> handledIndices) {
        super(null, basePreprocessor, handledIndices);
        this.encodingFrequencies = encodingFrequencies;
    }

    @Override
    public LabeledVector apply(K k, V v) {
        LabeledVector tmp = (LabeledVector)this.basePreprocessor.apply(k, v);
        double[] res = new double[tmp.size()];
        for (int i = 0; i < res.length; ++i) {
            Serializable tmpObj = tmp.getRaw(i);
            if (this.handledIndices.contains(i)) {
                if (tmpObj.equals(Double.NaN) && this.encodingFrequencies[i].containsKey("")) {
                    res[i] = ((Integer)this.encodingValues[i].get("")).intValue();
                    continue;
                }
                if (this.encodingFrequencies[i].containsKey(tmpObj)) {
                    res[i] = this.encodingFrequencies[i].get(tmpObj);
                    continue;
                }
                throw new UnknownCategorialValueException(tmpObj.toString());
            }
            res[i] = (Double)tmpObj;
        }
        return new LabeledVector(VectorUtils.of(res), tmp.label());
    }
}

