package org.apache.knox.gateway.provider.federation.jwt.filter;

import com.nimbusds.jose.JWSHeader;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.Principal;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.security.interfaces.RSAPublicKey;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import javax.security.auth.Subject;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.knox.gateway.audit.api.AuditContext;
import org.apache.knox.gateway.audit.api.AuditService;
import org.apache.knox.gateway.audit.api.AuditServiceFactory;
import org.apache.knox.gateway.audit.api.Auditor;
import org.apache.knox.gateway.config.GatewayConfig;
import org.apache.knox.gateway.i18n.messages.MessagesFactory;
import org.apache.knox.gateway.provider.federation.jwt.JWTMessages;
import org.apache.knox.gateway.security.PrimaryPrincipal;
import org.apache.knox.gateway.services.GatewayServices;
import org.apache.knox.gateway.services.ServiceLifecycleException;
import org.apache.knox.gateway.services.ServiceType;
import org.apache.knox.gateway.services.security.AliasService;
import org.apache.knox.gateway.services.security.AliasServiceException;
import org.apache.knox.gateway.services.security.token.JWTokenAuthority;
import org.apache.knox.gateway.services.security.token.TokenMetadata;
import org.apache.knox.gateway.services.security.token.TokenServiceException;
import org.apache.knox.gateway.services.security.token.TokenStateService;
import org.apache.knox.gateway.services.security.token.TokenUtils;
import org.apache.knox.gateway.services.security.token.UnknownTokenException;
import org.apache.knox.gateway.services.security.token.impl.JWT;
import org.apache.knox.gateway.services.security.token.impl.JWTToken;
import org.apache.knox.gateway.services.security.token.impl.TokenMAC;
import org.apache.knox.gateway.util.Tokens;

/* loaded from: input_file:org/apache/knox/gateway/provider/federation/jwt/filter/AbstractJWTFilter.class */
public abstract class AbstractJWTFilter implements Filter {
    public static final String JWT_EXPECTED_ISSUER = "jwt.expected.issuer";
    public static final String JWT_DEFAULT_ISSUER = "KNOXSSO";
    public static final String JWT_EXPECTED_SIGALG = "jwt.expected.sigalg";
    public static final String JWT_DEFAULT_SIGALG = "RS256";
    static JWTMessages log = (JWTMessages) MessagesFactory.get(JWTMessages.class);
    private static AuditService auditService = AuditServiceFactory.getAuditService();
    private static Auditor auditor = auditService.getAuditor("audit", "knox", "knox");
    protected List<String> audiences;
    protected JWTokenAuthority authority;
    protected RSAPublicKey publicKey;
    protected SignatureVerificationCache signatureVerificationCache;
    private String expectedIssuer;
    private String expectedSigAlg;
    protected String expectedPrincipalClaim;
    protected String expectedJWKSUrl;
    private TokenStateService tokenStateService;
    private TokenMAC tokenMAC;

    public abstract void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException;

