package org.apache.beam.sdk.io;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FilterInputStream;
import java.io.FilterOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.WritableByteChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.concurrent.ThreadLocalRandom;
import org.apache.beam.repackaged.core.org.antlr.v4.runtime.tree.xpath.XPath;
import org.apache.beam.repackaged.core.org.apache.commons.compress.compressors.bzip2.BZip2Constants;
import org.apache.beam.repackaged.core.org.apache.commons.lang3.StringUtils;
import org.apache.beam.repackaged.core.org.apache.commons.lang3.SystemUtils;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.io.FileIO;
import org.apache.beam.sdk.io.TFRecordIO;
import org.apache.beam.sdk.io.fs.MatchResult;
import org.apache.beam.sdk.testing.NeedsRunner;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.ParDoTest;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.display.DisplayDataMatchers;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.BaseEncoding;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams;
import org.hamcrest.CoreMatchers;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.hamcrest.core.Is;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.internal.matchers.ThrowableMessageMatcher;
import org.junit.rules.ExpectedException;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
/* loaded from: input_file:org/apache/beam/sdk/io/TFRecordIOTest.class */
public class TFRecordIOTest {
    private static final String FOO_RECORD_BASE64 = "AwAAAAAAAACwmUkOZm9vYYq+/g==";
    private static final String FOO_BAR_RECORD_BASE64 = "AwAAAAAAAACwmUkOZm9vYYq+/gMAAAAAAAAAsJlJDmJhckYA5cg=";
    private static final String BAR_FOO_RECORD_BASE64 = "AwAAAAAAAACwmUkOYmFyRgDlyAMAAAAAAAAAsJlJDmZvb2GKvv4=";
    private static final String[] FOO_RECORDS = {ParDoTest.TimerTests.AnonymousClass4.TIMER_ID};
    private static final String[] FOO_BAR_RECORDS = {ParDoTest.TimerTests.AnonymousClass4.TIMER_ID, "bar"};
    private static final Iterable<String> EMPTY = Collections.emptyList();
    private static final Iterable<String> LARGE = makeLines(1000, 4);
    private static final Iterable<String> LARGE_RECORDS = makeLines(100, BZip2Constants.BASEBLOCKSIZE);

    @Rule
    public TemporaryFolder tempFolder = new TemporaryFolder();

    @Rule
    public TestPipeline readPipeline = TestPipeline.create();

    @Rule
    public TestPipeline writePipeline = TestPipeline.create();

    @Rule
    public ExpectedException expectedException = ExpectedException.none();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/beam/sdk/io/TFRecordIOTest$ByteArrayToString.class */
    public static class ByteArrayToString extends DoFn<byte[], String> {
        ByteArrayToString() {
        }

        @DoFn.ProcessElement
        public void processElement(DoFn<byte[], String>.ProcessContext processContext) {
            processContext.output(new String((byte[]) processContext.element(), StandardCharsets.UTF_8));
        }
    }

    /* loaded from: input_file:org/apache/beam/sdk/io/TFRecordIOTest$PickyReadChannel.class */
    static class PickyReadChannel extends FilterInputStream implements ReadableByteChannel {
        protected PickyReadChannel(InputStream inputStream) {
            super(inputStream);
        }

        @Override // java.io.FilterInputStream, java.io.InputStream
        public int read(byte[] bArr, int i, int i2) {
            throw new UnsupportedOperationException();
        }

        @Override // java.nio.channels.ReadableByteChannel
        public int read(ByteBuffer byteBuffer) throws IOException {
            if (!TFRecordIOTest.maybeThisTime() || !byteBuffer.hasRemaining()) {
                return 0;
            }
            int read = read();
            if (read == -1) {
                return -1;
            }
            byteBuffer.put((byte) read);
            return 1;
        }

        @Override // java.nio.channels.Channel
        public boolean isOpen() {
            throw new UnsupportedOperationException();
        }
    }

    /* loaded from: input_file:org/apache/beam/sdk/io/TFRecordIOTest$PickyWriteChannel.class */
    static class PickyWriteChannel extends FilterOutputStream implements WritableByteChannel {
        @Override // java.io.FilterOutputStream, java.io.OutputStream
        public void write(byte[] bArr, int i, int i2) throws IOException {
            throw new UnsupportedOperationException();
        }

