package org.apache.flink.ml.feature.randomsplitter;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.api.AlgoOperator;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.OutputTag;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/feature/randomsplitter/RandomSplitter.class */
public class RandomSplitter implements AlgoOperator<RandomSplitter>, RandomSplitterParams<RandomSplitter> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    /* loaded from: input_file:org/apache/flink/ml/feature/randomsplitter/RandomSplitter$SplitterOperator.class */
    private static class SplitterOperator extends AbstractStreamOperator<Row> implements OneInputStreamOperator<Row, Row> {
        private Random random;
        private final long initSeed;
        OutputTag<Row>[] outputTag;
        final double[] fractions;

        public SplitterOperator(OutputTag<Row>[] outputTagArr, Double[] dArr, long j) {
            this.initSeed = j;
            this.outputTag = outputTagArr;
            this.fractions = new double[dArr.length];
            double d = 0.0d;
            for (Double d2 : dArr) {
                d += d2.doubleValue();
            }
            double d3 = 0.0d;
            for (int i = 0; i < this.fractions.length; i++) {
                d3 += dArr[i].doubleValue();
                this.fractions[i] = d3 / d;
            }
        }

        public void open() throws Exception {
            super.open();
            this.random = new Random(Tuple2.of(Long.valueOf(this.initSeed), Integer.valueOf(getRuntimeContext().getIndexOfThisSubtask())).hashCode());
        }

        public void processElement(StreamRecord<Row> streamRecord) throws Exception {
            int binarySearch = Arrays.binarySearch(this.fractions, this.random.nextDouble());
            int i = binarySearch < 0 ? (-binarySearch) - 2 : binarySearch - 1;
            if (i == -1) {
                this.output.collect(streamRecord);
            } else {
                this.output.collect(this.outputTag[i], streamRecord);
            }
        }
    }

    public RandomSplitter() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    @Override // org.apache.flink.ml.api.AlgoOperator
    public Table[] transform(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        StreamTableEnvironment tableEnvironment = ((TableImpl) tableArr[0]).getTableEnvironment();
        RowTypeInfo rowTypeInfo = TableUtils.getRowTypeInfo(tableArr[0].getResolvedSchema());
        Double[] weights = getWeights();
        OutputTag[] outputTagArr = new OutputTag[weights.length - 1];
        for (int i = 0; i < outputTagArr.length; i++) {
            outputTagArr[i] = new OutputTag<Row>("outputTag_" + i, rowTypeInfo) { // from class: org.apache.flink.ml.feature.randomsplitter.RandomSplitter.1
            };
        }
        SingleOutputStreamOperator transform = tableEnvironment.toDataStream(tableArr[0]).transform("SplitterOperator", rowTypeInfo, new SplitterOperator(outputTagArr, weights, getSeed()));
        Table[] tableArr2 = new Table[weights.length];
        tableArr2[0] = tableEnvironment.fromDataStream(transform);
        for (int i2 = 0; i2 < outputTagArr.length; i2++) {
            tableArr2[i2 + 1] = tableEnvironment.fromDataStream(transform.getSideOutput(outputTagArr[i2]));
        }
        return tableArr2;
    }

    @Override // org.apache.flink.ml.api.Stage
    public void save(String str) throws IOException {
        ReadWriteUtils.saveMetadata(this, str);
    }

    public static RandomSplitter load(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        return (RandomSplitter) ReadWriteUtils.loadStageParam(str);
    }

    @Override // org.apache.flink.ml.param.WithParams
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }
}
