/*
 * Decompiled with CFR 0.152.
 */
package org.apache.arrow.flight;

import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.NoOpFlightProducer;
import org.apache.arrow.flight.PutResult;
import org.apache.arrow.flight.SyncPutListener;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.jupiter.api.Test;

public class TestLeak {
    private static final int ROWS = 2048;

    private static Schema getSchema() {
        return new Schema(Arrays.asList(Field.nullable((String)"0", (ArrowType)new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), Field.nullable((String)"1", (ArrowType)new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), Field.nullable((String)"2", (ArrowType)new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), Field.nullable((String)"3", (ArrowType)new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), Field.nullable((String)"4", (ArrowType)new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), Field.nullable((String)"5", (ArrowType)new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), Field.nullable((String)"6", (ArrowType)new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), Field.nullable((String)"7", (ArrowType)new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), Field.nullable((String)"8", (ArrowType)new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), Field.nullable((String)"9", (ArrowType)new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), Field.nullable((String)"10", (ArrowType)new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))));
    }

    @Test
    public void testCancelingDoGetDoesNotLeak() throws Exception {
        CountDownLatch callFinished = new CountDownLatch(1);
        try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE);
             FlightServer s = FlightServer.builder((BufferAllocator)allocator, (Location)Location.forGrpcInsecure((String)"localhost", (int)0), (FlightProducer)new LeakFlightProducer((BufferAllocator)allocator, callFinished)).build().start();
             FlightClient client = FlightClient.builder((BufferAllocator)allocator, (Location)s.getLocation()).build();){
            FlightStream stream = client.getStream(new Ticket(new byte[0]), new CallOption[0]);
            stream.getRoot();
            stream.cancel("Cancel", null);
            callFinished.await(60L, TimeUnit.SECONDS);
            s.shutdown();
            s.awaitTermination();
        }
    }

    @Test
    public void testCancelingDoPutDoesNotBlock() throws Exception {
        CountDownLatch callFinished = new CountDownLatch(1);
        try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE);
             FlightServer s = FlightServer.builder((BufferAllocator)allocator, (Location)Location.forGrpcInsecure((String)"localhost", (int)0), (FlightProducer)new LeakFlightProducer((BufferAllocator)allocator, callFinished)).build().start();
             FlightClient client = FlightClient.builder((BufferAllocator)allocator, (Location)s.getLocation()).build();){
            try (VectorSchemaRoot root = VectorSchemaRoot.create((Schema)TestLeak.getSchema(), (BufferAllocator)allocator);){
                FlightDescriptor descriptor = FlightDescriptor.command((byte[])new byte[0]);
                SyncPutListener listener = new SyncPutListener();
                FlightClient.ClientStreamListener stream = client.startPut(descriptor, root, (FlightClient.PutListener)listener, new CallOption[0]);
                callFinished.await(60L, TimeUnit.SECONDS);
                for (int col = 0; col < 11; ++col) {
                    Float8Vector vector = (Float8Vector)root.getVector(Integer.toString(col));
                    vector.allocateNew();
                    for (int row = 0; row < 2048; ++row) {
                        vector.setSafe(row, 10.0);
                    }
                }
                root.setRowCount(2048);
                stream.putNext();
                stream.completed();
            }
            s.shutdown();
            s.awaitTermination();
        }
    }

    private static class LeakFlightProducer
    extends NoOpFlightProducer {
        private final BufferAllocator allocator;
        private final CountDownLatch callFinished;

        public LeakFlightProducer(BufferAllocator allocator, CountDownLatch callFinished) {
            this.allocator = allocator;
            this.callFinished = callFinished;
        }

        public void getStream(FlightProducer.CallContext context, Ticket ticket, FlightProducer.ServerStreamListener listener) {
            BufferAllocator childAllocator = this.allocator.newChildAllocator("foo", 0L, Long.MAX_VALUE);
            VectorSchemaRoot root = VectorSchemaRoot.create((Schema)TestLeak.getSchema(), (BufferAllocator)childAllocator);
            root.allocateNew();
            listener.start(root);
            listener.setOnCancelHandler(() -> {
                try {
                    for (int col = 0; col < 11; ++col) {
                        Float8Vector vector = (Float8Vector)root.getVector(Integer.toString(col));
                        vector.allocateNew();
                        for (int row = 0; row < 2048; ++row) {
                            vector.setSafe(row, 10.0);
                        }
                    }
                    root.setRowCount(2048);
                    listener.putNext();
                    listener.completed();
                }
                finally {
                    try {
                        root.close();
                        childAllocator.close();
                    }
                    finally {
                        this.callFinished.countDown();
                    }
                }
            });
        }

        public Runnable acceptPut(FlightProducer.CallContext context, FlightStream flightStream, FlightProducer.StreamListener<PutResult> ackStream) {
            return () -> {
                flightStream.getRoot();
                ackStream.onError((Throwable)CallStatus.CANCELLED.withDescription("CANCELLED").toRuntimeException());
                this.callFinished.countDown();
                ackStream.onCompleted();
            };
        }
    }
}

