package de.comhix.web.filter;

import com.google.common.base.Splitter;
import de.comhix.web.auth.AuthFunction;
import de.comhix.web.auth.AuthenticationException;
import de.comhix.web.auth.UserProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.inject.Inject;
import javax.inject.Singleton;
import javax.servlet.*;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Enumeration;
import java.util.List;
import java.util.Optional;

/**
 * @author Benjamin Beeker
 */
@Singleton
public class AuthFilter implements Filter {
    private static final Logger log = LoggerFactory.getLogger(AuthFilter.class);

    private static final String DEFAULT_TOKEN_COOKIE_NAME = "auth-cookie";
    public static final String TOKEN_COOKIE_NAME_PARAM = "tokenCookieName";
    public static final String NO_AUTH_PARAM = "noAuth";

    private final AuthFunction authFunction;
    private final UserProvider userProvider;
    private List<String> noAuth;
    private String tokenCookieName;

    @Inject
    public AuthFilter(AuthFunction authFunction, UserProvider userProvider) {
        this.authFunction = authFunction;
        this.userProvider = userProvider;
    }

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        log.debug("filter: {}", filterConfig.getFilterName());
        Enumeration<String> initParameterNames = filterConfig.getInitParameterNames();
        while (initParameterNames.hasMoreElements()) {
            String name = initParameterNames.nextElement();
            log.debug("{}: {}", name, filterConfig.getInitParameter(name));
            if (name.equals(NO_AUTH_PARAM)) {
                noAuth = Splitter.on(",").trimResults().splitToList(filterConfig.getInitParameter(name));
            }
            if (name.equals(TOKEN_COOKIE_NAME_PARAM)) {
                tokenCookieName = filterConfig.getInitParameter(name);
            }
        }

        if (tokenCookieName == null) {
            tokenCookieName = DEFAULT_TOKEN_COOKIE_NAME;
        }
    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        HttpServletRequest httpRequest = (HttpServletRequest) request;
        String servletPath = httpRequest.getServletPath();

        checkAuth(httpRequest, servletPath);

        chain.doFilter(request, response);
    }

    private void checkAuth(HttpServletRequest httpRequest, String servletPath) {
        if (noAuth.contains(servletPath)) {
            log.debug("no auth needed for {}", servletPath);
            return;
        }
        log.debug("checking auth for path: {}", servletPath);

        if (httpRequest.getCookies() != null) {
            for (Cookie cookie : httpRequest.getCookies()) {
                if (cookie.getName().equals(tokenCookieName)) {
                    checkCookieToken(cookie);
                    return;
                }
            }
        }
        throw new AuthenticationException();
    }

    public void doAuth(HttpServletResponse response, String token) {
        Cookie cookie = new Cookie(tokenCookieName, token);
        cookie.setPath("/");
        response.addCookie(cookie);
    }

    private void checkCookieToken(Cookie cookie) {
        Optional<String> user = authFunction.apply(cookie.getValue());
        if (!user.isPresent()) {
            throw new AuthenticationException();
        }
        else {
            userProvider.setUser(user.get());
        }
    }

    @Override
    public void destroy() {
    }
}
