
package net.lightapi.portal.oauth.command.handler;

import com.networknt.config.JsonMapper;
import com.networknt.monad.Failure;
import com.networknt.monad.Result;
import com.networknt.monad.Success;
import com.networknt.rpc.router.ServiceHandler;
import com.networknt.security.KeyUtil;
import com.networknt.status.Status;
import com.networknt.utility.HashUtil;
import com.networknt.utility.UuidUtil;
import io.undertow.server.HttpServerExchange;
import net.lightapi.portal.HybridQueryClient;
import net.lightapi.portal.PortalConstants;
import net.lightapi.portal.command.AbstractCommandHandler;
import org.jose4j.jwk.JsonWebKey;
import org.jose4j.jwk.JsonWebKeySet;
import org.jose4j.jwk.PublicJsonWebKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.security.KeyPair;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * Rotate keys for a particular provider in the cluster, called by the scheduler or administrator.
 * An authorization code token is needed to access this endpoint.
 */
@ServiceHandler(id="lightapi.net/oauth/rotateProvider/0.1.0")
public class RotateProvider extends AbstractCommandHandler {
    private static final Logger logger = LoggerFactory.getLogger(RotateProvider.class);
    private static final String PROVIDER_KEY_TYPE_MISSING = "ERR11639";
    private static final String PROVIDER_KEY_TYPE_INVALID = "ERR11640";
    private static final String KEY_GENERATION_ERROR = "ERR12054";

    @Override
    protected String getCloudEventType() {
        return PortalConstants.AUTH_PROVIDER_ROTATED_EVENT;
    }

    @Override
    protected Logger getLogger() {
        return logger;
    }

    @Override
    protected Result<Map<String, Object>> enrichInput(HttpServerExchange exchange, Map<String, Object> map) {
        String rotateKeyType = (String) map.get("keyType"); // to indicate LC or TC should be rotated.
        String providerId = (String) map.get("providerId");
        String hostId = (String) map.get("hostId");

        if(rotateKeyType == null || rotateKeyType.isBlank()) {
            // return an error that indicate the key type is missing.
            return Failure.of(new Status(PROVIDER_KEY_TYPE_MISSING));
        }
        if(!rotateKeyType.equals("LC") && !rotateKeyType.equals("TC")) {
            // return an error that indicate the key type is invalid.
            return Failure.of(new Status(PROVIDER_KEY_TYPE_INVALID, rotateKeyType));
        }
        String removeKeyType = rotateKeyType.equals("LC") ? "LP" : "TP"; // to indicate which key to remove.

        // get all the keys for the provider.
        Result<String> providerKeyResult = HybridQueryClient.getProviderKey(exchange, hostId, providerId);
        if(providerKeyResult.isFailure()) {
            return Failure.of(providerKeyResult.getError());
        }

        Map<String, Object> insertMap = new HashMap<>();

        List<Map<String, Object>> keys = JsonMapper.string2List(providerKeyResult.getResult());

        // iterate all keys and rotate the keys.
        try {
            // create a new key pair for the provider
            KeyPair keyPair = KeyUtil.generateKeyPair("RSA", 2048);
            String kid = UuidUtil.uuidToBase64(UuidUtil.getUUID());
            PublicJsonWebKey jwk = PublicJsonWebKey.Factory.newPublicJwk(keyPair.getPublic());
            jwk.setKeyId(kid);
            String publicKey = KeyUtil.serializePublicKey(keyPair.getPublic());
            String privateKey = KeyUtil.serializePrivateKey(keyPair.getPrivate());
            insertMap.put("publicKey", publicKey);
            insertMap.put("privateKey", privateKey);
            insertMap.put("kid", kid);
            insertMap.put("keyType", rotateKeyType);
            map.put("insert", insertMap);

            List<JsonWebKey> jwkList = new ArrayList<>();
            jwkList.add(jwk);

            // iterate the existing keys to create a new jwk list and also find which kid to remove and update from the list.
            for(Map<String, Object> key : keys) {
                if(key.get("keyType").equals(rotateKeyType)) {
                    // this is the key we need to rotate.
                    Map<String, Object> updateMap = new HashMap<>();
                    updateMap.put("kid", key.get("kid"));
                    updateMap.put("keyType", removeKeyType);
                    map.put("update", updateMap);
                    jwk = PublicJsonWebKey.Factory.newPublicJwk(KeyUtil.deserializePublicKey((String)key.get("publicKey"), "RSA"));
                    jwk.setKeyId((String)key.get("kid"));
                    jwkList.add(jwk);
                } else if(key.get("keyType").equals(removeKeyType)) {
                    // this is the key we need to remove.
                    Map<String, Object> deleteMap = new HashMap<>();
                    deleteMap.put("kid", key.get("kid"));
                    map.put("delete", deleteMap);
                    jwk = PublicJsonWebKey.Factory.newPublicJwk(KeyUtil.deserializePublicKey((String)key.get("publicKey"), "RSA"));
                    jwk.setKeyId((String)key.get("kid"));
                    jwkList.add(jwk);
                } else {
                    // this is the key we need to keep.
                    jwk = PublicJsonWebKey.Factory.newPublicJwk(KeyUtil.deserializePublicKey((String)key.get("publicKey"), "RSA"));
                    jwk.setKeyId((String)key.get("kid"));
                    jwkList.add(jwk);
                }
            }
            // create a JsonWebKeySet object with the list of JWK objects
            JsonWebKeySet jwks = new JsonWebKeySet(jwkList);
            // and output the JSON of the JWKS
            String jwkJson = jwks.toJson(JsonWebKey.OutputControlLevel.PUBLIC_ONLY);
            map.put("jwk", jwkJson);
        } catch (Exception e) {
            return Failure.of(new Status(KEY_GENERATION_ERROR, e.getMessage()));
        }
        return Success.of(map);
    }
}
