/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.kafka.server.plugins.auth.oauth;

import io.confluent.kafka.clients.plugins.auth.jwt.JwtAuthenticator;
import io.confluent.kafka.clients.plugins.auth.jwt.JwtAuthenticatorConfig;
import io.confluent.kafka.clients.plugins.auth.jwt.JwtVerificationException;
import io.confluent.kafka.common.multitenant.oauth.OAuthBearerJwsToken;
import io.confluent.kafka.multitenant.BasePhysicalClusterMetadata;
import io.confluent.kafka.server.plugins.auth.SniValidationMode;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.login.AppConfigurationEntry;
import kafka.server.KafkaConfig;
import org.apache.kafka.common.config.ConfigException;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.authenticator.PathAwareSniHostName;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerExtensionsValidatorCallback;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerValidatorCallback;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OAuthBearerValidatorCallbackHandler
implements AuthenticateCallbackHandler {
    private static final Logger log = LoggerFactory.getLogger(OAuthBearerValidatorCallbackHandler.class);
    private static final String DEFAULT_SCOPE_CLAIM = "clusters";
    private static final String AUTH_ERROR_MESSAGE = "Authentication failed";
    private JwtAuthenticator jwtAuthenticator;
    private BasePhysicalClusterMetadata clusterMetadata;
    private SniValidationMode mode;
    private boolean configured = false;

    public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
        if (!"OAUTHBEARER".equals(saslMechanism)) {
            throw new IllegalArgumentException(String.format("Unexpected SASL mechanism: %s", saslMechanism));
        }
        if (Objects.requireNonNull(jaasConfigEntries).size() != 1 || jaasConfigEntries.get(0) == null) {
            throw new IllegalArgumentException(String.format("Must supply exactly 1 non-null JAAS mechanism configuration (size was %d)", jaasConfigEntries.size()));
        }
        HashMap moduleOptions = new HashMap(jaasConfigEntries.get(0).getOptions());
        if (moduleOptions.containsKey("publicKeyPath")) {
            moduleOptions.put("jwksLocation", moduleOptions.remove("publicKeyPath"));
        }
        JwtAuthenticatorConfig authenticatorConfig = new JwtAuthenticatorConfig(moduleOptions);
        this.jwtAuthenticator = new JwtAuthenticator(authenticatorConfig);
        Object uuid = configs.get(KafkaConfig.BrokerSessionUuidProp());
        if (uuid == null || uuid.toString().isEmpty()) {
            throw new ConfigException("Broker session UUID must be set in the Kafka config!");
        }
        this.clusterMetadata = BasePhysicalClusterMetadata.getInstance(uuid.toString());
        if (this.clusterMetadata == null) {
            throw new ConfigException("Could not get a PhysicalClusterMetadata instance with broker session UUID " + uuid.toString());
        }
        this.mode = SniValidationMode.fromString((String)moduleOptions.get("sni_host_name_validation_mode"));
        this.configured = true;
    }

    public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
        if (!this.configured) {
            throw new IllegalStateException("Callback handler not configured");
        }
        for (Callback callback : callbacks) {
            if (callback instanceof OAuthBearerValidatorCallback) {
                this.handleCallback((OAuthBearerValidatorCallback)callback);
                continue;
            }
            if (callback instanceof OAuthBearerExtensionsValidatorCallback) {
                this.handleExtensionsCallback((OAuthBearerExtensionsValidatorCallback)callback);
                continue;
            }
            throw new UnsupportedCallbackException(callback);
        }
    }

    private void handleCallback(OAuthBearerValidatorCallback callback) {
        try {
            this.handleValidatorCallback(callback);
        }
        catch (JwtVerificationException e) {
            log.info("Failed to verify OAuth JWT token", (Throwable)e);
            callback.error(AUTH_ERROR_MESSAGE, null, null);
        }
    }

    public void close() {
        if (this.jwtAuthenticator != null) {
            try {
                this.jwtAuthenticator.close();
            }
            catch (IOException e) {
                log.error("Failed to close Authenticator", (Throwable)e);
            }
        }
    }

    private void handleValidatorCallback(OAuthBearerValidatorCallback callback) throws JwtVerificationException {
        String tokenValue = callback.tokenValue();
        if (tokenValue == null) {
            throw new IllegalArgumentException("Callback missing required token value");
        }
        OAuthBearerToken token = this.processToken(tokenValue);
        callback.token(token);
        log.debug("Successfully validated token");
    }

    private void handleExtensionsCallback(OAuthBearerExtensionsValidatorCallback callback) {
        OAuthBearerJwsToken token = (OAuthBearerJwsToken)callback.token();
        String logicalCluster = (String)callback.inputExtensions().map().get("logicalCluster");
        String sniHostName = (String)callback.inputExtensions().map().get("__confluent_sni_broker_host_name");
        if (!(this.doesClusterExtensionExist(callback, logicalCluster) && this.isSniHostNameMatched(callback, logicalCluster, sniHostName, this.mode) && this.isLogicalClusterPartOfAllowedClusters(callback, token, logicalCluster))) {
            return;
        }
        try {
            if (!this.isClusterMetadataMatched(callback, token, logicalCluster)) {
                return;
            }
        }
        catch (IllegalStateException e) {
            this.reportErrorGettingMetadata(callback, e);
            return;
        }
        callback.valid("logicalCluster");
        log.debug("Successfully authenticated for user: {} (cluster: {})", (Object)token.principalName(), (Object)logicalCluster);
    }

    private void reportErrorGettingMetadata(OAuthBearerExtensionsValidatorCallback callback, IllegalStateException e) {
        log.error("Could not get physical cluster metadata to validate the token. ", (Throwable)e);
        callback.errorMessage("Could not get cluster metadata to validate the token");
        callback.error("logicalCluster", AUTH_ERROR_MESSAGE);
    }

    private boolean isClusterMetadataMatched(OAuthBearerExtensionsValidatorCallback callback, OAuthBearerJwsToken token, String logicalCluster) {
        if (!this.clusterMetadata.logicalClusterIds().contains(logicalCluster)) {
            if (this.clusterMetadata.logicalClusterIdsIncludingStale().contains(logicalCluster)) {
                log.info("Failing OAuth authentication because the metadata for the logical cluster {} is stale.", (Object)logicalCluster);
            }
            String errorMessage = String.format("The principal %s's logical cluster %s is not hosted on this broker.", token.principalName(), logicalCluster);
            this.handleExtensionError(callback, errorMessage, "logicalCluster");
            return false;
        }
        return true;
    }

    private boolean isLogicalClusterPartOfAllowedClusters(OAuthBearerExtensionsValidatorCallback callback, OAuthBearerJwsToken token, String logicalCluster) {
        Set logicalClusters = token.scope();
        if (logicalClusters.contains(logicalCluster)) {
            return true;
        }
        String errorMessage = String.format("The principal %s's logical cluster %s is not part of the allowed clusters in this token (%s).", token.principalName(), logicalCluster, String.join((CharSequence)",", logicalClusters));
        this.handleExtensionError(callback, errorMessage, "logicalCluster");
        return false;
    }

    private boolean doesClusterExtensionExist(OAuthBearerExtensionsValidatorCallback callback, String logicalCluster) {
        if (logicalCluster == null || logicalCluster.isEmpty()) {
            String errorMessage = "The logical cluster extension is missing or is empty";
            this.handleExtensionError(callback, errorMessage, "logicalCluster");
            return false;
        }
        return true;
    }

    protected boolean isSniHostNameMatched(OAuthBearerExtensionsValidatorCallback callback, String logicalClusterId, String sniHostName, SniValidationMode sniValidationMode) {
        Optional<PathAwareSniHostName> sniHostNameOptional = sniHostName == null ? Optional.empty() : Optional.of(new PathAwareSniHostName(sniHostName));
        Optional<String> sniClusterId = sniHostNameOptional.map(PathAwareSniHostName::logicalClusterId);
        if (sniValidationMode.sniHostNameMatches(logicalClusterId, sniClusterId, sniHostNameOptional)) {
            return true;
        }
        String errorMessage = String.format("The SNI cluster Id: %s doesn't match with logical cluster extension: %s.", sniClusterId.orElse("<empty>"), logicalClusterId);
        this.handleExtensionError(callback, errorMessage, "__confluent_sni_broker_host_name");
        return false;
    }

    private void handleExtensionError(OAuthBearerExtensionsValidatorCallback callback, String errorMessage, String invalidExtensionName) {
        log.info(errorMessage);
        callback.errorMessage(errorMessage);
        callback.error(invalidExtensionName, AUTH_ERROR_MESSAGE);
    }

    OAuthBearerToken processToken(String jws) throws JwtVerificationException {
        return this.jwtAuthenticator.login(jws, DEFAULT_SCOPE_CLAIM);
    }
}

