/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.server.security;

import com.google.common.base.Joiner;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.io.ByteStreams;
import com.google.common.net.MediaType;
import io.prestosql.server.InternalAuthenticationManager;
import io.prestosql.server.security.AuthenticationException;
import io.prestosql.server.security.Authenticator;
import io.prestosql.server.security.BasicAuthCredentials;
import io.prestosql.server.security.SecurityConfig;
import io.prestosql.server.ui.WebUiAuthenticationManager;
import io.prestosql.spi.security.Identity;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.security.Principal;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import javax.inject.Inject;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;

public class AuthenticationFilter
implements Filter {
    private static final String HTTPS_PROTOCOL = "https";
    private final List<Authenticator> authenticators;
    private final boolean httpsForwardingEnabled;
    private final InternalAuthenticationManager internalAuthenticationManager;
    private final WebUiAuthenticationManager uiAuthenticationManager;

    @Inject
    public AuthenticationFilter(List<Authenticator> authenticators, SecurityConfig securityConfig, InternalAuthenticationManager internalAuthenticationManager, WebUiAuthenticationManager uiAuthenticationManager) {
        this.authenticators = ImmutableList.copyOf((Collection)Objects.requireNonNull(authenticators, "authenticators is null"));
        this.httpsForwardingEnabled = Objects.requireNonNull(securityConfig, "securityConfig is null").getEnableForwardingHttps();
        this.internalAuthenticationManager = Objects.requireNonNull(internalAuthenticationManager, "internalAuthenticationManager is null");
        this.uiAuthenticationManager = Objects.requireNonNull(uiAuthenticationManager, "uiAuthenticationManager is null");
    }

    public void init(FilterConfig filterConfig) {
    }

    public void destroy() {
    }

    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain nextFilter) throws IOException, ServletException {
        HttpServletRequest request = (HttpServletRequest)servletRequest;
        HttpServletResponse response = (HttpServletResponse)servletResponse;
        if (this.internalAuthenticationManager.isInternalRequest(request)) {
            Principal principal = this.internalAuthenticationManager.authenticateInternalRequest(request);
            if (principal == null) {
                response.setStatus(401);
                response.setContentType(MediaType.PLAIN_TEXT_UTF_8.toString());
                return;
            }
            Identity identity = Identity.forUser((String)"<internal>").withPrincipal(principal).build();
            AuthenticationFilter.withAuthenticatedIdentity(nextFilter, request, response, identity);
            return;
        }
        if (WebUiAuthenticationManager.isUiRequest(request)) {
            this.uiAuthenticationManager.handleUiRequest(request, response, nextFilter);
            return;
        }
        if (!this.doesRequestSupportAuthentication(request)) {
            AuthenticationFilter.handleInsecureRequest(nextFilter, request, response);
            return;
        }
        LinkedHashSet<String> messages = new LinkedHashSet<String>();
        LinkedHashSet authenticateHeaders = new LinkedHashSet();
        for (Authenticator authenticator : this.authenticators) {
            Identity authenticatedIdentity;
            try {
                authenticatedIdentity = authenticator.authenticate(request);
            }
            catch (AuthenticationException e) {
                if (e.getMessage() != null) {
                    messages.add(e.getMessage());
                }
                e.getAuthenticateHeader().ifPresent(authenticateHeaders::add);
                continue;
            }
            AuthenticationFilter.withAuthenticatedIdentity(nextFilter, request, response, authenticatedIdentity);
            return;
        }
        AuthenticationFilter.skipRequestBody(request);
        for (String value : authenticateHeaders) {
            response.addHeader("WWW-Authenticate", value);
        }
        if (messages.isEmpty()) {
            messages.add("Unauthorized");
        }
        String error = Joiner.on((String)" | ").join(messages);
        AuthenticationFilter.sendErrorMessage(response, 401, error);
    }

    private static void sendErrorMessage(HttpServletResponse response, int errorCode, String errorMessage) throws IOException {
        response.setStatus(errorCode, errorMessage);
        response.setContentType(MediaType.PLAIN_TEXT_UTF_8.toString());
        try (PrintWriter writer = response.getWriter();){
            writer.write(errorMessage);
        }
    }

    private static void handleInsecureRequest(FilterChain nextFilter, HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException {
        Optional<BasicAuthCredentials> basicAuthCredentials;
        try {
            basicAuthCredentials = BasicAuthCredentials.extractBasicAuthCredentials(request);
        }
        catch (AuthenticationException e) {
            AuthenticationFilter.sendErrorMessage(response, 403, e.getMessage());
            return;
        }
        if (!basicAuthCredentials.isPresent()) {
            nextFilter.doFilter((ServletRequest)request, (ServletResponse)response);
            return;
        }
        if (basicAuthCredentials.get().getPassword().isPresent()) {
            AuthenticationFilter.sendErrorMessage(response, 403, "Password not allowed for insecure request");
            return;
        }
        AuthenticationFilter.withAuthenticatedIdentity(nextFilter, request, response, Identity.ofUser((String)basicAuthCredentials.get().getUser()));
    }

    private boolean doesRequestSupportAuthentication(HttpServletRequest request) {
        if (this.authenticators.isEmpty()) {
            return false;
        }
        if (request.isSecure()) {
            return true;
        }
        return this.httpsForwardingEnabled && Strings.nullToEmpty((String)request.getHeader("X-Forwarded-Proto")).equalsIgnoreCase(HTTPS_PROTOCOL);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static void withAuthenticatedIdentity(FilterChain nextFilter, HttpServletRequest request, HttpServletResponse response, Identity authenticatedIdentity) throws IOException, ServletException {
        request.setAttribute("presto.authenticated-identity", (Object)authenticatedIdentity);
        try {
            nextFilter.doFilter(AuthenticationFilter.withPrincipal(request, authenticatedIdentity.getPrincipal()), (ServletResponse)response);
        }
        finally {
            Optional.ofNullable(request.getAttribute("presto.authenticated-identity")).map(Identity.class::cast).ifPresent(Identity::destroy);
        }
    }

    private static ServletRequest withPrincipal(HttpServletRequest request, final Optional<Principal> principal) {
        Objects.requireNonNull(principal, "principal is null");
        if (!principal.isPresent()) {
            return request;
        }
        return new HttpServletRequestWrapper(request){

            public Principal getUserPrincipal() {
                return (Principal)principal.get();
            }
        };
    }

    private static void skipRequestBody(HttpServletRequest request) throws IOException {
        try (ServletInputStream inputStream = request.getInputStream();){
            ByteStreams.copy((InputStream)inputStream, (OutputStream)ByteStreams.nullOutputStream());
        }
    }
}

