
package net.lightapi.portal.oauth.command.handler;

import com.networknt.monad.Failure;
import com.networknt.monad.Result;
import com.networknt.monad.Success;
import com.networknt.rpc.router.ServiceHandler;
import com.networknt.security.KeyUtil;
import com.networknt.status.Status;
import com.networknt.utility.HashUtil;
import com.networknt.utility.Util;
import com.networknt.utility.UuidUtil;
import io.undertow.server.HttpServerExchange;
import net.lightapi.portal.PortalConstants;
import net.lightapi.portal.command.AbstractCommandHandler;
import org.jose4j.jwk.JsonWebKey;
import org.jose4j.jwk.JsonWebKeySet;
import org.jose4j.jwk.PublicJsonWebKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.security.KeyPair;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * Create an OAuth 2.0 provider in a cluster and it is called from the light-view by admin. A authorization code token is
 * needed to access this endpoint.
 *
 * @author Steve Hu
*/
@ServiceHandler(id="lightapi.net/oauth/createProvider/0.1.0")
public class CreateProvider extends AbstractCommandHandler {
    private static final Logger logger = LoggerFactory.getLogger(CreateProvider.class);
    private static final String KEY_GENERATION_ERROR = "ERR12054";

    @Override
    protected String getCloudEventType() {
        return PortalConstants.AUTH_PROVIDER_CREATED_EVENT;
    }

    @Override
    public String getCloudEventAggregateType() {
        return PortalConstants.AGGREGATE_PROVIDER;
    }

    @Override
    public String getCloudEventAggregateId(Map<String, Object> map) {
        // the aggregate id is the providerId in the data section. It is generated by enrichInput method.
        return (String) map.get("providerId");
    }

    @Override
    protected Logger getLogger() {
        return logger;
    }

    @Override
    protected Result<Map<String, Object>> enrichInput(HttpServerExchange exchange, Map<String, Object> map) {
        String providerId = (String)map.get("providerId");
        if(providerId == null) {
            providerId = UuidUtil.uuidToBase64(UuidUtil.getUUID());
            map.put("providerId", providerId);
        }
        // create keys for the provider
        Map<String, Object> keys = new HashMap<>();
        try {
            List<JsonWebKey> jwkList = new ArrayList<>();
            KeyPair longKeyPairCurr = KeyUtil.generateKeyPair("RSA", 2048);
            String longKeyIdCurr = UuidUtil.uuidToBase64(UuidUtil.getUUID());
            PublicJsonWebKey jwk = PublicJsonWebKey.Factory.newPublicJwk(longKeyPairCurr.getPublic());
            jwk.setKeyId(longKeyIdCurr);
            jwkList.add(jwk);
            String longPublicKeyCurr = KeyUtil.serializePublicKey(longKeyPairCurr.getPublic());
            String longPrivateKeyCurr = KeyUtil.serializePrivateKey(longKeyPairCurr.getPrivate());
            Map<String, Object> lcMap = new HashMap<>();
            lcMap.put("publicKey", longPublicKeyCurr);
            lcMap.put("privateKey", longPrivateKeyCurr);
            lcMap.put("kid", longKeyIdCurr);
            keys.put("LC", lcMap);
            KeyPair longKeyPairPrev = KeyUtil.generateKeyPair("RSA", 2048);
            String longKeyIdPrev = UuidUtil.uuidToBase64(UuidUtil.getUUID());
            jwk = PublicJsonWebKey.Factory.newPublicJwk(longKeyPairPrev.getPublic());
            jwk.setKeyId(longKeyIdPrev);
            jwkList.add(jwk);
            String longPublicKeyPrev = KeyUtil.serializePublicKey(longKeyPairPrev.getPublic());
            String longPrivateKeyPrev = KeyUtil.serializePrivateKey(longKeyPairPrev.getPrivate());
            Map<String, Object> lpMap = new HashMap<>();
            lpMap.put("publicKey", longPublicKeyPrev);
            lpMap.put("privateKey", longPrivateKeyPrev);
            lpMap.put("kid", longKeyIdPrev);
            keys.put("LP", lpMap);
            KeyPair tokenKeyPairCurr = KeyUtil.generateKeyPair("RSA", 2048);
            String tokenKeyIdCurr = UuidUtil.uuidToBase64(UuidUtil.getUUID());
            jwk = PublicJsonWebKey.Factory.newPublicJwk(tokenKeyPairCurr.getPublic());
            jwk.setKeyId(tokenKeyIdCurr);
            jwkList.add(jwk);
            String tokenPublicKeyCurr = KeyUtil.serializePublicKey(tokenKeyPairCurr.getPublic());
            String tokenPrivateKeyCurr = KeyUtil.serializePrivateKey(tokenKeyPairCurr.getPrivate());
            Map<String, Object> tcMap = new HashMap<>();
            tcMap.put("publicKey", tokenPublicKeyCurr);
            tcMap.put("privateKey", tokenPrivateKeyCurr);
            tcMap.put("kid", tokenKeyIdCurr);
            keys.put("TC", tcMap);

            KeyPair tokenKeyPairPrev = KeyUtil.generateKeyPair("RSA", 2048);
            String tokenKeyIdPrev = UuidUtil.uuidToBase64(UuidUtil.getUUID());
            jwk = PublicJsonWebKey.Factory.newPublicJwk(tokenKeyPairPrev.getPublic());
            jwk.setKeyId(tokenKeyIdPrev);
            jwkList.add(jwk);
            String tokenPublicKeyPrev = KeyUtil.serializePublicKey(tokenKeyPairPrev.getPublic());
            String tokenPrivateKeyPrev = KeyUtil.serializePrivateKey(tokenKeyPairPrev.getPrivate());
            Map<String, Object> tpMap = new HashMap<>();
            tpMap.put("publicKey", tokenPublicKeyPrev);
            tpMap.put("privateKey", tokenPrivateKeyPrev);
            tpMap.put("kid", tokenKeyIdPrev);
            keys.put("TP", tpMap);

            // create a JsonWebKeySet object with the list of JWK objects
            JsonWebKeySet jwks = new JsonWebKeySet(jwkList);
            // and output the JSON of the JWKS
            String jwkJson = jwks.toJson(JsonWebKey.OutputControlLevel.PUBLIC_ONLY);
            map.put("jwk", jwkJson);
        } catch (Exception e) {
            return Failure.of( new Status(KEY_GENERATION_ERROR, e.getMessage()));
        }
        map.put("keys", keys);
        return Success.of(map);
    }
}
