package org.apache.flink.runtime.checkpoint.channel;

import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.stream.Collectors;
import org.apache.flink.core.memory.HeapMemorySegment;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/runtime/checkpoint/channel/ChannelStateReaderImplTest.class */
public class ChannelStateReaderImplTest {
    private static final InputChannelInfo CHANNEL = new InputChannelInfo(1, 2);
    private static final byte[] DATA = generateData(10);
    private ChannelStateReaderImpl reader;

    @Before
    public void init() {
        this.reader = getReader(CHANNEL, DATA);
    }

    @After
    public void tearDown() throws Exception {
        this.reader.close();
    }

    @Test
    public void testDifferentBufferSizes() throws Exception {
        for (int i = 1; i < 2 * DATA.length; i++) {
            ChannelStateReaderImpl reader = getReader(CHANNEL, DATA);
            Throwable th = null;
            try {
                try {
                    readAndVerify(i, CHANNEL, DATA, reader);
                    if (reader != null) {
                        if (0 != 0) {
                            try {
                                reader.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            reader.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                if (reader != null) {
                    if (th != null) {
                        try {
                            reader.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        reader.close();
                    }
                }
                throw th3;
            }
        }
    }

    @Test
    public void testWithOffsets() throws IOException {
        Map<InputChannelStateHandle, byte[]> generateHandlesWithBytes = generateHandlesWithBytes(10, 20);
        ChannelStateReaderImpl channelStateReaderImpl = new ChannelStateReaderImpl(taskStateSnapshot(generateHandlesWithBytes.keySet()), new ChannelStateSerializerImpl());
        for (Map.Entry<InputChannelStateHandle, byte[]> entry : generateHandlesWithBytes.entrySet()) {
            readAndVerify(42, (InputChannelInfo) entry.getKey().getInfo(), entry.getValue(), channelStateReaderImpl);
        }
    }

    @Test(expected = Exception.class)
    public void testReadOnlyOnce() throws IOException {
        this.reader.readInputData(CHANNEL, getBuffer(DATA.length));
        this.reader.readInputData(CHANNEL, getBuffer(DATA.length));
    }

    @Test(expected = IllegalStateException.class)
    public void testReadClosed() throws Exception {
        this.reader.close();
        this.reader.readInputData(CHANNEL, getBuffer(DATA.length));
    }

    @Test
    public void testReadUnknownChannelState() throws IOException {
        Assert.assertEquals(ChannelStateReader.ReadResult.NO_MORE_DATA, this.reader.readInputData(new InputChannelInfo(CHANNEL.getGateIdx() + 1, CHANNEL.getInputChannelIdx() + 1), getBuffer(DATA.length)));
    }

    private TaskStateSnapshot taskStateSnapshot(Collection<InputChannelStateHandle> collection) {
        return new TaskStateSnapshot(Collections.singletonMap(new OperatorID(), new OperatorSubtaskState(StateObjectCollection.empty(), StateObjectCollection.empty(), StateObjectCollection.empty(), StateObjectCollection.empty(), new StateObjectCollection(collection), StateObjectCollection.empty())));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static byte[] generateData(int i) {
        byte[] bArr = new byte[i];
        new Random().nextBytes(bArr);
        return bArr;
    }

    private byte[] toBytes(NetworkBuffer networkBuffer) {
        byte[] bArr = new byte[networkBuffer.readableBytes()];
        networkBuffer.readBytes(bArr);
        return bArr;
    }

    private ChannelStateReaderImpl getReader(InputChannelInfo inputChannelInfo, final byte[] bArr) {
        return new ChannelStateReaderImpl(taskStateSnapshot(Collections.singletonList(new InputChannelStateHandle(inputChannelInfo, new ByteStreamStateHandle("", bArr), Collections.singletonList(0L)))), new ChannelStateSerializerImpl() { // from class: org.apache.flink.runtime.checkpoint.channel.ChannelStateReaderImplTest.1
            public void readHeader(InputStream inputStream) {
            }

            public int readLength(InputStream inputStream) {
                return bArr.length;
            }
        });
    }

    private void readAndVerify(int i, InputChannelInfo inputChannelInfo, byte[] bArr, ChannelStateReader channelStateReader) throws IOException {
        int length = bArr.length;
        int i2 = (length / i) + ((-(length % i)) >>> 31);
        NetworkBuffer buffer = getBuffer(i);
        int i3 = 0;
        while (i3 < i2) {
            try {
                String format = String.format("dataSize=%d, bufferSize=%d, iteration=%d/%d", Integer.valueOf(length), Integer.valueOf(i), Integer.valueOf(i3 + 1), Integer.valueOf(i2));
                boolean z = i3 == i2 - 1;
                Assert.assertEquals(format, z ? ChannelStateReader.ReadResult.NO_MORE_DATA : ChannelStateReader.ReadResult.HAS_MORE_DATA, channelStateReader.readInputData(inputChannelInfo, buffer));
                Assert.assertEquals(format, z ? length - (i * i3) : i, buffer.readableBytes());
                Assert.assertArrayEquals(format, Arrays.copyOfRange(bArr, i3 * i, Math.min(length, (i3 + 1) * i)), toBytes(buffer));
                buffer.resetReaderIndex();
                buffer.resetWriterIndex();
                i3++;
            } finally {
                buffer.release();
            }
        }
    }

    private NetworkBuffer getBuffer(int i) {
        return new NetworkBuffer(HeapMemorySegment.FACTORY.allocateUnpooledSegment(i, (Object) null), FreeingBufferRecycler.INSTANCE);
    }

    private Map<InputChannelStateHandle, byte[]> generateHandlesWithBytes(int i, int i2) throws IOException {
        HashMap hashMap = new HashMap();
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(100);
        DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream);
        ChannelStateSerializerImpl channelStateSerializerImpl = new ChannelStateSerializerImpl();
        channelStateSerializerImpl.writeHeader(dataOutputStream);
        for (int i3 = 0; i3 < i; i3++) {
            hashMap.put(Integer.valueOf(byteArrayOutputStream.size()), writeSomeBytes(i2, dataOutputStream, channelStateSerializerImpl));
        }
        ByteStreamStateHandle byteStreamStateHandle = new ByteStreamStateHandle("", byteArrayOutputStream.toByteArray());
        return (Map) hashMap.entrySet().stream().collect(Collectors.toMap(entry -> {
            return new InputChannelStateHandle(new InputChannelInfo(((Integer) entry.getKey()).intValue(), ((Integer) entry.getKey()).intValue()), byteStreamStateHandle, Collections.singletonList(Long.valueOf(((Integer) entry.getKey()).intValue())));
        }, (v0) -> {
            return v0.getValue();
        }));
    }

    private byte[] writeSomeBytes(int i, DataOutputStream dataOutputStream, ChannelStateSerializer channelStateSerializer) throws IOException {
        byte[] generateData = generateData(i);
        Buffer buffer = getBuffer(i);
        try {
            buffer.writeBytes(generateData);
            channelStateSerializer.writeData(dataOutputStream, new Buffer[]{buffer});
            buffer.release();
            return generateData;
        } catch (Throwable th) {
            buffer.release();
            throw th;
        }
    }
}
