package io.confluent.rbacapi;

import io.confluent.common.security.util.PemUtils;
import io.confluent.kafka.server.plugins.auth.token.IdentityProviderService;
import io.confluent.security.authentication.http.HttpClient;
import io.confluent.tokenapi.jwt.JwtProvider;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.Principal;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.Form;
import org.jose4j.jwt.JwtClaims;
import org.jose4j.jwt.consumer.JwtConsumer;
import org.jose4j.jwt.consumer.JwtConsumerBuilder;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:io/confluent/rbacapi/JwtProviderTest.class */
public class JwtProviderTest {

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private JwtProvider jwtProvider;
    private JwtConsumer jwtReader;
    private Path pemPath;
    private HttpClient httpClient;
    private IdentityProviderService idp;

    @Before
    public void setUp() throws IOException {
        this.pemPath = createJwtKeyPair(this.tempFolder.newFolder().toPath().resolve("0.pem"));
        this.httpClient = HttpClient.builder().build();
        this.idp = new IdentityProviderService();
        this.idp.setStartupTimeout(Duration.ofMinutes(10L));
        this.idp.start();
        this.jwtProvider = new JwtProvider();
        this.jwtProvider.configure(getJwtProviderProps());
        this.jwtReader = new JwtConsumerBuilder().setSkipSignatureVerification().setDisableRequireSignature().setSkipAllValidators().build();
    }

    @After
    public void tearDown() throws IOException {
        if (this.idp != null) {
            this.idp.shutdown();
        }
        if (this.httpClient != null) {
            this.httpClient.close();
        }
    }

    @Test
    public void testNewJwsTokenWithoutCustomClaims() throws Exception {
        JwtClaims processToClaims = this.jwtReader.processToClaims(this.jwtProvider.newJwsToken(() -> {
            return "testUser";
        }, new String[]{"audience1"}));
        long value = processToClaims.getExpirationTime().getValue() - (System.currentTimeMillis() / 1000);
        Assert.assertNotNull(processToClaims);
        Assert.assertEquals("testUser", processToClaims.getSubject());
        Assert.assertEquals("Confluent", processToClaims.getIssuer());
        Assert.assertEquals(1L, processToClaims.getAudience().size());
        Assert.assertTrue(processToClaims.getAudience().contains("audience1"));
        Assert.assertTrue(value <= this.jwtProvider.tokenLifetime() && value > this.jwtProvider.tokenLifetime() - 5);
    }

    @Test
    public void testNewJwtTokenWithCustomClaims() throws Exception {
        Principal principal = () -> {
            return "testUser";
        };
        HashMap hashMap = new HashMap();
        HashSet hashSet = new HashSet(Arrays.asList("g1", "g2"));
        hashMap.put("customClaim1", "value1");
        hashMap.put("customClaim2", 123L);
        hashMap.put("groups", hashSet);
        JwtClaims processToClaims = this.jwtReader.processToClaims(this.jwtProvider.newJwsToken(principal, hashMap, 3600L, new String[]{"audience1", "audience2"}));
        long value = processToClaims.getExpirationTime().getValue() - (System.currentTimeMillis() / 1000);
        Assert.assertNotNull(processToClaims);
        Assert.assertEquals("testUser", processToClaims.getSubject());
        Assert.assertEquals("Confluent", processToClaims.getIssuer());
        Assert.assertEquals(2L, processToClaims.getAudience().size());
        Assert.assertTrue(processToClaims.getAudience().contains("audience1"));
        Assert.assertTrue(processToClaims.getAudience().contains("audience2"));
        Assert.assertEquals("value1", processToClaims.getStringClaimValue("customClaim1"));
        Assert.assertEquals(123L, processToClaims.getClaimValue("customClaim2"));
        Assert.assertEquals(new ArrayList(hashSet), processToClaims.getStringListClaimValue("groups"));
        Assert.assertTrue(value <= 3600 && value > 3595);
    }

