package org.apache.flink.streaming.api.operators.collect;

import java.io.IOException;
import java.lang.Thread;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.accumulators.SerializedListAccumulator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
import org.apache.flink.streaming.api.operators.collect.utils.MockFunctionInitializationContext;
import org.apache.flink.streaming.api.operators.collect.utils.MockFunctionSnapshotContext;
import org.apache.flink.streaming.api.operators.collect.utils.MockOperatorEventGateway;
import org.apache.flink.streaming.api.operators.collect.utils.TestJobClient;
import org.apache.flink.streaming.util.MockStreamingRuntimeContext;
import org.apache.flink.util.OptionalFailure;
import org.apache.flink.util.TestLogger;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/streaming/api/operators/collect/CollectSinkFunctionTest.class */
public class CollectSinkFunctionTest extends TestLogger {
    private static final int MAX_RESULTS_PER_BATCH = 3;
    private static final String ACCUMULATOR_NAME = "tableCollectAccumulator";
    private static final int FUTURE_TIMEOUT_MILLIS = 10000;
    private static final int SOCKET_TIMEOUT_MILLIS = 1000;
    private static final int MAX_RETIRES = 100;
    private static final JobID TEST_JOB_ID = new JobID();
    private static final OperatorID TEST_OPERATOR_ID = new OperatorID();
    private static final TypeSerializer<Integer> serializer = IntSerializer.INSTANCE;
    private CollectSinkFunction<Integer> function;
    private CollectSinkOperatorCoordinator coordinator;
    private MockFunctionInitializationContext functionInitializationContext;
    private boolean jobFinished;
    private IOManager ioManager;
    private StreamingRuntimeContext runtimeContext;
    private MockOperatorEventGateway gateway;

    /* loaded from: input_file:org/apache/flink/streaming/api/operators/collect/CollectSinkFunctionTest$CheckpointCountdown.class */
    private static class CheckpointCountdown {
        private long id;
        private List<Integer> data;
        private int countdown;

        private CheckpointCountdown(long j, List<Integer> list, int i) {
            this.id = j;
            this.data = new ArrayList(list);
            this.countdown = i;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public boolean tick() {
            if (this.countdown <= 0) {
                return false;
            }
            this.countdown--;
            return this.countdown == 0;
        }
    }

    /* loaded from: input_file:org/apache/flink/streaming/api/operators/collect/CollectSinkFunctionTest$CheckpointedDataFeeder.class */
    private class CheckpointedDataFeeder extends Thread {
        private LinkedList<Integer> data;
        private List<Integer> checkpointedData;
        private long checkpointId;
        private long lastSuccessCheckpointId;
        private List<CheckpointCountdown> checkpointCountdowns;

