/*
 * Decompiled with CFR 0.152.
 */
package org.apache.arrow.flight;

import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.stream.IntStream;
import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.FlightStatusCode;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.NoOpFlightProducer;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VectorLoader;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.testing.ValueVectorDataPopulator;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

public class TestDoExchange {
    static byte[] EXCHANGE_DO_GET = "do-get".getBytes(StandardCharsets.UTF_8);
    static byte[] EXCHANGE_DO_PUT = "do-put".getBytes(StandardCharsets.UTF_8);
    static byte[] EXCHANGE_ECHO = "echo".getBytes(StandardCharsets.UTF_8);
    static byte[] EXCHANGE_METADATA_ONLY = "only-metadata".getBytes(StandardCharsets.UTF_8);
    static byte[] EXCHANGE_TRANSFORM = "transform".getBytes(StandardCharsets.UTF_8);
    static byte[] EXCHANGE_CANCEL = "cancel".getBytes(StandardCharsets.UTF_8);
    static byte[] EXCHANGE_ERROR = "error".getBytes(StandardCharsets.UTF_8);
    private BufferAllocator allocator;
    private FlightServer server;
    private FlightClient client;

    @BeforeEach
    public void setUp() throws Exception {
        this.allocator = new RootAllocator(Integer.MAX_VALUE);
        Location serverLocation = Location.forGrpcInsecure((String)"localhost", (int)0);
        this.server = FlightServer.builder((BufferAllocator)this.allocator, (Location)serverLocation, (FlightProducer)new Producer(this.allocator)).build();
        this.server.start();
        Location clientLocation = Location.forGrpcInsecure((String)"localhost", (int)this.server.getPort());
        this.client = FlightClient.builder((BufferAllocator)this.allocator, (Location)clientLocation).build();
    }

    @AfterEach
    public void tearDown() throws Exception {
        AutoCloseables.close((AutoCloseable[])new AutoCloseable[]{this.client, this.server, this.allocator});
    }

    @Test
    public void testDoExchangeOnlyMetadata() throws Exception {
        try (FlightClient.ExchangeReaderWriter stream = this.client.doExchange(FlightDescriptor.command((byte[])EXCHANGE_METADATA_ONLY), new CallOption[0]);){
            FlightStream reader = stream.getReader();
            Assertions.assertTrue((boolean)reader.next());
            Assertions.assertFalse((boolean)reader.hasRoot());
            Assertions.assertEquals((int)42, (int)reader.getLatestMetadata().getInt(0L));
            ArrowBuf buf = this.allocator.buffer(4L);
            buf.writeInt(84);
            stream.getWriter().putMetadata(buf);
            Assertions.assertTrue((boolean)reader.next());
            Assertions.assertFalse((boolean)reader.hasRoot());
            Assertions.assertEquals((int)84, (int)reader.getLatestMetadata().getInt(0L));
            stream.getWriter().completed();
            Assertions.assertFalse((boolean)reader.next());
        }
    }

    @Test
    public void testDoExchangeDoGet() throws Exception {
        try (FlightClient.ExchangeReaderWriter stream = this.client.doExchange(FlightDescriptor.command((byte[])EXCHANGE_DO_GET), new CallOption[0]);){
            FlightStream reader = stream.getReader();
            VectorSchemaRoot root = reader.getRoot();
            IntVector iv = (IntVector)root.getVector("a");
            int value = 0;
            while (reader.next()) {
                for (int i = 0; i < root.getRowCount(); ++i) {
                    Assertions.assertFalse((boolean)iv.isNull(i), (String)String.format("Row %d should not be null", value));
                    Assertions.assertEquals((int)value, (int)iv.get(i));
                    ++value;
                }
            }
            Assertions.assertEquals((int)100, (int)value);
        }
    }

