package org.apache.arrow.flight.auth;

import com.google.common.collect.ImmutableList;
import io.grpc.StatusRuntimeException;
import java.io.IOException;
import java.util.Arrays;
import java.util.Optional;
import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.Criteria;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.FlightTestUtil;
import org.apache.arrow.flight.NoOpFlightProducer;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.flight.auth.BasicServerAuthHandler;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.Assertions;

/* loaded from: input_file:org/apache/arrow/flight/auth/TestBasicAuth.class */
public class TestBasicAuth {
    final String PERMISSION_DENIED = "PERMISSION_DENIED";
    private static final String USERNAME = "flight";
    private static final String PASSWORD = "woohoo";
    private static final byte[] VALID_TOKEN = "my_token".getBytes();
    private FlightClient client;
    private FlightServer server;
    private BufferAllocator allocator;

    @Test
    public void validAuth() {
        this.client.authenticateBasic(USERNAME, PASSWORD);
        Assert.assertTrue(ImmutableList.copyOf(this.client.listFlights(Criteria.ALL, new CallOption[0])).size() >= 0);
    }

    @Test
    public void asyncCall() {
        this.client.authenticateBasic(USERNAME, PASSWORD);
        this.client.listFlights(Criteria.ALL, new CallOption[0]);
        FlightStream stream = this.client.getStream(new Ticket(new byte[1]), new CallOption[0]);
        while (stream.next()) {
            Assert.assertEquals(4095L, stream.getRoot().getRowCount());
            stream.getRoot().clear();
        }
    }

    @Test
    public void invalidAuth() {
        Assertions.assertThrows(StatusRuntimeException.class, () -> {
            this.client.authenticateBasic(USERNAME, "WRONG");
        }, "PERMISSION_DENIED");
        Assertions.assertThrows(StatusRuntimeException.class, () -> {
            this.client.listFlights(Criteria.ALL, new CallOption[0]);
        }, "PERMISSION_DENIED");
    }

    @Test
    public void didntAuth() {
        Assertions.assertThrows(StatusRuntimeException.class, () -> {
            this.client.listFlights(Criteria.ALL, new CallOption[0]);
        }, "PERMISSION_DENIED");
    }

    @Before
    public void setup() throws IOException {
        this.allocator = new RootAllocator(Long.MAX_VALUE);
        BasicServerAuthHandler.BasicAuthValidator basicAuthValidator = new BasicServerAuthHandler.BasicAuthValidator() { // from class: org.apache.arrow.flight.auth.TestBasicAuth.1
            public Optional<String> isValid(byte[] bArr) {
                return Arrays.equals(bArr, TestBasicAuth.VALID_TOKEN) ? Optional.of(TestBasicAuth.USERNAME) : Optional.empty();
            }

            public byte[] getToken(String str, String str2) {
                if (TestBasicAuth.USERNAME.equals(str) && TestBasicAuth.PASSWORD.equals(str2)) {
                    return TestBasicAuth.VALID_TOKEN;
                }
                throw new IllegalArgumentException("invalid credentials");
            }
        };
        this.server = (FlightServer) FlightTestUtil.getStartedServer(location -> {
            return FlightServer.builder(this.allocator, location, new NoOpFlightProducer() { // from class: org.apache.arrow.flight.auth.TestBasicAuth.2
                public void listFlights(FlightProducer.CallContext callContext, Criteria criteria, FlightProducer.StreamListener<FlightInfo> streamListener) {
                    if (callContext.peerIdentity().equals(TestBasicAuth.USERNAME)) {
                        streamListener.onCompleted();
                    } else {
                        streamListener.onError(new IllegalArgumentException("Invalid username"));
                    }
                }

                public void getStream(FlightProducer.CallContext callContext, Ticket ticket, FlightProducer.ServerStreamListener serverStreamListener) {
                    if (!callContext.peerIdentity().equals(TestBasicAuth.USERNAME)) {
                        serverStreamListener.error(new IllegalArgumentException("Invalid username"));
                        return;
                    }
                    VectorSchemaRoot create = VectorSchemaRoot.create(new Schema(ImmutableList.of(Field.nullable("a", Types.MinorType.BIGINT.getType()))), TestBasicAuth.this.allocator);
                    serverStreamListener.start(create);
                    create.allocateNew();
                    create.setRowCount(4095);
                    serverStreamListener.putNext();
                    create.clear();
                    serverStreamListener.completed();
                }
            }).authHandler(new BasicServerAuthHandler(basicAuthValidator)).build();
        });
        this.client = FlightClient.builder(this.allocator, this.server.getLocation()).build();
    }

    @After
    public void shutdown() throws Exception {
        AutoCloseables.close(new AutoCloseable[]{this.client, this.server, this.allocator});
    }
}
