package org.apache.flink.ml.common.iteration;

import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.iteration.IterationListener;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.class */
public class TerminateOnMaxIterOrTol implements IterationListener<Integer>, FlatMapFunction<Double, Integer> {
    private final int maxIter;
    private final double tol;
    private double loss;

    public TerminateOnMaxIterOrTol(Integer num, Double d) {
        this.loss = Double.MAX_VALUE;
        this.maxIter = num.intValue();
        this.tol = d.doubleValue();
    }

    public TerminateOnMaxIterOrTol(Double d) {
        this.loss = Double.MAX_VALUE;
        this.maxIter = Integer.MAX_VALUE;
        this.tol = d.doubleValue();
    }

    public void flatMap(Double d, Collector<Integer> collector) {
        Preconditions.checkArgument(Double.compare(this.loss, Double.MAX_VALUE) == 0, "Each epoch should contain only one loss value.");
        this.loss = d.doubleValue();
    }

    @Override // org.apache.flink.iteration.IterationListener
    public void onEpochWatermarkIncremented(int i, IterationListener.Context context, Collector<Integer> collector) {
        if (i + 1 < this.maxIter && this.loss > this.tol) {
            collector.collect(0);
        }
        this.loss = Double.MAX_VALUE;
    }

    @Override // org.apache.flink.iteration.IterationListener
    public void onIterationTerminated(IterationListener.Context context, Collector<Integer> collector) {
    }

    public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
        flatMap((Double) obj, (Collector<Integer>) collector);
    }
}