    @Test
    public void testDoExchangeDoPut() throws Exception {
        Schema schema = new Schema(Collections.singletonList(Field.nullable((String)"a", (ArrowType)new ArrowType.Int(32, true))));
        try (FlightClient.ExchangeReaderWriter stream = this.client.doExchange(FlightDescriptor.command((byte[])EXCHANGE_DO_PUT), new CallOption[0]);
             VectorSchemaRoot root = VectorSchemaRoot.create((Schema)schema, (BufferAllocator)this.allocator);){
            IntVector iv = (IntVector)root.getVector("a");
            iv.allocateNew();
            stream.getWriter().start(root);
            int counter = 0;
            for (int i = 0; i < 10; ++i) {
                ValueVectorDataPopulator.setVector((IntVector)iv, (Integer[])((Integer[])IntStream.range(0, i).boxed().toArray(Integer[]::new)));
                root.setRowCount(i);
                stream.getWriter().putNext();
                Assertions.assertTrue((boolean)stream.getReader().next());
                Assertions.assertFalse((boolean)stream.getReader().hasRoot());
                ArrowBuf metadata = stream.getReader().getLatestMetadata();
                Assertions.assertEquals((int)(counter += i), (int)metadata.getInt(0L));
            }
            stream.getWriter().completed();
            while (stream.getReader().next()) {
            }
        }
    }

    @Test
    public void testDoExchangeEcho() throws Exception {
        Schema schema = new Schema(Collections.singletonList(Field.nullable((String)"a", (ArrowType)new ArrowType.Int(32, true))));
        try (FlightClient.ExchangeReaderWriter stream = this.client.doExchange(FlightDescriptor.command((byte[])EXCHANGE_ECHO), new CallOption[0]);
             VectorSchemaRoot root = VectorSchemaRoot.create((Schema)schema, (BufferAllocator)this.allocator);){
            FlightStream reader = stream.getReader();
            ArrowBuf buf = this.allocator.buffer(4L);
            buf.writeInt(42);
            stream.getWriter().putMetadata(buf);
            buf = this.allocator.buffer(4L);
            buf.writeInt(84);
            stream.getWriter().putMetadata(buf);
            Assertions.assertTrue((boolean)reader.next());
            Assertions.assertFalse((boolean)reader.hasRoot());
            Assertions.assertEquals((int)42, (int)reader.getLatestMetadata().getInt(0L));
            Assertions.assertTrue((boolean)reader.next());
            Assertions.assertFalse((boolean)reader.hasRoot());
            Assertions.assertEquals((int)84, (int)reader.getLatestMetadata().getInt(0L));
            IntVector iv = (IntVector)root.getVector("a");
            iv.allocateNew();
            stream.getWriter().start(root);
            for (int i = 0; i < 10; ++i) {
                iv.setSafe(0, i);
                root.setRowCount(1);
                stream.getWriter().putNext();
                Assertions.assertTrue((boolean)reader.next());
                Assertions.assertNull((Object)reader.getLatestMetadata());
                Assertions.assertEquals((Object)root.getSchema(), (Object)reader.getSchema());
                Assertions.assertEquals((int)i, (int)((IntVector)reader.getRoot().getVector("a")).get(0));
            }
            stream.getWriter().completed();
            Assertions.assertFalse((boolean)reader.next(), (String)"We should not be waiting for any messages");
        }
    }

