package org.apache.flink.runtime.state;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.flink.core.memory.HeapMemorySegment;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.CheckpointType;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateReaderImpl;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriterImpl;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.memory.NonPersistentMetadataCheckpointStorageLocation;
import org.apache.flink.util.CloseableIterator;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.function.BiFunctionWithException;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/runtime/state/ChannelPersistenceITCase.class */
public class ChannelPersistenceITCase {
    private static final Random RANDOM = new Random(System.currentTimeMillis());

    @Test
    public void testReadWritten() throws Exception {
        InputChannelInfo inputChannelInfo = new InputChannelInfo(2, 3);
        byte[] randomBytes = randomBytes(1024);
        ResultSubpartitionInfo resultSubpartitionInfo = new ResultSubpartitionInfo(4, 5);
        byte[] randomBytes2 = randomBytes(1024);
        ChannelStateWriter.ChannelStateWriteResult write = write(1L, Collections.singletonMap(inputChannelInfo, randomBytes), Collections.singletonMap(resultSubpartitionInfo, randomBytes2));
        Assert.assertArrayEquals(randomBytes, read(toTaskStateSnapshot(write), randomBytes.length, (channelStateReader, memorySegment) -> {
            return channelStateReader.readInputData(inputChannelInfo, new NetworkBuffer(memorySegment, FreeingBufferRecycler.INSTANCE));
        }));
        Assert.assertArrayEquals(randomBytes2, read(toTaskStateSnapshot(write), randomBytes2.length, (channelStateReader2, memorySegment2) -> {
            return channelStateReader2.readOutputData(resultSubpartitionInfo, new BufferBuilder(memorySegment2, FreeingBufferRecycler.INSTANCE));
        }));
    }

    private byte[] randomBytes(int i) {
        byte[] bArr = new byte[i];
        RANDOM.nextBytes(bArr);
        return bArr;
    }

    private ChannelStateWriter.ChannelStateWriteResult write(long j, Map<InputChannelInfo, byte[]> map, Map<ResultSubpartitionInfo, byte[]> map2) throws Exception {
        int sizeOfBytes = sizeOfBytes(map) + sizeOfBytes(map2) + 16;
        Map wrapWithBuffers = wrapWithBuffers(map);
        Map wrapWithBuffers2 = wrapWithBuffers(map2);
        ChannelStateWriterImpl channelStateWriterImpl = new ChannelStateWriterImpl("test", getStreamFactoryFactory(sizeOfBytes));
        Throwable th = null;
        try {
            try {
                channelStateWriterImpl.open();
                channelStateWriterImpl.start(j, new CheckpointOptions(CheckpointType.CHECKPOINT, new CheckpointStorageLocationReference("poly".getBytes())));
                for (Map.Entry entry : wrapWithBuffers.entrySet()) {
                    channelStateWriterImpl.addInputData(j, (InputChannelInfo) entry.getKey(), -2, CloseableIterator.ofElements((v0) -> {
                        v0.recycleBuffer();
                    }, new Buffer[]{(Buffer) entry.getValue()}));
                }
                channelStateWriterImpl.finishInput(j);
                for (Map.Entry entry2 : wrapWithBuffers2.entrySet()) {
                    channelStateWriterImpl.addOutputData(j, (ResultSubpartitionInfo) entry2.getKey(), -2, new Buffer[]{(Buffer) entry2.getValue()});
                }
                channelStateWriterImpl.finishOutput(j);
                ChannelStateWriter.ChannelStateWriteResult andRemoveWriteResult = channelStateWriterImpl.getAndRemoveWriteResult(j);
                andRemoveWriteResult.getResultSubpartitionStateHandles().join();
                if (channelStateWriterImpl != null) {
                    if (0 != 0) {
                        try {
                            channelStateWriterImpl.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        channelStateWriterImpl.close();
                    }
                }
                return andRemoveWriteResult;
            } finally {
            }
        } catch (Throwable th3) {
            if (channelStateWriterImpl != null) {
                if (th != null) {
                    try {
                        channelStateWriterImpl.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    channelStateWriterImpl.close();
                }
            }
            throw th3;
        }
    }

    public static CheckpointStorageWorkerView getStreamFactoryFactory() {
        return getStreamFactoryFactory(42);
    }

    public static CheckpointStorageWorkerView getStreamFactoryFactory(final int i) {
        return new CheckpointStorageWorkerView() { // from class: org.apache.flink.runtime.state.ChannelPersistenceITCase.1
            public CheckpointStreamFactory resolveCheckpointStorageLocation(long j, CheckpointStorageLocationReference checkpointStorageLocationReference) {
                return new NonPersistentMetadataCheckpointStorageLocation(i);
            }

            public CheckpointStreamFactory.CheckpointStateOutputStream createTaskOwnedStateStream() {
                throw new UnsupportedOperationException();
            }
        };
    }

    private byte[] read(TaskStateSnapshot taskStateSnapshot, int i, BiFunctionWithException<ChannelStateReader, MemorySegment, ChannelStateReader.ReadResult, Exception> biFunctionWithException) throws Exception {
        byte[] bArr = new byte[i];
        HeapMemorySegment wrap = HeapMemorySegment.FACTORY.wrap(bArr);
        try {
            Preconditions.checkState(ChannelStateReader.ReadResult.NO_MORE_DATA == biFunctionWithException.apply(new ChannelStateReaderImpl(taskStateSnapshot), wrap));
            wrap.free();
            return bArr;
        } catch (Throwable th) {
            wrap.free();
            throw th;
        }
    }

    private TaskStateSnapshot toTaskStateSnapshot(ChannelStateWriter.ChannelStateWriteResult channelStateWriteResult) throws Exception {
        return new TaskStateSnapshot(Collections.singletonMap(new OperatorID(), new OperatorSubtaskState(StateObjectCollection.empty(), StateObjectCollection.empty(), StateObjectCollection.empty(), StateObjectCollection.empty(), new StateObjectCollection((Collection) channelStateWriteResult.getInputChannelStateHandles().get()), new StateObjectCollection((Collection) channelStateWriteResult.getResultSubpartitionStateHandles().get()))));
    }

    private <C> List<C> collect(Collection<StateObject> collection, Class<C> cls) {
        Stream<StateObject> stream = collection.stream();
        cls.getClass();
        return (List) stream.filter((v1) -> {
            return r1.isInstance(v1);
        }).map(stateObject -> {
            return stateObject;
        }).collect(Collectors.toList());
    }

    private static int sizeOfBytes(Map<?, byte[]> map) {
        return map.values().stream().mapToInt(bArr -> {
            return bArr.length;
        }).sum();
    }

    private <K> Map<K, Buffer> wrapWithBuffers(Map<K, byte[]> map) {
        return (Map) map.entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, entry -> {
            return wrapWithBuffer((byte[]) entry.getValue());
        }));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Buffer wrapWithBuffer(byte[] bArr) {
        NetworkBuffer networkBuffer = new NetworkBuffer(HeapMemorySegment.FACTORY.allocateUnpooledSegment(bArr.length, (Object) null), FreeingBufferRecycler.INSTANCE);
        networkBuffer.writeBytes(bArr);
        return networkBuffer;
    }
}