    public void init(FilterConfig filterConfig) throws ServletException {
        GatewayServices gatewayServices;
        ServletContext servletContext = filterConfig.getServletContext();
        if (servletContext != null && (gatewayServices = (GatewayServices) servletContext.getAttribute("org.apache.knox.gateway.gateway.services")) != null) {
            this.authority = (JWTokenAuthority) gatewayServices.getService(ServiceType.TOKEN_SERVICE);
            if (TokenUtils.isServerManagedTokenStateEnabled(filterConfig)) {
                this.tokenStateService = (TokenStateService) gatewayServices.getService(ServiceType.TOKEN_STATE_SERVICE);
                try {
                    this.tokenMAC = new TokenMAC(((GatewayConfig) servletContext.getAttribute("org.apache.knox.gateway.config")).getKnoxTokenHashAlgorithm(), ((AliasService) gatewayServices.getService(ServiceType.ALIAS_SERVICE)).getPasswordFromAliasForGateway("knox.token.hash.key"));
                } catch (ServiceLifecycleException | AliasServiceException e) {
                    throw new ServletException("Error while initializing Knox token MAC generator", e);
                }
            }
        }
        this.signatureVerificationCache = SignatureVerificationCache.getInstance(servletContext != null ? (String) servletContext.getAttribute("org.apache.knox.gateway.gateway.cluster") : null, filterConfig);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void configureExpectedParameters(FilterConfig filterConfig) {
        this.expectedIssuer = filterConfig.getInitParameter(JWT_EXPECTED_ISSUER);
        if (this.expectedIssuer == null) {
            this.expectedIssuer = JWT_DEFAULT_ISSUER;
        }
        this.expectedSigAlg = filterConfig.getInitParameter(JWT_EXPECTED_SIGALG);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<String> parseExpectedAudiences(String str) {
        ArrayList arrayList = null;
        if (str != null && !str.isEmpty()) {
            String[] split = str.split(",");
            arrayList = new ArrayList();
            for (String str2 : split) {
                arrayList.add(str2.trim());
            }
        }
        return arrayList;
    }

    protected boolean tokenIsStillValid(JWT jwt) throws UnknownTokenException {
        Date serverManagedStateExpiration = getServerManagedStateExpiration(TokenUtils.getTokenId(jwt));
        if (serverManagedStateExpiration == null) {
            serverManagedStateExpiration = jwt.getExpiresDate();
        }
        return serverManagedStateExpiration == null || new Date().before(serverManagedStateExpiration);
    }

    protected boolean tokenIsStillValid(String str) throws UnknownTokenException {
        Date serverManagedStateExpiration = getServerManagedStateExpiration(str);
        return serverManagedStateExpiration == null || new Date().before(serverManagedStateExpiration);
    }

    private Date getServerManagedStateExpiration(String str) throws UnknownTokenException {
        Date date = null;
        if (this.tokenStateService != null) {
            long tokenExpiration = this.tokenStateService.getTokenExpiration(str);
            if (tokenExpiration > 0) {
                date = new Date(tokenExpiration);
            }
        }
        return date;
    }

    protected boolean validateAudiences(JWT jwt) {
        boolean z = false;
        String[] audienceClaims = jwt.getAudienceClaims();
        if (this.audiences == null) {
            z = true;
        } else if (audienceClaims != null) {
            int length = audienceClaims.length;
            int i = 0;
            while (true) {
                if (i >= length) {
                    break;
                }
                if (this.audiences.contains(audienceClaims[i])) {
                    log.jwtAudienceValidated();
                    z = true;
                    break;
                }
                i++;
            }
        }
        return z;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void continueWithEstablishedSecurityContext(Subject subject, final HttpServletRequest httpServletRequest, final HttpServletResponse httpServletResponse, final FilterChain filterChain) throws IOException, ServletException {
        Principal principal = (Principal) subject.getPrincipals(PrimaryPrincipal.class).toArray()[0];
        AuditContext context = auditService.getContext();
        if (context != null) {
            context.setUsername(principal.getName());
            String str = (String) httpServletRequest.getAttribute("sourceRequestContextUrl");
            if (str != null) {
                auditor.audit("authentication", str, "uri", "success");
            }
        }
        try {
            Subject.doAs(subject, new PrivilegedExceptionAction<Object>() { // from class: org.apache.knox.gateway.provider.federation.jwt.filter.AbstractJWTFilter.1
                @Override // java.security.PrivilegedExceptionAction
                public Object run() throws Exception {
                    filterChain.doFilter(httpServletRequest, httpServletResponse);
                    return null;
                }
            });
        } catch (PrivilegedActionException e) {
            ServletException cause = e.getCause();
            if (cause instanceof IOException) {
                throw ((IOException) cause);
            }
            if (!(cause instanceof ServletException)) {
                throw new ServletException(cause);
            }
            throw cause;
        }
    }

    public Subject createSubjectFromToken(String str) throws ParseException, UnknownTokenException {
        return createSubjectFromToken((JWT) new JWTToken(str));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Subject createSubjectFromToken(JWT jwt) throws UnknownTokenException {
        String subject = jwt.getSubject();
        String str = null;
        if (this.expectedPrincipalClaim != null) {
            str = jwt.getClaim(this.expectedPrincipalClaim);
        }
        return createSubjectFromTokenData(subject, str);
    }

    public Subject createSubjectFromTokenIdentifier(String str) throws UnknownTokenException {
        TokenMetadata tokenMetadata = this.tokenStateService.getTokenMetadata(str);
        if (tokenMetadata != null) {
            return createSubjectFromTokenData(tokenMetadata.getUserName(), null);
        }
        return null;
    }

    protected Subject createSubjectFromTokenData(String str, String str2) {
        String lowerCase = str2 != null ? str2.toLowerCase(Locale.ROOT) : null;
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        hashSet2.add(new PrimaryPrincipal(lowerCase != null ? lowerCase : str));
        return new Subject(true, hashSet2, hashSet, hashSet);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean validateToken(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, FilterChain filterChain, JWT jwt) throws IOException, ServletException {
        String tokenId = TokenUtils.getTokenId(jwt);
        String tokenIDDisplayText = Tokens.getTokenIDDisplayText(tokenId);
        String tokenDisplayText = Tokens.getTokenDisplayText(jwt.toString());
        if (!this.expectedIssuer.equals(jwt.getIssuer())) {
            handleValidationError(httpServletRequest, httpServletResponse, 401, null);
            return false;
        }
        try {
            if (!tokenIsStillValid(jwt)) {
                log.tokenHasExpired(tokenDisplayText, tokenIDDisplayText);
                removeSignatureVerificationRecord(jwt.toString());
                handleValidationError(httpServletRequest, httpServletResponse, 401, "Token has expired");
            } else if (validateAudiences(jwt)) {
                Date notBeforeDate = jwt.getNotBeforeDate();
                if (notBeforeDate != null && !new Date().after(notBeforeDate)) {
                    log.notBeforeCheckFailed();
                    handleValidationError(httpServletRequest, httpServletResponse, 400, "Bad request: the NotBefore check failed");
                } else if (!isTokenEnabled(tokenId)) {
                    log.disabledToken(tokenIDDisplayText);
                    handleValidationError(httpServletRequest, httpServletResponse, 401, "Token " + tokenIDDisplayText + " is disabled");
                } else {
                    if (verifyTokenSignature(jwt)) {
                        return true;
                    }
                    log.failedToVerifyTokenSignature(tokenDisplayText, tokenIDDisplayText);
                    handleValidationError(httpServletRequest, httpServletResponse, 401, null);
                }
            } else {
                log.failedToValidateAudience(tokenDisplayText, tokenIDDisplayText);
                handleValidationError(httpServletRequest, httpServletResponse, 400, "Bad request: missing required token audience");
            }
            return false;
        } catch (UnknownTokenException e) {
            log.unableToVerifyExpiration(e);
            handleValidationError(httpServletRequest, httpServletResponse, 401, e.getMessage());
            return false;
        }
    }

    private boolean isTokenEnabled(String str) throws UnknownTokenException {
        TokenMetadata tokenMetadata = this.tokenStateService == null ? null : this.tokenStateService.getTokenMetadata(str);
        if (tokenMetadata == null) {
            return true;
        }
        return tokenMetadata.isEnabled();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean validateToken(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, FilterChain filterChain, String str, String str2) throws IOException, ServletException {
        if (this.tokenStateService == null) {
            return false;
        }
        try {
            if (str != null) {
                String tokenIDDisplayText = Tokens.getTokenIDDisplayText(str);
                if (!tokenIsStillValid(str)) {
                    log.tokenHasExpired(tokenIDDisplayText);
                    removeSignatureVerificationRecord(str2);
                    handleValidationError(httpServletRequest, httpServletResponse, 401, "Token has expired");
                } else if (!isTokenEnabled(str)) {
                    log.disabledToken(tokenIDDisplayText);
                    handleValidationError(httpServletRequest, httpServletResponse, 401, "Token " + tokenIDDisplayText + " is disabled");
                } else {
                    if (hasSignatureBeenVerified(str2) || validatePasscode(str, str2)) {
                        return true;
                    }
                    log.wrongPasscodeToken(str);
                    handleValidationError(httpServletRequest, httpServletResponse, 401, "Invalid passcode");
                }
            } else {
                log.missingTokenPasscode();
                handleValidationError(httpServletRequest, httpServletResponse, 400, "Bad request: missing token passcode.");
            }
            return false;
        } catch (UnknownTokenException e) {
            log.unableToVerifyExpiration(e);
            handleValidationError(httpServletRequest, httpServletResponse, 401, e.getMessage());
            return false;
        }
    }

    private boolean validatePasscode(String str, String str2) throws UnknownTokenException {
        long tokenIssueTime = this.tokenStateService.getTokenIssueTime(str);
        TokenMetadata tokenMetadata = this.tokenStateService.getTokenMetadata(str);
        boolean equals = Arrays.equals(this.tokenMAC.hash(str, tokenIssueTime, tokenMetadata == null ? "" : tokenMetadata.getUserName(), str2).getBytes(StandardCharsets.UTF_8), tokenMetadata == null ? null : tokenMetadata.getPasscode().getBytes(StandardCharsets.UTF_8));
        if (equals) {
            recordSignatureVerification(str2);
        }
        return equals;
    }

    protected boolean verifyTokenSignature(JWT jwt) {
        String obj = jwt.toString();
        boolean hasSignatureBeenVerified = hasSignatureBeenVerified(obj);
        if (!hasSignatureBeenVerified) {
            try {
                hasSignatureBeenVerified = this.publicKey != null ? this.authority.verifyToken(jwt, this.publicKey) : this.expectedJWKSUrl != null ? this.authority.verifyToken(jwt, this.expectedJWKSUrl, this.expectedSigAlg) : this.authority.verifyToken(jwt);
            } catch (TokenServiceException e) {
                log.unableToVerifyToken(e);
            }
            if (hasSignatureBeenVerified && this.expectedSigAlg != null) {
                try {
                    if (!JWSHeader.parse(jwt.getHeader()).getAlgorithm().getName().equals(this.expectedSigAlg)) {
                        hasSignatureBeenVerified = false;
                    }
                } catch (ParseException e2) {
                    log.unableToVerifyToken(e2);
                    hasSignatureBeenVerified = false;
                }
            }
            if (hasSignatureBeenVerified) {
                recordSignatureVerification(obj);
            }
        }
        return hasSignatureBeenVerified;
    }

    protected boolean hasSignatureBeenVerified(String str) {
        return this.signatureVerificationCache.hasSignatureBeenVerified(str);
    }

    protected void recordSignatureVerification(String str) {
        this.signatureVerificationCache.recordSignatureVerification(str);
    }

    protected void removeSignatureVerificationRecord(String str) {
        this.signatureVerificationCache.removeSignatureVerificationRecord(str);
    }

    protected abstract void handleValidationError(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, int i, String str) throws IOException;
}
