/*
jGuard is a security framework based on top of jaas (java authentication and authorization security).
it is written for web applications, to resolve simply, access control problems.
version $Name:  $
http://sourceforge.net/projects/jguard/

Copyright (C) 2004  Charles GAY

This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.

This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA


jGuard project home page:
http://sourceforge.net/projects/jguard/

*/
package net.sf.jguard.jee.authentication.callbacks;

import java.io.IOException;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;

import javax.security.auth.Subject;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.LanguageCallback;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;

import net.sf.jguard.core.CoreConstants;
import net.sf.jguard.core.authentication.AuthenticationUtils;
import net.sf.jguard.core.authentication.callbacks.InetAddressCallback;
import net.sf.jguard.ext.SecurityConstants;
import net.sf.jguard.ext.authentication.callbacks.CallbackHandlerUtils;
import net.sf.jguard.ext.authentication.callbacks.JCaptchaCallback;
import net.sf.jguard.jee.authentication.http.HttpConstants;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xbill.DNS.DClass;
import org.xbill.DNS.ExtendedResolver;
import org.xbill.DNS.Message;
import org.xbill.DNS.Name;
import org.xbill.DNS.Record;
import org.xbill.DNS.Resolver;
import org.xbill.DNS.ReverseMap;
import org.xbill.DNS.Section;
import org.xbill.DNS.Type;

import com.octo.captcha.service.CaptchaService;
import java.net.UnknownHostException;

/**
 * handle grabbing credentials from an HTTP Servlet request.
 * @author <a href="mailto:diabolo512@users.sourceforge.net ">Charles Gay</a>
 */
public class HttpServletCallbackHandler implements CallbackHandler{

	
	
	public static final String AUTHORIZATION = "Authorization";
	private static final String BASIC_REALM = "Basic realm=\"";
	private static final String NO_CACHE_AUTHORIZATION = "no-cache=\"Authorization\"";
	private static final String CACHE_CONTROL = "Cache-Control";
	private static final String WWW_AUTHENTICATE = "WWW-Authenticate";
	
	
	/** Logger for this class */
	private static final Logger logger = LoggerFactory.getLogger(HttpServletCallbackHandler.class);
	private HttpServletRequest httpServletRequest;
	private HttpServletResponse httpServletResponse;
	private static String authSchemes=HttpConstants.FORM_AUTH;
	private static String loginField="login";
	private static String passwordField="password";
	private boolean afterRegistration;

        
	/**
	 * constructor required by javadoc of the CallbackHandler interface.
	 */
	public HttpServletCallbackHandler(){
		super();
	}

	/**
	 * constructor.
	 * @param request
	 * @param response
	 * @param authSchemes
	 */
	public HttpServletCallbackHandler(HttpServletRequest request,HttpServletResponse response){
		this.httpServletRequest = request;
		Boolean ar = (Boolean)request.getAttribute(CoreConstants.REGISTRATION_DONE);
		if(ar!=null){
			afterRegistration = ar.booleanValue();
		}
		this.httpServletResponse = response;
	}


	/**
	 * extract from the HttpServletRequest client's information and put
	 * them into callbacks (credentials, extra informations and so on).
	 * if those are not recognised, we put the challenge in the
	 * HttpServletResponse.
	 * this Callbackhandler support {@ InetAdressCallback},{@ NameCallback},
	 * {@ PasswordCallback}, {@ LanguageCallback}, {@ JCaptchaCallback}, {@ CertificatesCallback}.
	 */
	public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
                
                boolean httpRelatedAuthScheme = false;
                for (int i = 0; i< callbacks.length;i++){
                    Callback callback = callbacks[i];
                    if(callback instanceof InetAddressCallback){
                    	String remoteAddress = httpServletRequest.getRemoteAddr();
                    	String remoteHost = httpServletRequest.getRemoteHost();
                        InetAddressCallback inetAddressCallback = (InetAddressCallback)callback;
                        inetAddressCallback.setHostAdress(remoteAddress);
                        
                        //the server is not configured to return the hostName.
                        if(remoteAddress.equals(remoteHost)){
                            String resolvedHostName = remoteAddress;
                            try{
                             resolvedHostName = reverseDns(remoteAddress);
                            }catch(UnknownHostException uhe){
                                logger.warn(" host bound to address "+remoteAddress +"cannot be resolved" , uhe);
                            }
                            inetAddressCallback.setHostName(resolvedHostName);
                        }else{
                            //the server is configured to return the hostName.
                            inetAddressCallback.setHostName(remoteHost);
                        }
                    }else if (callback  instanceof LanguageCallback) {
						LanguageCallback languageCallback = (LanguageCallback)callback;
						Locale locale = httpServletRequest.getLocale();
						languageCallback.setLocale(locale);
					}
                }
                