        public PickyWriteChannel(OutputStream outputStream) {
            super(outputStream);
        }

        @Override // java.nio.channels.WritableByteChannel
        public int write(ByteBuffer byteBuffer) throws IOException {
            if (!TFRecordIOTest.maybeThisTime() || !byteBuffer.hasRemaining()) {
                return 0;
            }
            write(byteBuffer.get());
            return 1;
        }

        @Override // java.nio.channels.Channel
        public boolean isOpen() {
            throw new UnsupportedOperationException();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/beam/sdk/io/TFRecordIOTest$StringToByteArray.class */
    public static class StringToByteArray extends DoFn<String, byte[]> {
        StringToByteArray() {
        }

        @DoFn.ProcessElement
        public void processElement(DoFn<String, byte[]>.ProcessContext processContext) {
            processContext.output(((String) processContext.element()).getBytes(StandardCharsets.UTF_8));
        }
    }

    @Test
    public void testReadNamed() {
        this.readPipeline.enableAbandonedNodeEnforcement(false);
        MatcherAssert.assertThat(this.readPipeline.apply(TFRecordIO.read().from("foo.*").withoutValidation()).getName(), CoreMatchers.startsWith("TFRecordIO.Read/Read"));
        MatcherAssert.assertThat(this.readPipeline.apply("MyRead", TFRecordIO.read().from("foo.*").withoutValidation()).getName(), CoreMatchers.startsWith("MyRead/Read"));
    }

    @Test
    public void testReadFilesNamed() {
        this.readPipeline.enableAbandonedNodeEnforcement(false);
        Create.Values of = Create.of(new FileIO.ReadableFile(MatchResult.Metadata.builder().setResourceId(FileSystems.matchNewResource("file", false)).setIsReadSeekEfficient(true).setSizeBytes(1024L).build(), Compression.AUTO), new FileIO.ReadableFile[0]);
        Assert.assertEquals("TFRecordIO.ReadFiles/Read all via FileBasedSource/Read ranges/ParMultiDo(ReadFileRanges).output", this.readPipeline.apply(of).apply(TFRecordIO.readFiles()).getName());
        Assert.assertEquals("MyRead/Read all via FileBasedSource/Read ranges/ParMultiDo(ReadFileRanges).output", this.readPipeline.apply(of).apply("MyRead", TFRecordIO.readFiles()).getName());
    }

    @Test
    public void testReadDisplayData() {
        DisplayData from = DisplayData.from(TFRecordIO.read().from("foo.*").withCompression(Compression.GZIP).withoutValidation());
        MatcherAssert.assertThat(from, DisplayDataMatchers.hasDisplayItem("filePattern", "foo.*"));
        MatcherAssert.assertThat(from, DisplayDataMatchers.hasDisplayItem("compressionType", Compression.GZIP.toString()));
        MatcherAssert.assertThat(from, DisplayDataMatchers.hasDisplayItem("validation", (Boolean) false));
    }

    @Test
    public void testWriteDisplayData() {
        Assume.assumeFalse(SystemUtils.IS_OS_WINDOWS);
        DisplayData from = DisplayData.from(TFRecordIO.write().to("/foo").withSuffix("bar").withShardNameTemplate("-SS-of-NN-").withNumShards(100).withCompression(Compression.GZIP));
        MatcherAssert.assertThat(from, DisplayDataMatchers.hasDisplayItem("filePrefix", "/foo"));
        MatcherAssert.assertThat(from, DisplayDataMatchers.hasDisplayItem("fileSuffix", "bar"));
        MatcherAssert.assertThat(from, DisplayDataMatchers.hasDisplayItem("shardNameTemplate", "-SS-of-NN-"));
        MatcherAssert.assertThat(from, DisplayDataMatchers.hasDisplayItem("numShards", 100L));
        MatcherAssert.assertThat(from, DisplayDataMatchers.hasDisplayItem("compressionType", Compression.GZIP.toString()));
    }

    @Test
    @Category({NeedsRunner.class})
    public void testReadOne() throws Exception {
        runTestRead(FOO_RECORD_BASE64, FOO_RECORDS);
    }

    @Test
    @Category({NeedsRunner.class})
    public void testReadTwo() throws Exception {
        runTestRead(FOO_BAR_RECORD_BASE64, FOO_BAR_RECORDS);
    }

    @Test
    @Category({NeedsRunner.class})
    public void testWriteOne() throws Exception {
        runTestWrite(FOO_RECORDS, FOO_RECORD_BASE64);
    }

    @Test
    @Category({NeedsRunner.class})
    public void testWriteTwo() throws Exception {
        runTestWrite(FOO_BAR_RECORDS, FOO_BAR_RECORD_BASE64, BAR_FOO_RECORD_BASE64);
    }

    @Test
    @Category({NeedsRunner.class})
    public void testReadInvalidRecord() throws Exception {
        this.expectedException.expectMessage("Not a valid TFRecord. Fewer than 12 bytes.");
        runTestRead("bar".getBytes(StandardCharsets.UTF_8), new String[0]);
    }

    @Test
    @Category({NeedsRunner.class})
    public void testReadInvalidLengthMask() throws Exception {
        this.expectedException.expectCause(ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString("Mismatch of length mask")));
        byte[] decode = BaseEncoding.base64().decode(FOO_RECORD_BASE64);
        decode[9] = (byte) (decode[9] + 1);
        runTestRead(decode, FOO_RECORDS);
    }

