/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.io;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import org.apache.beam.repackaged.beam_sdks_java_core.com.google.common.base.Charsets;
import org.apache.beam.repackaged.beam_sdks_java_core.com.google.common.collect.Lists;
import org.apache.beam.repackaged.beam_sdks_java_core.com.google.common.io.BaseEncoding;
import org.apache.beam.repackaged.beam_sdks_java_core.com.google.common.io.ByteStreams;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.io.Compression;
import org.apache.beam.sdk.io.FileIO;
import org.apache.beam.sdk.io.TFRecordIO;
import org.apache.beam.sdk.testing.NeedsRunner;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.display.DisplayDataMatchers;
import org.apache.beam.sdk.transforms.display.HasDisplayData;
import org.apache.beam.sdk.values.PCollection;
import org.hamcrest.Matcher;
import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.rules.ExpectedException;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(value=JUnit4.class)
public class TFRecordIOTest {
    private static final String FOO_RECORD_BASE64 = "AwAAAAAAAACwmUkOZm9vYYq+/g==";
    private static final String FOO_BAR_RECORD_BASE64 = "AwAAAAAAAACwmUkOZm9vYYq+/gMAAAAAAAAAsJlJDmJhckYA5cg=";
    private static final String BAR_FOO_RECORD_BASE64 = "AwAAAAAAAACwmUkOYmFyRgDlyAMAAAAAAAAAsJlJDmZvb2GKvv4=";
    private static final String[] FOO_RECORDS = new String[]{"foo"};
    private static final String[] FOO_BAR_RECORDS = new String[]{"foo", "bar"};
    private static final Iterable<String> EMPTY = Collections.emptyList();
    private static final Iterable<String> LARGE = TFRecordIOTest.makeLines(1000);
    @Rule
    public TemporaryFolder tempFolder = new TemporaryFolder();
    @Rule
    public TestPipeline readPipeline = TestPipeline.create();
    @Rule
    public TestPipeline writePipeline = TestPipeline.create();
    @Rule
    public ExpectedException expectedException = ExpectedException.none();

    @Test
    public void testReadNamed() {
        this.writePipeline.enableAbandonedNodeEnforcement(false);
        Assert.assertEquals((Object)"TFRecordIO.Read/Read.out", (Object)((PCollection)this.writePipeline.apply((PTransform)TFRecordIO.read().from("foo.*").withoutValidation())).getName());
        Assert.assertEquals((Object)"MyRead/Read.out", (Object)((PCollection)this.writePipeline.apply("MyRead", (PTransform)TFRecordIO.read().from("foo.*").withoutValidation())).getName());
    }

    @Test
    public void testReadDisplayData() {
        TFRecordIO.Read read = TFRecordIO.read().from("foo.*").withCompression(Compression.GZIP).withoutValidation();
        DisplayData displayData = DisplayData.from((HasDisplayData)read);
        Assert.assertThat((Object)displayData, DisplayDataMatchers.hasDisplayItem("filePattern", "foo.*"));
        Assert.assertThat((Object)displayData, DisplayDataMatchers.hasDisplayItem("compressionType", Compression.GZIP.toString()));
        Assert.assertThat((Object)displayData, DisplayDataMatchers.hasDisplayItem("validation", false));
    }

    @Test
    public void testWriteDisplayData() {
        TFRecordIO.Write write = TFRecordIO.write().to("/foo").withSuffix("bar").withShardNameTemplate("-SS-of-NN-").withNumShards(100).withCompression(Compression.GZIP);
        DisplayData displayData = DisplayData.from((HasDisplayData)write);
        Assert.assertThat((Object)displayData, DisplayDataMatchers.hasDisplayItem("filePrefix", "/foo"));
        Assert.assertThat((Object)displayData, DisplayDataMatchers.hasDisplayItem("fileSuffix", "bar"));
        Assert.assertThat((Object)displayData, DisplayDataMatchers.hasDisplayItem("shardNameTemplate", "-SS-of-NN-"));
        Assert.assertThat((Object)displayData, DisplayDataMatchers.hasDisplayItem("numShards", 100L));
        Assert.assertThat((Object)displayData, DisplayDataMatchers.hasDisplayItem("compressionType", Compression.GZIP.toString()));
    }

    @Test
    @Category(value={NeedsRunner.class})
    public void testReadOne() throws Exception {
        this.runTestRead(FOO_RECORD_BASE64, FOO_RECORDS);
    }