                //authentication schemes part
                logger.debug("authSchemes="+authSchemes);
                String[] schemes = authSchemes.split(",");
                List authSchemesList = Arrays.asList(schemes);
                Iterator itAutSchemes = authSchemesList.iterator();
                while(itAutSchemes.hasNext()){
                        String scheme = (String)itAutSchemes.next();
                        //FORM, BASIC, and DIGEST are mutual exclusive
                        if(!httpRelatedAuthScheme && HttpConstants.FORM_AUTH.equalsIgnoreCase(scheme)){
                                grabFormCredentials(this.httpServletRequest,callbacks);
                                httpRelatedAuthScheme = true;
                        }else if(!httpRelatedAuthScheme && HttpConstants.BASIC_AUTH.equalsIgnoreCase(scheme)){
                                grabBasicCredentials(this.httpServletRequest,callbacks);
                                httpRelatedAuthScheme = true;
                        }else if(!httpRelatedAuthScheme && HttpConstants.DIGEST_AUTH.equalsIgnoreCase(scheme)){
                                grabDigestCredentials(this.httpServletRequest,callbacks);
                                httpRelatedAuthScheme = true;
                        }
                        //CLIENT_CERT can be used with another authentication mechanism
                        //defined above
                        if(HttpConstants.CLIENT_CERT_AUTH.equalsIgnoreCase(scheme)){
                                boolean certificatesFound = grabClientCertCredentials(this.httpServletRequest,callbacks);
                                if(!certificatesFound){
                                        logger.info(" X509 certificates are not found ");
                                }
                        }
                }

	}

	
	

	/**
	 * send to the client the BASIC challenge into the response, according to the RFC 2617.
	 * @param response reponse send to the Client
	 * @param realmName realm owned by the server => specify what kind of credential the user should provide
	 */
	public static void buildBasicChallenge(HttpServletResponse response,String realmName){
		StringBuffer responseValue= new StringBuffer();
		responseValue.append(HttpServletCallbackHandler.BASIC_REALM);
		responseValue.append(realmName);
		responseValue.append("\"");
		response.setHeader(HttpServletCallbackHandler.WWW_AUTHENTICATE,responseValue.toString());
		response.setHeader(HttpServletCallbackHandler.CACHE_CONTROL, HttpServletCallbackHandler.NO_CACHE_AUTHORIZATION);
		response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
	}
	
	
	/**
	 * parse into the HttpServletRequest the user and password field,
	 * and authenticate the user with these credentials using the <b>NON SECURE</> BASIC method.
	 * @param request request send by the client.
	 * @param callbacks
	 * @return authentication's result. <i>true</i> for authentication success, <i>false</i> otherwise.
	 */
	private boolean grabBasicCredentials(HttpServletRequest request,Callback[] callbacks){
		//user and password are encoded in Base64
		String encodedLoginAndPwd = request.getHeader(HttpServletCallbackHandler.AUTHORIZATION);
		String encoding =  request.getCharacterEncoding();
		
		return CallbackHandlerUtils.grabBasicCredentials(encodedLoginAndPwd, encoding, callbacks);
	}

	
	
    /**
     * grab user credentials from request in the 'form' authentication method.
     * @param request request send by the client
     * @param callbacks
     * @return authentication result : <b>true</b> when authentication succeed,<b>false</b> when authentication fails.
     */
	private  boolean  grabFormCredentials(HttpServletRequest request,Callback[] callbacks){
		boolean result = false;
        HttpSession session = request.getSession();

		for(int i=0;i<callbacks.length;i++){
        	if(callbacks[i] instanceof NameCallback){
        		NameCallback nc = (NameCallback)callbacks[i];
        		String login =httpServletRequest.getParameter(loginField);
        		nc.setName(login);
        	}else if(callbacks[i] instanceof PasswordCallback){
        		PasswordCallback pc = (PasswordCallback)callbacks[i];
        		String strPwd = httpServletRequest.getParameter(passwordField);
                 if(strPwd!= null &&strPwd!=""){
        		  pc.setPassword(strPwd.toCharArray());
                 }else{
                  pc.setPassword(null);
                 }
        	}else if(callbacks[i] instanceof JCaptchaCallback){
        		JCaptchaCallback pc = (JCaptchaCallback)callbacks[i];
        		pc.setCaptchaAnswer(httpServletRequest.getParameter(SecurityConstants.CAPTCHA_ANSWER));
        		pc.setCaptchaService((CaptchaService)session.getServletContext().getAttribute(SecurityConstants.CAPTCHA_SERVICE));
        		Subject subject = ((AuthenticationUtils)session.getAttribute(CoreConstants.AUTHN_UTILS)).getSubject();
        		if(subject==null ||afterRegistration){
        		    pc.setSkipJCaptchaChallenge(true);
        		}

        		pc.setSessionID(session.getId());
        	}
        }
        result = true;


		return result;
	}

	 /**
	 * grab user credentials from request in the 'digest' authentication metod.
     * @param request request send by the client
     * @param callbacks
     * @return authentication result : <b>true</b> when authentication succeed,<b>false</b> when authentication fails.
     */
	private  boolean  grabDigestCredentials(HttpServletRequest request,Callback[] callbacks){
        boolean result = false;
		String login = "";
        String password = "";
        //all users must be authenticated
        //unless when the user send a wrong  login or/and password
        //=> he is redirected to the logonPage
        if(login==null || password == null ){
           login =CoreConstants.GUEST;
           password =CoreConstants.GUEST;
        }else{
            //TODO implements digest authentication
           result = true;
        }
		return result;
	}

	/**
	 * grab user credentials from request in the 'clientCert' authentication metod.
	 * @param request
	 * @param callbacks
	 * @return <code>true</code> if successfull, <code>false</code> otherwise
	 */
	private boolean grabClientCertCredentials(HttpServletRequest request,Callback[] callbacks) {
		if(!request.isSecure()){
			logger.warn(" certificate-based authentication MUST be do in secure mode ");
			logger.warn(" but connection is do with the non secured protocol "+request.getScheme());
			return false;
		}

		Object[] objects = (Object[]) request.getAttribute(CallbackHandlerUtils.JAVAX_SERVLET_REQUEST_X509CERTIFICATE);
		
		return CallbackHandlerUtils.grabClientCertCredentials(callbacks, objects);
	}

	
	public static void buildFormChallenge(FilterChain chain,ServletRequest req,ServletResponse res) throws IOException, ServletException{
		chain.doFilter(req,res);
	}

	/**
	 * send to the client the DIGEST challenge into the response, according to the RFC 2617.
	 * @param response reponse send to the Client
	 * @param token realm owned by the server => specify what kind of credential the user should provide
	 */
	public static void buildDigestChallenge(HttpServletRequest request,HttpServletResponse response, String realm) {
			String responseValue = CallbackHandlerUtils.buildDigestChallenge(realm);
			response.setHeader(HttpServletCallbackHandler.WWW_AUTHENTICATE,responseValue.toString());
			response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
	}
	
	

	/**
	 * gets the HttpRequest password field
	 * @return password field
	 */
	public static String getPasswordField() {
		return passwordField;
	}

	public static void setPasswordField(String passwordField) {
		if(passwordField!=null){
			HttpServletCallbackHandler.passwordField = passwordField;
		}
	}

	/**
	 * gets the HttpRequest login field
	 * @return login field
	 */
	public static String getLoginField() {
		return loginField;
	}

	public static void setLoginField(String loginField) {
		if(loginField!=null){
			HttpServletCallbackHandler.loginField = loginField;
		}
	}


	

        /**
         *return the host name related to the IP adress.
         *this method comes from <a href="http://www.oreillynet.com/onjava/blog/2005/11/reverse_dns_lookup_and_java.html">a blog entry about dnsjava</a>.
         *@param hostIp Internet Protocol  adress
         *@return host name related to the hostIp parameter, 
         *or hostIp parameter if no nam eserver is found.
         */
        private String reverseDns(String hostIp) throws IOException {
                 Record opt = null;
                 Resolver res = new ExtendedResolver();

                 Name name = ReverseMap.fromAddress(hostIp);
                 int type = Type.PTR;
                 int dclass = DClass.IN;
                 Record rec = Record.newRecord(name, type, dclass);
                 Message query = Message.newQuery(rec);
                 Message response = res.send(query);

                 Record[] answers = response.getSectionArray(Section.ANSWER);
                 if (answers.length == 0){
                    return hostIp;
                 }else{
                    return answers[0].rdataToString();
                 }
        }

    public String getAuthScheme() {
        return authSchemes;
    }

    public static void setAuthSchemes(String schemes) {
    	authSchemes = schemes;
    }
}