    @Test
    public void testTransform() throws Exception {
        Schema schema = new Schema(Arrays.asList(Field.nullable((String)"a", (ArrowType)new ArrowType.Int(32, true)), Field.nullable((String)"b", (ArrowType)new ArrowType.Int(32, true))));
        try (FlightClient.ExchangeReaderWriter stream = this.client.doExchange(FlightDescriptor.command((byte[])EXCHANGE_TRANSFORM), new CallOption[0]);){
            IntVector vec;
            int batchIndex;
            FlightStream reader = stream.getReader();
            FlightClient.ClientStreamListener writer = stream.getWriter();
            try (VectorSchemaRoot root = VectorSchemaRoot.create((Schema)schema, (BufferAllocator)this.allocator);){
                writer.start(root);
                for (batchIndex = 0; batchIndex < 10; ++batchIndex) {
                    for (FieldVector rawVec : root.getFieldVectors()) {
                        vec = (IntVector)rawVec;
                        ValueVectorDataPopulator.setVector((IntVector)vec, (Integer[])((Integer[])IntStream.range(0, batchIndex).boxed().toArray(Integer[]::new)));
                    }
                    root.setRowCount(batchIndex);
                    writer.putNext();
                }
            }
            writer.completed();
            Assertions.assertEquals((Object)schema, (Object)reader.getSchema());
            root = reader.getRoot();
            for (batchIndex = 0; batchIndex < 10; ++batchIndex) {
                Assertions.assertTrue((boolean)reader.next(), (String)("Didn't receive batch #" + batchIndex));
                Assertions.assertEquals((int)batchIndex, (int)root.getRowCount());
                for (FieldVector rawVec : root.getFieldVectors()) {
                    vec = (IntVector)rawVec;
                    for (int row = 0; row < batchIndex; ++row) {
                        Assertions.assertEquals((int)(2 * row), (int)vec.get(row));
                    }
                }
            }
            Assertions.assertTrue((boolean)reader.next(), (String)"There should be one extra message");
            Assertions.assertEquals((int)10, (int)reader.getLatestMetadata().getInt(0L));
            Assertions.assertFalse((boolean)reader.next(), (String)"There should be no more data");
        }
    }

    @Test
    public void testTransformZeroCopy() throws Exception {
        int rowsPerBatch = 4096;
        Schema schema = new Schema(Arrays.asList(Field.nullable((String)"a", (ArrowType)new ArrowType.Int(32, true)), Field.nullable((String)"b", (ArrowType)new ArrowType.Int(32, true))));
        try (FlightClient.ExchangeReaderWriter stream = this.client.doExchange(FlightDescriptor.command((byte[])EXCHANGE_TRANSFORM), new CallOption[0]);){
            int row;
            IntVector vec;
            int batchIndex;
            FlightStream reader = stream.getReader();
            FlightClient.ClientStreamListener writer = stream.getWriter();
            try (VectorSchemaRoot root = VectorSchemaRoot.create((Schema)schema, (BufferAllocator)this.allocator);){
                writer.start(root);
                writer.setUseZeroCopy(true);
                for (batchIndex = 0; batchIndex < 100; ++batchIndex) {
                    for (FieldVector rawVec : root.getFieldVectors()) {
                        vec = (IntVector)rawVec;
                        for (row = 0; row < 4096; ++row) {
                            vec.setSafe(row, batchIndex + row);
                        }
                    }
                    root.setRowCount(4096);
                    writer.putNext();
                    root.allocateNew();
                }
            }
            writer.completed();
            Assertions.assertEquals((Object)schema, (Object)reader.getSchema());
            root = reader.getRoot();
            for (batchIndex = 0; batchIndex < 100; ++batchIndex) {
                Assertions.assertTrue((boolean)reader.next(), (String)("Didn't receive batch #" + batchIndex));
                Assertions.assertEquals((int)4096, (int)root.getRowCount());
                for (FieldVector rawVec : root.getFieldVectors()) {
                    vec = (IntVector)rawVec;
                    for (row = 0; row < 4096; ++row) {
                        Assertions.assertEquals((int)(2 * (batchIndex + row)), (int)vec.get(row));
                    }
                }
            }
            Assertions.assertTrue((boolean)reader.next(), (String)"There should be one extra message");
            Assertions.assertEquals((int)100, (int)reader.getLatestMetadata().getInt(0L));
            Assertions.assertFalse((boolean)reader.next(), (String)"There should be no more data");
        }
    }