    @Test
    @Category(value={NeedsRunner.class})
    public void testReadTwo() throws Exception {
        this.runTestRead(FOO_BAR_RECORD_BASE64, FOO_BAR_RECORDS);
    }

    @Test
    @Category(value={NeedsRunner.class})
    public void testWriteOne() throws Exception {
        this.runTestWrite(FOO_RECORDS, FOO_RECORD_BASE64);
    }

    @Test
    @Category(value={NeedsRunner.class})
    public void testWriteTwo() throws Exception {
        this.runTestWrite(FOO_BAR_RECORDS, FOO_BAR_RECORD_BASE64, BAR_FOO_RECORD_BASE64);
    }

    @Test
    @Category(value={NeedsRunner.class})
    public void testReadInvalidRecord() throws Exception {
        this.expectedException.expect(IllegalStateException.class);
        this.expectedException.expectMessage("Not a valid TFRecord. Fewer than 12 bytes.");
        System.out.println("abr".getBytes(Charsets.UTF_8).length);
        this.runTestRead("bar".getBytes(Charsets.UTF_8), new String[0]);
    }

    @Test
    @Category(value={NeedsRunner.class})
    public void testReadInvalidLengthMask() throws Exception {
        this.expectedException.expect(IllegalStateException.class);
        this.expectedException.expectMessage("Mismatch of length mask");
        byte[] data = BaseEncoding.base64().decode(FOO_RECORD_BASE64);
        data[9] = (byte)(data[9] + 1);
        this.runTestRead(data, FOO_RECORDS);
    }

    @Test
    @Category(value={NeedsRunner.class})
    public void testReadInvalidDataMask() throws Exception {
        this.expectedException.expect(IllegalStateException.class);
        this.expectedException.expectMessage("Mismatch of data mask");
        byte[] data = BaseEncoding.base64().decode(FOO_RECORD_BASE64);
        data[16] = (byte)(data[16] + 1);
        this.runTestRead(data, FOO_RECORDS);
    }

    private void runTestRead(String base64, String[] expected) throws IOException {
        this.runTestRead(BaseEncoding.base64().decode(base64), expected);
    }

    private void runTestRead(byte[] data, String[] expected) throws IOException {
        File tmpFile = Files.createTempFile(this.tempFolder.getRoot().toPath(), "file", ".tfrecords", new FileAttribute[0]).toFile();
        String filename = tmpFile.getPath();
        try (FileOutputStream fos = new FileOutputStream(tmpFile);){
            fos.write(data);
        }
        TFRecordIO.Read read = TFRecordIO.read().from(filename);
        PCollection output = (PCollection)((PCollection)this.writePipeline.apply((PTransform)read)).apply((PTransform)ParDo.of((DoFn)new ByteArrayToString()));
        PAssert.that((PCollection)output).containsInAnyOrder((Object[])expected);
        this.writePipeline.run();
    }

    private void runTestWrite(String[] elems, String ... base64) throws IOException {
        File tmpFile = Files.createTempFile(this.tempFolder.getRoot().toPath(), "file", ".tfrecords", new FileAttribute[0]).toFile();
        String filename = tmpFile.getPath();
        PCollection input = (PCollection)((PCollection)this.writePipeline.apply((PTransform)Create.of(Arrays.asList(elems)))).apply((PTransform)ParDo.of((DoFn)new StringToByteArray()));
        TFRecordIO.Write write = TFRecordIO.write().to(filename).withoutSharding();
        input.apply((PTransform)write);
        this.writePipeline.run();
        FileInputStream fis = new FileInputStream(tmpFile);
        String written = BaseEncoding.base64().encode(ByteStreams.toByteArray(fis));
        Assert.assertThat((Object)written, (Matcher)Matchers.isIn((Object[])base64));
    }

    @Test
    @Category(value={NeedsRunner.class})
    public void runTestRoundTrip() throws IOException {
        this.runTestRoundTrip(LARGE, 10, ".tfrecords", Compression.UNCOMPRESSED, Compression.UNCOMPRESSED);
    }

    @Test
    @Category(value={NeedsRunner.class})
    public void runTestRoundTripWithEmptyData() throws IOException {
        this.runTestRoundTrip(EMPTY, 10, ".tfrecords", Compression.UNCOMPRESSED, Compression.UNCOMPRESSED);
    }

    @Test
    @Category(value={NeedsRunner.class})
    public void runTestRoundTripWithOneShards() throws IOException {
        this.runTestRoundTrip(LARGE, 1, ".tfrecords", Compression.UNCOMPRESSED, Compression.UNCOMPRESSED);
    }