    @Test
    public void testRefreshTokenWithoutValidationWithConfluentToken() throws Exception {
        Principal principal = () -> {
            return "testUser";
        };
        HashMap hashMap = new HashMap();
        hashMap.put("groups", new HashSet(Arrays.asList("g1", "g2")));
        JwtClaims processToClaims = this.jwtReader.processToClaims(this.jwtProvider.refreshTokenWithoutValidation(principal, this.jwtProvider.newJwsToken(principal, hashMap, 3600L, new String[]{"audience1"}), new String[0]));
        Assert.assertEquals("testUser", processToClaims.getSubject());
        Assert.assertEquals("Confluent", processToClaims.getIssuer());
        Assert.assertEquals(0L, processToClaims.getAudience().size());
        Assert.assertNull(processToClaims.getClaimValue("aud"));
        Assert.assertEquals(Arrays.asList("g1", "g2"), processToClaims.getStringListClaimValue("groups"));
    }

    @Test
    public void testRefreshTokenWithoutValidationWithIdpToken() throws Exception {
        String jwtFromIdp = getJwtFromIdp();
        JwtClaims processToClaims = this.jwtReader.processToClaims(jwtFromIdp);
        Assert.assertEquals("app1-developer", processToClaims.getSubject());
        Assert.assertEquals(this.idp.getIssuer(), processToClaims.getIssuer());
        Assert.assertEquals(1L, processToClaims.getAudience().size());
        Assert.assertTrue(processToClaims.getAudience().contains("account"));
        Assert.assertEquals(Collections.singletonList("/g4"), processToClaims.getStringListClaimValue("groups"));
        JwtClaims processToClaims2 = this.jwtReader.processToClaims(this.jwtProvider.refreshTokenWithoutValidation(() -> {
            return "app1-developer";
        }, jwtFromIdp, new String[0]));
        Assert.assertEquals("app1-developer", processToClaims2.getSubject());
        Assert.assertEquals("Confluent", processToClaims2.getIssuer());
        Assert.assertEquals(0L, processToClaims2.getAudience().size());
        Assert.assertEquals(Collections.singletonList("/g4"), processToClaims2.getStringListClaimValue("groups"));
    }

    private String getJwtFromIdp() throws ExecutionException, InterruptedException {
        return (String) this.httpClient.target(URI.create(this.idp.getTokenEndpoint())).request().header("Authorization", "Basic " + Base64.getEncoder().encodeToString("app1-developer:app1-developer".getBytes(StandardCharsets.UTF_8))).accept(new String[]{"application/json"}).rx().post(Entity.entity(new Form().param("grant_type", "client_credentials"), "application/x-www-form-urlencoded")).thenApply(response -> {
            return (Map) response.readEntity(Map.class);
        }).thenApply(map -> {
            return (String) map.get("access_token");
        }).toCompletableFuture().get();
    }

    private Map<String, Object> getJwtProviderProps() {
        HashMap hashMap = new HashMap();
        hashMap.put("token.key.path", this.pemPath.toString());
        hashMap.put("token.issuer", "Confluent");
        hashMap.put("user.store", "OAUTH");
        hashMap.put("oauthbearer.jwks.endpoint.url", this.idp.getJwksEndpoint());
        hashMap.put("oauthbearer.expected.issuer", this.idp.getIssuer());
        hashMap.put("oauthbearer.expected.audience", "Confluent,account");
        hashMap.put("oauthbearer.groups.claim.name", "groups");
        return hashMap;
    }

    private Path createJwtKeyPair(Path path) {
        try {
            OutputStream newOutputStream = Files.newOutputStream(path, StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING);
            Throwable th = null;
            try {
                try {
                    PemUtils.writeKeyPair(newOutputStream, generateKeyPair());
                    if (newOutputStream != null) {
                        if (0 != 0) {
                            try {
                                newOutputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            newOutputStream.close();
                        }
                    }
                    return path;
                } finally {
                }
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException("Failed to load JWT PEM file", e);
        }
    }

    private KeyPair generateKeyPair() {
        try {
            KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
            keyPairGenerator.initialize(2048);
            return keyPairGenerator.generateKeyPair();
        } catch (Exception e) {
            throw new RuntimeException("Failed to generate key pair", e);
        }
    }
}