    @Test
    public void testServerCancel() throws Exception {
        try (FlightClient.ExchangeReaderWriter stream = this.client.doExchange(FlightDescriptor.command((byte[])EXCHANGE_CANCEL), new CallOption[0]);){
            FlightStream reader = stream.getReader();
            FlightClient.ClientStreamListener writer = stream.getWriter();
            FlightRuntimeException fre = (FlightRuntimeException)Assertions.assertThrows(FlightRuntimeException.class, () -> ((FlightStream)reader).next());
            Assertions.assertEquals((Object)FlightStatusCode.CANCELLED, (Object)fre.status().code());
            Assertions.assertEquals((Object)"expected", (Object)fre.status().description());
            writer.putMetadata(this.allocator.getEmpty());
        }
    }

    @Test
    public void testServerCancelLeak() throws Exception {
        try (FlightClient.ExchangeReaderWriter stream = this.client.doExchange(FlightDescriptor.command((byte[])EXCHANGE_CANCEL), new CallOption[0]);){
            FlightStream reader = stream.getReader();
            FlightClient.ClientStreamListener writer = stream.getWriter();
            try (VectorSchemaRoot root = VectorSchemaRoot.create((Schema)Producer.SCHEMA, (BufferAllocator)this.allocator);){
                writer.start(root);
                IntVector ints = (IntVector)root.getVector("a");
                for (int i = 0; i < 128; ++i) {
                    for (int row = 0; row < 1024; ++row) {
                        ints.setSafe(row, row);
                    }
                    root.setRowCount(1024);
                    writer.putNext();
                }
            }
            FlightRuntimeException fre = (FlightRuntimeException)Assertions.assertThrows(FlightRuntimeException.class, () -> ((FlightStream)reader).next());
            Assertions.assertEquals((Object)FlightStatusCode.CANCELLED, (Object)fre.status().code());
            Assertions.assertEquals((Object)"expected", (Object)fre.status().description());
        }
    }

    @Test
    @Disabled
    public void testClientCancel() throws Exception {
        try (FlightClient.ExchangeReaderWriter stream = this.client.doExchange(FlightDescriptor.command((byte[])EXCHANGE_DO_GET), new CallOption[0]);){
            FlightStream reader = stream.getReader();
            reader.cancel("", null);
            reader.cancel("", null);
        }
    }

    @Test
    public void testDoExchangeError() throws Exception {
        Schema schema = new Schema(Collections.singletonList(Field.nullable((String)"a", (ArrowType)new ArrowType.Int(32, true))));
        try (FlightClient.ExchangeReaderWriter stream = this.client.doExchange(FlightDescriptor.command((byte[])EXCHANGE_ERROR), new CallOption[0]);
             VectorSchemaRoot root = VectorSchemaRoot.create((Schema)schema, (BufferAllocator)this.allocator);){
            FlightStream reader = stream.getReader();
            IntVector iv = (IntVector)root.getVector("a");
            iv.allocateNew();
            stream.getWriter().start(root);
            for (int i = 0; i < 10; ++i) {
                iv.setSafe(0, i);
                root.setRowCount(1);
                stream.getWriter().putNext();
                Assertions.assertTrue((boolean)reader.next());
                Assertions.assertEquals((Object)root.getSchema(), (Object)reader.getSchema());
                Assertions.assertEquals((int)i, (int)((IntVector)reader.getRoot().getVector("a")).get(0));
            }
            stream.getWriter().completed();
            FlightRuntimeException fre = (FlightRuntimeException)Assertions.assertThrows(FlightRuntimeException.class, () -> ((FlightClient.ExchangeReaderWriter)stream).getResult());
            Assertions.assertEquals((Object)"error completing exchange", (Object)fre.status().description());
        }
    }

    @Test
    public void testClientClose() throws Exception {
        try (FlightClient.ExchangeReaderWriter stream = this.client.doExchange(FlightDescriptor.command((byte[])EXCHANGE_DO_GET), new CallOption[0]);){
            Assertions.assertEquals((Object)Producer.SCHEMA, (Object)stream.getReader().getSchema());
        }
        this.allocator = null;
        this.client = null;
    }