        private CheckpointedDataFeeder(List<Integer> list) {
            this.data = new LinkedList<>(list);
            this.checkpointedData = new ArrayList(list);
            this.checkpointId = 0L;
            this.lastSuccessCheckpointId = 0L;
            this.checkpointCountdowns = new ArrayList();
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            Random random = new Random();
            try {
                CollectSinkFunctionTest.this.openFunctionWithState();
                while (this.data.size() > 0) {
                    ListIterator<CheckpointCountdown> listIterator = this.checkpointCountdowns.listIterator();
                    while (listIterator.hasNext()) {
                        CheckpointCountdown next = listIterator.next();
                        if (next.id < this.lastSuccessCheckpointId) {
                            listIterator.remove();
                        } else if (next.tick()) {
                            this.checkpointedData = next.data;
                            CollectSinkFunctionTest.this.checkpointComplete(next.id);
                            this.lastSuccessCheckpointId = next.id;
                            listIterator.remove();
                        }
                    }
                    int nextInt = random.nextInt(10);
                    if (nextInt < 6) {
                        int min = Math.min(this.data.size(), random.nextInt(9) + 1);
                        for (int i = 0; i < min; i++) {
                            CollectSinkFunctionTest.this.function.invoke(this.data.removeFirst(), (SinkFunction.Context) null);
                        }
                    } else if (nextInt < 9) {
                        this.checkpointId++;
                        if (random.nextBoolean()) {
                            this.checkpointCountdowns.add(new CheckpointCountdown(this.checkpointId, this.data, random.nextInt(CollectSinkFunctionTest.MAX_RESULTS_PER_BATCH) + 1));
                        }
                        CollectSinkFunctionTest.this.checkpointFunction(this.checkpointId);
                    } else {
                        this.checkpointCountdowns.clear();
                        Collections.shuffle(this.checkpointedData);
                        this.data = new LinkedList<>(this.checkpointedData);
                        CollectSinkFunctionTest.this.closeFuntionAbnormally();
                        CollectSinkFunctionTest.this.openFunctionWithState();
                    }
                    if (random.nextBoolean()) {
                        Thread.sleep(random.nextInt(10));
                    }
                }
                CollectSinkFunctionTest.this.finishJob();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/streaming/api/operators/collect/CollectSinkFunctionTest$CollectClient.class */
    public class CollectClient extends Thread {
        private List<Integer> results;
        private CollectResultIterator<Integer> iterator;

        private CollectClient() {
            this.results = new ArrayList();
            this.iterator = new CollectResultIterator<>(CompletableFuture.completedFuture(CollectSinkFunctionTest.TEST_OPERATOR_ID), CollectSinkFunctionTest.serializer, CollectSinkFunctionTest.ACCUMULATOR_NAME, 0);
            this.iterator.setJobClient(new TestJobClient(CollectSinkFunctionTest.TEST_JOB_ID, CollectSinkFunctionTest.TEST_OPERATOR_ID, CollectSinkFunctionTest.this.coordinator, new TestJobClient.JobInfoProvider() { // from class: org.apache.flink.streaming.api.operators.collect.CollectSinkFunctionTest.CollectClient.1
                @Override // org.apache.flink.streaming.api.operators.collect.utils.TestJobClient.JobInfoProvider
                public boolean isJobFinished() {
                    return CollectSinkFunctionTest.this.jobFinished;
                }

                @Override // org.apache.flink.streaming.api.operators.collect.utils.TestJobClient.JobInfoProvider
                public Map<String, OptionalFailure<Object>> getAccumulatorResults() {
                    HashMap hashMap = new HashMap();
                    hashMap.put(CollectSinkFunctionTest.ACCUMULATOR_NAME, OptionalFailure.of(CollectSinkFunctionTest.this.runtimeContext.getAccumulator(CollectSinkFunctionTest.ACCUMULATOR_NAME).getLocalValue()));
                    return hashMap;
                }
            }));
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            Random random = new Random();
            while (this.iterator.hasNext()) {
                this.results.add(this.iterator.next());
                if (random.nextBoolean()) {
                    try {
                        Thread.sleep(5L);
                    } catch (InterruptedException e) {
                    }
                }
            }
            try {
                this.iterator.close();
            } catch (Exception e2) {
                throw new RuntimeException(e2);
            }
        }
    }

    /* loaded from: input_file:org/apache/flink/streaming/api/operators/collect/CollectSinkFunctionTest$UncheckpointedDataFeeder.class */
    private class UncheckpointedDataFeeder extends Thread {
        private LinkedList<Integer> data;
        private List<Integer> checkpointedData;
        private boolean failedBefore;

        private UncheckpointedDataFeeder(List<Integer> list) {
            this.data = new LinkedList<>(list);
            this.checkpointedData = new ArrayList(list);
            this.failedBefore = false;
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            Random random = new Random();
            try {
                CollectSinkFunctionTest.this.openFunction();
                while (this.data.size() > 0) {
                    int min = Math.min(this.data.size(), random.nextInt(9) + 1);
                    for (int i = 0; i < min; i++) {
                        CollectSinkFunctionTest.this.function.invoke(this.data.removeFirst(), (SinkFunction.Context) null);
                    }
                    if (!this.failedBefore && this.data.size() < this.checkpointedData.size() / 2) {
                        if (random.nextBoolean()) {
                            Collections.shuffle(this.checkpointedData);
                            this.data = new LinkedList<>(this.checkpointedData);
                            CollectSinkFunctionTest.this.closeFuntionAbnormally();
                            CollectSinkFunctionTest.this.openFunction();
                        }
                        this.failedBefore = true;
                    }
                    if (random.nextBoolean()) {
                        Thread.sleep(random.nextInt(10));
                    }
                }
                CollectSinkFunctionTest.this.finishJob();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
    }

    @Before
    public void before() throws Exception {
        this.ioManager = new IOManagerAsync();
        this.runtimeContext = new MockStreamingRuntimeContext(false, 1, 0, this.ioManager);
        this.gateway = new MockOperatorEventGateway();
        this.coordinator = new CollectSinkOperatorCoordinator(SOCKET_TIMEOUT_MILLIS);
        this.coordinator.start();
        this.functionInitializationContext = new MockFunctionInitializationContext();
        this.jobFinished = false;
    }

    @After
    public void after() throws Exception {
        this.coordinator.close();
        this.ioManager.close();
    }

    @Test
    public void testUncheckpointedProtocol() throws Exception {
        openFunction();
        for (int i = 0; i < 6; i++) {
            this.function.invoke(Integer.valueOf(i), (SinkFunction.Context) null);
        }
        CollectCoordinationResponse<Integer> sendRequestAndGetValidResponse = sendRequestAndGetValidResponse("", 0L);
        Assert.assertEquals(0L, sendRequestAndGetValidResponse.getLastCheckpointedOffset());
        String version = sendRequestAndGetValidResponse.getVersion();
        assertResponseEquals(sendRequestAndGetValidResponse(version, 0L), version, 0L, Arrays.asList(0, 1, 2));
        assertResponseEquals(sendRequestAndGetValidResponse(version, 4L), version, 0L, Arrays.asList(4, 5));
        assertResponseEquals(sendRequestAndGetValidResponse(version, 6L), version, 0L, Collections.emptyList());
        for (int i2 = 6; i2 < 10; i2++) {
            this.function.invoke(Integer.valueOf(i2), (SinkFunction.Context) null);
        }
        assertResponseEquals(sendRequestAndGetValidResponse(version, 5L), version, 0L, Collections.emptyList());
        assertResponseEquals(sendRequestAndGetValidResponse(version, 6L), version, 0L, Arrays.asList(6, 7, 8));
        assertResponseEquals(sendRequestAndGetValidResponse(version, 6L), version, 0L, Arrays.asList(6, 7, 8));
        assertResponseEquals(sendRequestAndGetValidResponse(version, 12L), version, 0L, Collections.emptyList());
        for (int i3 = 10; i3 < 16; i3++) {
            this.function.invoke(Integer.valueOf(i3), (SinkFunction.Context) null);
        }
        assertResponseEquals(sendRequestAndGetValidResponse(version, 12L), version, 0L, Arrays.asList(12, 13, 14));
        finishJob();
        assertAccumulatorResult(12L, version, 0L, Arrays.asList(12, 13, 14, 15));
    }

    @Test
    public void testCheckpointProtocol() throws Exception {
        openFunctionWithState();
        for (int i = 0; i < 2; i++) {
            this.function.invoke(Integer.valueOf(i), (SinkFunction.Context) null);
        }
        CollectCoordinationResponse<Integer> sendRequestAndGetValidResponse = sendRequestAndGetValidResponse("", 0L);
        Assert.assertEquals(0L, sendRequestAndGetValidResponse.getLastCheckpointedOffset());
        String version = sendRequestAndGetValidResponse.getVersion();
        assertResponseEquals(sendRequestAndGetValidResponse(version, 0L), version, 0L, Arrays.asList(0, 1));
        for (int i2 = 2; i2 < 6; i2++) {
            this.function.invoke(Integer.valueOf(i2), (SinkFunction.Context) null);
        }
        assertResponseEquals(sendRequestAndGetValidResponse(version, 3L), version, 0L, Arrays.asList(Integer.valueOf(MAX_RESULTS_PER_BATCH), 4, 5));
        checkpointFunction(1L);
        assertResponseEquals(sendRequestAndGetValidResponse(version, 4L), version, 0L, Arrays.asList(4, 5));
        checkpointComplete(1L);
        assertResponseEquals(sendRequestAndGetValidResponse(version, 4L), version, 3L, Arrays.asList(4, 5));
        for (int i3 = 6; i3 < 9; i3++) {
            this.function.invoke(Integer.valueOf(i3), (SinkFunction.Context) null);
        }
        assertResponseEquals(sendRequestAndGetValidResponse(version, 6L), version, 3L, Arrays.asList(6, 7, 8));
        closeFuntionAbnormally();
        openFunctionWithState();
        for (int i4 = 9; i4 < 12; i4++) {
            this.function.invoke(Integer.valueOf(i4), (SinkFunction.Context) null);
        }
        CollectCoordinationResponse<Integer> sendRequestAndGetValidResponse2 = sendRequestAndGetValidResponse(version, 4L);
        Assert.assertEquals(3L, sendRequestAndGetValidResponse2.getLastCheckpointedOffset());
        String version2 = sendRequestAndGetValidResponse2.getVersion();
        assertResponseEquals(sendRequestAndGetValidResponse(version2, 4L), version2, 3L, Arrays.asList(4, 5, 9));
        assertResponseEquals(sendRequestAndGetValidResponse(version2, 6L), version2, 3L, Arrays.asList(9, 10, 11));
        checkpointFunction(2L);
        checkpointComplete(2L);
        this.function.invoke(12, (SinkFunction.Context) null);
        assertResponseEquals(sendRequestAndGetValidResponse(version2, 7L), version2, 6L, Arrays.asList(10, 11, 12));
        closeFuntionAbnormally();
        openFunctionWithState();
        CollectCoordinationResponse<Integer> sendRequestAndGetValidResponse3 = sendRequestAndGetValidResponse(version2, 7L);
        Assert.assertEquals(6L, sendRequestAndGetValidResponse3.getLastCheckpointedOffset());
        String version3 = sendRequestAndGetValidResponse3.getVersion();
        assertResponseEquals(sendRequestAndGetValidResponse(version3, 7L), version3, 6L, Arrays.asList(10, 11));
        assertResponseEquals(sendRequest(version3, 9L), version3, 6L, Collections.emptyList());
        for (int i5 = 13; i5 < 17; i5++) {
            this.function.invoke(Integer.valueOf(i5), (SinkFunction.Context) null);
        }
        assertResponseEquals(sendRequestAndGetValidResponse(version3, 9L), version3, 6L, Arrays.asList(13, 14, 15));
        checkpointFunction(3L);
        checkpointComplete(3L);
        closeFuntionAbnormally();
        openFunctionWithState();
        CollectCoordinationResponse<Integer> sendRequestAndGetValidResponse4 = sendRequestAndGetValidResponse(version3, 12L);
        Assert.assertEquals(9L, sendRequestAndGetValidResponse4.getLastCheckpointedOffset());
        String version4 = sendRequestAndGetValidResponse4.getVersion();
        assertResponseEquals(sendRequestAndGetValidResponse(version4, 12L), version4, 9L, Collections.singletonList(16));
        for (int i6 = 17; i6 < 20; i6++) {
            this.function.invoke(Integer.valueOf(i6), (SinkFunction.Context) null);
        }
        assertResponseEquals(sendRequestAndGetValidResponse(version4, 12L), version4, 9L, Arrays.asList(16, 17, 18));
        checkpointFunction(4L);
        closeFuntionAbnormally();
        openFunctionWithState();
        CollectCoordinationResponse<Integer> sendRequestAndGetValidResponse5 = sendRequestAndGetValidResponse(version4, 12L);
        Assert.assertEquals(9L, sendRequestAndGetValidResponse5.getLastCheckpointedOffset());
        String version5 = sendRequestAndGetValidResponse5.getVersion();
        assertResponseEquals(sendRequestAndGetValidResponse(version5, 12L), version5, 9L, Collections.singletonList(16));
        for (int i7 = 20; i7 < 23; i7++) {
            this.function.invoke(Integer.valueOf(i7), (SinkFunction.Context) null);
        }
        assertResponseEquals(sendRequestAndGetValidResponse(version5, 12L), version5, 9L, Arrays.asList(16, 20, 21));
        finishJob();
        assertAccumulatorResult(12L, version5, 9L, Arrays.asList(16, 20, 21, 22));
    }

    @Test
    public void testUncheckpointedFunction() throws Exception {
        for (int i = 30; i > 0; i--) {
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < 50; i2++) {
                arrayList.add(Integer.valueOf(i2));
            }
            assertResultsEqualAfterSort(arrayList, runFunctionRandomTest(new UncheckpointedDataFeeder(arrayList)));
            after();
            before();
        }
    }

    @Test
    public void testCheckpointedFunction() throws Exception {
        for (int i = 30; i > 0; i--) {
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < 50; i2++) {
                arrayList.add(Integer.valueOf(i2));
            }
            assertResultsEqualAfterSort(arrayList, runFunctionRandomTest(new CheckpointedDataFeeder(arrayList)));
            after();
            before();
        }
    }

    private List<Integer> runFunctionRandomTest(Thread thread) throws Exception {
        CollectClient collectClient = new CollectClient();
        Thread.UncaughtExceptionHandler uncaughtExceptionHandler = (thread2, th) -> {
            thread.interrupt();
            collectClient.interrupt();
            th.printStackTrace();
        };
        thread.setUncaughtExceptionHandler(uncaughtExceptionHandler);
        collectClient.setUncaughtExceptionHandler(uncaughtExceptionHandler);
        thread.start();
        collectClient.start();
        thread.join();
        collectClient.join();
        return collectClient.results;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void openFunction() throws Exception {
        this.function = new CollectSinkFunction<>(serializer, MAX_RESULTS_PER_BATCH, ACCUMULATOR_NAME);
        this.function.setRuntimeContext(this.runtimeContext);
        this.function.setOperatorEventGateway(this.gateway);
        this.function.open(new Configuration());
        this.coordinator.handleEventFromOperator(0, this.gateway.getNextEvent());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void openFunctionWithState() throws Exception {
        this.functionInitializationContext.m56getOperatorStateStore().revertToLastSuccessCheckpoint();
        this.function = new CollectSinkFunction<>(serializer, MAX_RESULTS_PER_BATCH, ACCUMULATOR_NAME);
        this.function.setRuntimeContext(this.runtimeContext);
        this.function.setOperatorEventGateway(this.gateway);
        this.function.initializeState(this.functionInitializationContext);
        this.function.open(new Configuration());
        this.coordinator.handleEventFromOperator(0, this.gateway.getNextEvent());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void checkpointFunction(long j) throws Exception {
        this.function.snapshotState(new MockFunctionSnapshotContext(j));
        this.functionInitializationContext.m56getOperatorStateStore().checkpointBegin(j);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void checkpointComplete(long j) throws Exception {
        this.function.notifyCheckpointComplete(j);
        this.functionInitializationContext.m56getOperatorStateStore().checkpointSuccess(j);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void closeFuntionAbnormally() throws Exception {
        this.function.close();
        this.coordinator.subtaskFailed(0, (Throwable) null);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void finishJob() throws Exception {
        this.function.accumulateFinalResults();
        this.function.close();
        this.jobFinished = true;
    }

    private CollectCoordinationResponse<Integer> sendRequest(String str, long j) throws Exception {
        return (CollectCoordinationResponse) this.coordinator.handleCoordinationRequest(new CollectCoordinationRequest(str, j)).get(10000L, TimeUnit.MILLISECONDS);
    }

    private CollectCoordinationResponse<Integer> sendRequestAndGetValidResponse(String str, long j) throws Exception {
        for (int i = 0; i < MAX_RETIRES; i++) {
            CollectCoordinationResponse<Integer> sendRequest = sendRequest(str, j);
            if (sendRequest.getLastCheckpointedOffset() >= 0) {
                return sendRequest;
            }
        }
        throw new RuntimeException("Too many retries in sendRequestAndGetValidResponse");
    }

    private Tuple2<Long, CollectCoordinationResponse<Integer>> getAccumualtorResults() throws Exception {
        List deserializeList = SerializedListAccumulator.deserializeList(this.runtimeContext.getAccumulator(ACCUMULATOR_NAME).getLocalValue(), BytePrimitiveArraySerializer.INSTANCE);
        Assert.assertEquals(1L, deserializeList.size());
        return CollectSinkFunction.deserializeAccumulatorResult((byte[]) deserializeList.get(0));
    }

    private void assertResponseEquals(CollectCoordinationResponse<Integer> collectCoordinationResponse, String str, long j, List<Integer> list) throws IOException {
        Assert.assertEquals(str, collectCoordinationResponse.getVersion());
        Assert.assertEquals(j, collectCoordinationResponse.getLastCheckpointedOffset());
        assertResultsEqual(list, collectCoordinationResponse.getResults(serializer));
    }

    private void assertResultsEqual(List<Integer> list, List<Integer> list2) {
        Assert.assertArrayEquals(list.toArray(new Integer[0]), list2.toArray(new Integer[0]));
    }

    private void assertResultsEqualAfterSort(List<Integer> list, List<Integer> list2) {
        Collections.sort(list);
        Collections.sort(list2);
        assertResultsEqual(list, list2);
    }

    private void assertAccumulatorResult(long j, String str, long j2, List<Integer> list) throws Exception {
        Tuple2<Long, CollectCoordinationResponse<Integer>> accumualtorResults = getAccumualtorResults();
        long longValue = ((Long) accumualtorResults.f0).longValue();
        CollectCoordinationResponse collectCoordinationResponse = (CollectCoordinationResponse) accumualtorResults.f1;
        List<Integer> results = collectCoordinationResponse.getResults(serializer);
        Assert.assertEquals(j, longValue);
        Assert.assertEquals(str, collectCoordinationResponse.getVersion());
        Assert.assertEquals(j2, collectCoordinationResponse.getLastCheckpointedOffset());
        assertResultsEqual(list, results);
    }
}