    @Test
    @Category({NeedsRunner.class})
    public void testReadInvalidDataMask() throws Exception {
        this.expectedException.expectCause(ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString("Mismatch of data mask")));
        byte[] decode = BaseEncoding.base64().decode(FOO_RECORD_BASE64);
        decode[16] = (byte) (decode[16] + 1);
        runTestRead(decode, FOO_RECORDS);
    }

    private void runTestRead(String str, String[] strArr) throws IOException {
        runTestRead(BaseEncoding.base64().decode(str), strArr);
    }

    private void runTestRead(byte[] bArr, String[] strArr) throws IOException {
        File file = Files.createTempFile(this.tempFolder.getRoot().toPath(), "file", ".tfrecords", new FileAttribute[0]).toFile();
        String path = file.getPath();
        FileOutputStream fileOutputStream = new FileOutputStream(file);
        try {
            fileOutputStream.write(bArr);
            fileOutputStream.close();
            PAssert.that(this.readPipeline.apply(TFRecordIO.read().from(path)).apply(ParDo.of(new ByteArrayToString()))).containsInAnyOrder(strArr);
            Compression compression = Compression.AUTO;
            PAssert.that(this.readPipeline.apply("Create_Paths_ReadFiles_" + file, Create.of(file.getPath(), new String[0])).apply("Match_" + file, FileIO.matchAll()).apply("ReadMatches_" + file, FileIO.readMatches().withCompression(compression)).apply("ReadFiles_" + compression.toString(), TFRecordIO.readFiles()).apply("ToString", ParDo.of(new ByteArrayToString()))).containsInAnyOrder(strArr);
            this.readPipeline.run();
        } catch (Throwable th) {
            try {
                fileOutputStream.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private void runTestWrite(String[] strArr, String... strArr2) throws IOException {
        File file = Files.createTempFile(this.tempFolder.getRoot().toPath(), "file", ".tfrecords", new FileAttribute[0]).toFile();
        this.writePipeline.apply(Create.of(Arrays.asList(strArr))).apply(ParDo.of(new StringToByteArray())).apply(TFRecordIO.write().to(file.getPath()).withoutSharding());
        this.writePipeline.run();
        MatcherAssert.assertThat(BaseEncoding.base64().encode(ByteStreams.toByteArray(new FileInputStream(file))), Is.is(Matchers.in(strArr2)));
    }

    @Test
    @Category({NeedsRunner.class})
    public void runTestRoundTrip() throws IOException {
        runTestRoundTrip(LARGE, 10, ".tfrecords", Compression.UNCOMPRESSED, Compression.UNCOMPRESSED);
    }

    @Test
    @Category({NeedsRunner.class})
    public void runTestRoundTripWithEmptyData() throws IOException {
        runTestRoundTrip(EMPTY, 10, ".tfrecords", Compression.UNCOMPRESSED, Compression.UNCOMPRESSED);
    }

    @Test
    @Category({NeedsRunner.class})
    public void runTestRoundTripWithOneShards() throws IOException {
        runTestRoundTrip(LARGE, 1, ".tfrecords", Compression.UNCOMPRESSED, Compression.UNCOMPRESSED);
    }

    @Test
    @Category({NeedsRunner.class})
    public void runTestRoundTripWithSuffix() throws IOException {
        runTestRoundTrip(LARGE, 10, ".suffix", Compression.UNCOMPRESSED, Compression.UNCOMPRESSED);
    }

    @Test
    @Category({NeedsRunner.class})
    public void runTestRoundTripGzip() throws IOException {
        runTestRoundTrip(LARGE, 10, ".tfrecords", Compression.GZIP, Compression.GZIP);
    }

    @Test
    @Category({NeedsRunner.class})
    public void runTestRoundTripZlib() throws IOException {
        runTestRoundTrip(LARGE, 10, ".tfrecords", Compression.DEFLATE, Compression.DEFLATE);
    }

    @Test
    @Category({NeedsRunner.class})
    public void runTestRoundTripUncompressedFilesWithAuto() throws IOException {
        runTestRoundTrip(LARGE, 10, ".tfrecords", Compression.UNCOMPRESSED, Compression.AUTO);
    }

    @Test
    @Category({NeedsRunner.class})
    public void runTestRoundTripGzipFilesWithAuto() throws IOException {
        runTestRoundTrip(LARGE, 10, ".tfrecords", Compression.GZIP, Compression.AUTO);
    }

    @Test
    @Category({NeedsRunner.class})
    public void runTestRoundTripZlibFilesWithAuto() throws IOException {
        runTestRoundTrip(LARGE, 10, ".tfrecords", Compression.DEFLATE, Compression.AUTO);
    }

    @Test
    @Category({NeedsRunner.class})
    public void runTestRoundTripLargeRecords() throws IOException {
        runTestRoundTrip(LARGE_RECORDS, 10, ".tfrecords", Compression.UNCOMPRESSED, Compression.UNCOMPRESSED);
    }

    @Test
    @Category({NeedsRunner.class})
    public void runTestRoundTripLargeRecordsGzip() throws IOException {
        runTestRoundTrip(LARGE_RECORDS, 10, ".tfrecords", Compression.GZIP, Compression.GZIP);
    }

    private void runTestRoundTrip(Iterable<String> iterable, int i, String str, Compression compression, Compression compression2) throws IOException {
        Path createTempDirectory = Files.createTempDirectory(this.tempFolder.getRoot().toPath(), "test-rt", new FileAttribute[0]);
        String path = createTempDirectory.resolve("via-write").toString();
        String path2 = createTempDirectory.resolve("via-sink").toString();
        PCollection apply = this.writePipeline.apply(Create.of(iterable).withCoder(StringUtf8Coder.of())).apply(ParDo.of(new StringToByteArray()));
        apply.apply("Write via TFRecordIO.write", TFRecordIO.write().to(path).withNumShards(i).withSuffix(str).withCompression(compression));
        apply.apply("Write via TFRecordIO.sink", FileIO.write().via(TFRecordIO.sink()).to(createTempDirectory.toString()).withPrefix("via-sink").withSuffix(str).withCompression(compression).withIgnoreWindowing());
        this.writePipeline.run();
        PAssert.that(this.readPipeline.apply("Read written by TFRecordIO.write", TFRecordIO.read().from(path + XPath.WILDCARD).withCompression(compression2)).apply("To string read from write", ParDo.of(new ByteArrayToString()))).containsInAnyOrder(iterable);
        PAssert.that(this.readPipeline.apply("Read written by TFRecordIO.sink", TFRecordIO.read().from(path2 + XPath.WILDCARD).withCompression(compression2)).apply("To string read from sink", ParDo.of(new ByteArrayToString()))).containsInAnyOrder(iterable);
        PAssert.that(this.readPipeline.apply("Create_Paths_ReadFiles_" + path, Create.of(path + XPath.WILDCARD, new String[0])).apply("Match_" + path, FileIO.matchAll()).apply("ReadMatches_" + path, FileIO.readMatches().withCompression(compression2)).apply("ReadFiles written by TFRecordIO.write", TFRecordIO.readFiles()).apply("To string readFiles from write", ParDo.of(new ByteArrayToString()))).containsInAnyOrder(iterable);
        PAssert.that(this.readPipeline.apply("ReadFiles written by TFRecordIO.sink", TFRecordIO.read().from(path2 + XPath.WILDCARD).withCompression(compression2)).apply("To string readFiles from sink", ParDo.of(new ByteArrayToString()))).containsInAnyOrder(iterable);
        this.readPipeline.run();
    }

    private static Iterable<String> makeLines(int i, int i2) {
        ArrayList newArrayList = Lists.newArrayList();
        StringBuilder sb = new StringBuilder();
        for (int i3 = 0; i3 < i2; i3++) {
            sb.append("x");
        }
        String sb2 = sb.toString();
        for (int i4 = 0; i4 < i; i4++) {
            newArrayList.add(sb2 + StringUtils.SPACE + i4);
        }
        return newArrayList;
    }

    static boolean maybeThisTime() {
        return ThreadLocalRandom.current().nextBoolean();
    }

    @Test
    public void testReadFully() throws IOException {
        byte[] bytes = "Hello World".getBytes(StandardCharsets.UTF_8);
        PickyReadChannel pickyReadChannel = new PickyReadChannel(new ByteArrayInputStream(bytes));
        ByteBuffer allocate = ByteBuffer.allocate(bytes.length);
        TFRecordIO.TFRecordCodec.readFully(pickyReadChannel, allocate);
        Assert.assertArrayEquals(bytes, allocate.array());
    }

    @Test
    public void testReadFullyFail() throws IOException {
        byte[] bytes = "Hello Wo".getBytes(StandardCharsets.UTF_8);
        PickyReadChannel pickyReadChannel = new PickyReadChannel(new ByteArrayInputStream(bytes));
        ByteBuffer allocate = ByteBuffer.allocate(bytes.length + 1);
        this.expectedException.expect(IOException.class);
        this.expectedException.expectMessage("expected 9, but got 8");
        TFRecordIO.TFRecordCodec.readFully(pickyReadChannel, allocate);
    }

    @Test
    public void testWriteFully() throws IOException {
        byte[] bytes = "Hello World".getBytes(StandardCharsets.UTF_8);
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        TFRecordIO.TFRecordCodec.writeFully(new PickyWriteChannel(byteArrayOutputStream), ByteBuffer.wrap(bytes));
        Assert.assertArrayEquals(bytes, byteArrayOutputStream.toByteArray());
    }

    @Test
    public void testTFRecordCodec() throws IOException {
        Base64.Decoder decoder = Base64.getDecoder();
        TFRecordIO.TFRecordCodec tFRecordCodec = new TFRecordIO.TFRecordCodec();
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        PickyWriteChannel pickyWriteChannel = new PickyWriteChannel(byteArrayOutputStream);
        tFRecordCodec.write(pickyWriteChannel, ParDoTest.TimerTests.AnonymousClass4.TIMER_ID.getBytes(StandardCharsets.UTF_8));
        Assert.assertArrayEquals(decoder.decode(FOO_RECORD_BASE64), byteArrayOutputStream.toByteArray());
        tFRecordCodec.write(pickyWriteChannel, "bar".getBytes(StandardCharsets.UTF_8));
        Assert.assertArrayEquals(decoder.decode(FOO_BAR_RECORD_BASE64), byteArrayOutputStream.toByteArray());
        PickyReadChannel pickyReadChannel = new PickyReadChannel(new ByteArrayInputStream(byteArrayOutputStream.toByteArray()));
        byte[] read = tFRecordCodec.read(pickyReadChannel);
        byte[] read2 = tFRecordCodec.read(pickyReadChannel);
        Assert.assertNull(tFRecordCodec.read(pickyReadChannel));
        Assert.assertEquals(ParDoTest.TimerTests.AnonymousClass4.TIMER_ID, new String(read, StandardCharsets.UTF_8));
        Assert.assertEquals("bar", new String(read2, StandardCharsets.UTF_8));
    }
}