    @Test
    public void testCloseWithMetadata() throws Exception {
        try (FlightClient.ExchangeReaderWriter stream = this.client.doExchange(FlightDescriptor.command((byte[])EXCHANGE_METADATA_ONLY), new CallOption[0]);){
            FlightStream reader = stream.getReader();
            Assertions.assertTrue((boolean)reader.next());
            Assertions.assertFalse((boolean)reader.hasRoot());
            Assertions.assertEquals((int)42, (int)reader.getLatestMetadata().getInt(0L));
            ArrowBuf buf = this.allocator.buffer(4L);
            buf.writeInt(84);
            stream.getWriter().putMetadata(buf);
            Assertions.assertTrue((boolean)reader.next());
            Assertions.assertFalse((boolean)reader.hasRoot());
            Assertions.assertEquals((int)84, (int)reader.getLatestMetadata().getInt(0L));
            stream.getWriter().completed();
            stream.getResult();
            stream.getReader().close();
        }
    }

    static class Producer
    extends NoOpFlightProducer {
        static final Schema SCHEMA = new Schema(Collections.singletonList(Field.nullable((String)"a", (ArrowType)new ArrowType.Int(32, true))));
        private final BufferAllocator allocator;

        Producer(BufferAllocator allocator) {
            this.allocator = allocator;
        }

        public void doExchange(FlightProducer.CallContext context, FlightStream reader, FlightProducer.ServerStreamListener writer) {
            if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_METADATA_ONLY)) {
                this.metadataOnly(context, reader, writer);
            } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_DO_GET)) {
                this.doGet(context, reader, writer);
            } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_DO_PUT)) {
                this.doPut(context, reader, writer);
            } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_ECHO)) {
                this.echo(context, reader, writer);
            } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_TRANSFORM)) {
                this.transform(context, reader, writer);
            } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_CANCEL)) {
                this.cancel(context, reader, writer);
            } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_ERROR)) {
                this.error(context, reader, writer);
            } else {
                writer.error((Throwable)CallStatus.UNIMPLEMENTED.withDescription("Command not implemented").toRuntimeException());
            }
        }

        private void doGet(FlightProducer.CallContext unusedContext, FlightStream unusedReader, FlightProducer.ServerStreamListener writer) {
            try (VectorSchemaRoot root = VectorSchemaRoot.create((Schema)SCHEMA, (BufferAllocator)this.allocator);){
                writer.start(root);
                root.allocateNew();
                IntVector iv = (IntVector)root.getVector("a");
                for (int i = 0; i < 100; i += 2) {
                    iv.set(0, i);
                    iv.set(1, i + 1);
                    root.setRowCount(2);
                    writer.putNext();
                }
            }
            writer.completed();
        }

        private void doPut(FlightProducer.CallContext unusedContext, FlightStream reader, FlightProducer.ServerStreamListener writer) {
            int counter = 0;
            while (reader.next()) {
                if (!reader.hasRoot()) {
                    writer.error((Throwable)CallStatus.INVALID_ARGUMENT.withDescription("Message has no data").toRuntimeException());
                    return;
                }
                ArrowBuf pong = this.allocator.buffer(4L);
                pong.writeInt(counter += reader.getRoot().getRowCount());
                writer.putMetadata(pong);
            }
            writer.completed();
        }

        private void metadataOnly(FlightProducer.CallContext unusedContext, FlightStream reader, FlightProducer.ServerStreamListener writer) {
            ArrowBuf buf = this.allocator.buffer(4L);
            buf.writeInt(42);
            writer.putMetadata(buf);
            Assertions.assertTrue((boolean)reader.next());
            Assertions.assertNotNull((Object)reader.getLatestMetadata());
            reader.getLatestMetadata().getReferenceManager().retain();
            writer.putMetadata(reader.getLatestMetadata());
            writer.completed();
        }

        private void echo(FlightProducer.CallContext unusedContext, FlightStream reader, FlightProducer.ServerStreamListener writer) {
            VectorSchemaRoot root = null;
            VectorLoader loader = null;
            while (reader.next()) {
                if (reader.hasRoot()) {
                    if (root == null) {
                        root = VectorSchemaRoot.create((Schema)reader.getSchema(), (BufferAllocator)this.allocator);
                        loader = new VectorLoader(root);
                        writer.start(root);
                    }
                    VectorUnloader unloader = new VectorUnloader(reader.getRoot());
                    try (ArrowRecordBatch arb = unloader.getRecordBatch();){
                        loader.load(arb);
                    }
                    if (reader.getLatestMetadata() != null) {
                        reader.getLatestMetadata().getReferenceManager().retain();
                        writer.putNext(reader.getLatestMetadata());
                        continue;
                    }
                    writer.putNext();
                    continue;
                }
                reader.getLatestMetadata().getReferenceManager().retain();
                writer.putMetadata(reader.getLatestMetadata());
            }
            if (root != null) {
                root.close();
            }
            writer.completed();
        }

        private void transform(FlightProducer.CallContext unusedContext, FlightStream reader, FlightProducer.ServerStreamListener writer) {
            Schema schema = reader.getSchema();
            for (Field field : schema.getFields()) {
                if (!(field.getType() instanceof ArrowType.Int)) {
                    writer.error((Throwable)CallStatus.INVALID_ARGUMENT.withDescription("Invalid type: " + String.valueOf(field)).toRuntimeException());
                    return;
                }
                ArrowType.Int intType = (ArrowType.Int)field.getType();
                if (intType.getIsSigned() && intType.getBitWidth() == 32) continue;
                writer.error((Throwable)CallStatus.INVALID_ARGUMENT.withDescription("Must be i32: " + String.valueOf(field)).toRuntimeException());
                return;
            }
            int batches = 0;
            try (VectorSchemaRoot root = VectorSchemaRoot.create((Schema)schema, (BufferAllocator)this.allocator);){
                writer.start(root);
                writer.setUseZeroCopy(true);
                VectorLoader loader = new VectorLoader(root);
                VectorUnloader unloader = new VectorUnloader(reader.getRoot());
                while (reader.next()) {
                    try (ArrowRecordBatch batch = unloader.getRecordBatch();){
                        loader.load(batch);
                    }
                    ++batches;
                    for (FieldVector rawVec : root.getFieldVectors()) {
                        IntVector vec = (IntVector)rawVec;
                        for (int i = 0; i < root.getRowCount(); ++i) {
                            if (vec.isNull(i)) continue;
                            vec.set(i, vec.get(i) * 2);
                        }
                    }
                    writer.putNext();
                }
            }
            ArrowBuf count = this.allocator.buffer(4L);
            count.writeInt(batches);
            writer.putMetadata(count);
            writer.completed();
        }

        private void cancel(FlightProducer.CallContext unusedContext, FlightStream unusedReader, FlightProducer.ServerStreamListener writer) {
            writer.error((Throwable)CallStatus.CANCELLED.withDescription("expected").toRuntimeException());
        }

        private void error(FlightProducer.CallContext unusedContext, FlightStream reader, FlightProducer.ServerStreamListener writer) {
            VectorSchemaRoot root = null;
            VectorLoader loader = null;
            while (reader.next()) {
                if (root == null) {
                    root = VectorSchemaRoot.create((Schema)reader.getSchema(), (BufferAllocator)this.allocator);
                    loader = new VectorLoader(root);
                    writer.start(root);
                }
                VectorUnloader unloader = new VectorUnloader(reader.getRoot());
                try (ArrowRecordBatch arb = unloader.getRecordBatch();){
                    loader.load(arb);
                }
                writer.putNext();
            }
            if (root != null) {
                root.close();
            }
            writer.error((Throwable)CallStatus.INTERNAL.withDescription("error completing exchange").toRuntimeException());
        }
    }
}