    @Test
    @Category(value={NeedsRunner.class})
    public void runTestRoundTripWithSuffix() throws IOException {
        this.runTestRoundTrip(LARGE, 10, ".suffix", Compression.UNCOMPRESSED, Compression.UNCOMPRESSED);
    }

    @Test
    @Category(value={NeedsRunner.class})
    public void runTestRoundTripGzip() throws IOException {
        this.runTestRoundTrip(LARGE, 10, ".tfrecords", Compression.GZIP, Compression.GZIP);
    }

    @Test
    @Category(value={NeedsRunner.class})
    public void runTestRoundTripZlib() throws IOException {
        this.runTestRoundTrip(LARGE, 10, ".tfrecords", Compression.DEFLATE, Compression.DEFLATE);
    }

    @Test
    @Category(value={NeedsRunner.class})
    public void runTestRoundTripUncompressedFilesWithAuto() throws IOException {
        this.runTestRoundTrip(LARGE, 10, ".tfrecords", Compression.UNCOMPRESSED, Compression.AUTO);
    }

    @Test
    @Category(value={NeedsRunner.class})
    public void runTestRoundTripGzipFilesWithAuto() throws IOException {
        this.runTestRoundTrip(LARGE, 10, ".tfrecords", Compression.GZIP, Compression.AUTO);
    }

    @Test
    @Category(value={NeedsRunner.class})
    public void runTestRoundTripZlibFilesWithAuto() throws IOException {
        this.runTestRoundTrip(LARGE, 10, ".tfrecords", Compression.DEFLATE, Compression.AUTO);
    }

    private void runTestRoundTrip(Iterable<String> elems, int numShards, String suffix, Compression writeCompression, Compression readCompression) throws IOException {
        Path baseDir = Files.createTempDirectory(this.tempFolder.getRoot().toPath(), "test-rt", new FileAttribute[0]);
        String outputNameViaWrite = "via-write";
        String baseFilenameViaWrite = baseDir.resolve(outputNameViaWrite).toString();
        String outputNameViaSink = "via-sink";
        String baseFilenameViaSink = baseDir.resolve(outputNameViaSink).toString();
        PCollection data = (PCollection)((PCollection)this.writePipeline.apply((PTransform)Create.of(elems).withCoder((Coder)StringUtf8Coder.of()))).apply((PTransform)ParDo.of((DoFn)new StringToByteArray()));
        data.apply("Write via TFRecordIO.write", (PTransform)TFRecordIO.write().to(baseFilenameViaWrite).withNumShards(numShards).withSuffix(suffix).withCompression(writeCompression));
        data.apply("Write via TFRecordIO.sink", (PTransform)FileIO.write().via((FileIO.Sink)TFRecordIO.sink()).to(baseDir.toString()).withPrefix(outputNameViaSink).withSuffix(suffix).withCompression(writeCompression).withIgnoreWindowing());
        this.writePipeline.run();
        PAssert.that((PCollection)((PCollection)((PCollection)this.readPipeline.apply("Read written by TFRecordIO.write", (PTransform)TFRecordIO.read().from(baseFilenameViaWrite + "*").withCompression(readCompression))).apply("To string first", (PTransform)ParDo.of((DoFn)new ByteArrayToString())))).containsInAnyOrder(elems);
        PAssert.that((PCollection)((PCollection)((PCollection)this.readPipeline.apply("Read written by TFRecordIO.sink", (PTransform)TFRecordIO.read().from(baseFilenameViaSink + "*").withCompression(readCompression))).apply("To string second", (PTransform)ParDo.of((DoFn)new ByteArrayToString())))).containsInAnyOrder(elems);
        this.readPipeline.run();
    }

    private static Iterable<String> makeLines(int n) {
        ArrayList<String> ret = Lists.newArrayList();
        for (int i = 0; i < n; ++i) {
            ret.add("word" + i);
        }
        return ret;
    }

    static class StringToByteArray
    extends DoFn<String, byte[]> {
        StringToByteArray() {
        }

        @DoFn.ProcessElement
        public void processElement(DoFn.ProcessContext c) {
            c.output((Object)((String)c.element()).getBytes(Charsets.UTF_8));
        }
    }

    static class ByteArrayToString
    extends DoFn<byte[], String> {
        ByteArrayToString() {
        }

        @DoFn.ProcessElement
        public void processElement(DoFn.ProcessContext c) {
            c.output((Object)new String((byte[])c.element(), Charsets.UTF_8));
        }
    }
}

