package com.unbound.provider;

import com.unbound.common.Base64;
import com.unbound.common.JSON;
import com.unbound.common.Log;
import com.unbound.common.STR;
import com.unbound.common.crypto.X509;
import com.unbound.provider.kmip.KMIP;
import com.unbound.provider.kmip.KMIPConvertException;
import com.unbound.provider.kmip.KMIPConverter;
import com.unbound.provider.kmip.attribute.Authentication;
import com.unbound.provider.kmip.request.*;
import com.unbound.provider.kmip.request.dy.DyLoginRequest;
import com.unbound.provider.kmip.response.*;
import com.unbound.provider.kmip.response.dy.DyLoginResponse;
//import sun.security.x509.X500Name;

import javax.net.ssl.KeyManager;
import javax.net.ssl.KeyManagerFactory;
import javax.security.auth.x500.X500Principal;
import java.io.FileInputStream;
import java.io.IOException;
import java.security.*;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.time.Clock;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;

public class Partition
{
  private static HashMap<String, Partition> partitions = new HashMap<>();
  private static final Clock clock = Clock.systemUTC();

  KeyManager[] keyManagers;
  String name;
  //String clientName;
  UBKeyStore keyStore;
  private byte[] jwt = null;
  private long jwtValidityClock;

  private Partition(String name, KeyStore pfx, String pfxPass) throws KeyStoreException, NoSuchAlgorithmException, UnrecoverableKeyException
  {
    Log log = Log.func("Partition").log("name", name).end(); try
    {
      this.name = name;

      //clientName = getNameFromPfx(pfx, "CN");

      KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());//"X.509");
      kmf.init(pfx, pfxPass.toCharArray());
      keyManagers = kmf.getKeyManagers();

      keyStore = new UBKeyStore(this);
    }
    catch (Exception e) { log.failed(e); throw e; } finally { log.leave(); }
  }

  private static KeyStore loadPfx(String pfxFileName, String pass) throws KeyStoreException, IOException, CertificateException, NoSuchAlgorithmException
  {
    Log log = Log.func("Partition.loadPfx").log("pfxFileName", pfxFileName).log("pass", pass!=null).end(); try
    {
      char[] passChars = pass == null ? null : pass.toCharArray();
      KeyStore ks = KeyStore.getInstance("pkcs12");
      ks.load(new FileInputStream(pfxFileName), passChars);
      return ks;
    }
    catch (Exception e) { log.failed(e); throw e; } finally { log.leave(); }
  }

  private static String getNameFromPfx(KeyStore pfx, String type) throws KeyStoreException
  {
    Enumeration<String> aliases = pfx.aliases();
    if (!aliases.hasMoreElements()) throw new ProviderException("Empty store");
    X509Certificate cert = (X509Certificate) pfx.getCertificate(aliases.nextElement());
    if (cert == null) throw new ProviderException("Empty store");
    X500Principal principal = cert.getSubjectX500Principal();
    if (principal == null) throw new ProviderException("Invalid prinicpal");

    String ou = X509.getName(principal, type);
    if (ou==null) throw new ProviderException("Invalid prinicpal");
    return ou;

    /*X500Name x500name = new X500Name(principal.getName());
    if (x500name == null) throw new KeyStoreException("Invalid X500Name");
    return x500name.getOrganizationalUnit();*/
  }

  static synchronized Partition registerPfx(String pfxFileName, String pass) throws CertificateException, NoSuchAlgorithmException, KeyStoreException, IOException, UnrecoverableKeyException
  {
    KeyStore pfx = loadPfx(pfxFileName, pass);
    return registerPfx(pfx, pass);
  }

  static synchronized Partition registerPfx(KeyStore pfx, String pass) throws KeyStoreException, UnrecoverableKeyException, NoSuchAlgorithmException
  {
    String name = getNameFromPfx(pfx, "OU");
    Partition partition = partitions.get(name);
    if (partition == null)
    {
      partition = new Partition(name, pfx, pass);
      partitions.put(name, partition);
    }
    return partition;
  }

  ResponseMessage transmit(RequestMessage req) throws IOException
  {
    Log log = Log.func("Partition.transmit").end(); try
    {
      byte[] jwt;
      long jwtValidityClock;
      synchronized (this)
      {
        jwt = this.jwt;
        jwtValidityClock = this.jwtValidityClock;
      }

      if (req.header.auth==null && jwt!=null)
      {
        if (jwtValidityClock < clock.millis())
        {
          loginRenew();
        }

        req.header.auth = new Authentication();
        req.header.auth.credType = KMIP.CredentialType.Attestation;
        req.header.auth.attestationType = KMIP.AttestationType.DyJwtAssertion;
        req.header.auth.attestationAssertion = jwt;
      }

      return Client.transmit(this, req);
    }
    catch (Exception e) { log.failed(e); throw e; } finally { log.leave(); }
  }

  ResponseItem transmit(RequestItem req) throws IOException
  {
    RequestMessage reqMsg = new RequestMessage();
    reqMsg.batch.add(req);
    ResponseMessage respMsg = transmit(reqMsg);
    return respMsg.batch.get(0);
  }

  private static GetAttributesRequest prepareGetAttrRequest(long uid)
  {
    GetAttributesRequest getAttr = new GetAttributesRequest();
    getAttr.uid = UBObject.uidToStr(uid);
    getAttr.names.add("Object Type");
    getAttr.names.add("Cryptographic Algorithm");
    getAttr.names.add("Name");
    getAttr.names.add("Initial Date");
    return getAttr;
  }

  ResponseMessage read(long[] uids) throws IOException
  {
    RequestMessage reqMsg = new RequestMessage();
    for (long uid : uids)
    {
      GetAttributesRequest getAttr = prepareGetAttrRequest(uid);

      GetRequest get = new GetRequest();
      get.uid = UBObject.uidToStr(uid);
      get.formatType = KMIP.KeyFormatType.X_509;

      reqMsg.batch.add(getAttr);
      reqMsg.batch.add(get);
    }
    return transmit(reqMsg);
  }

  ResponseMessage read(long uid) throws IOException
  {
    RequestMessage reqMsg = new RequestMessage();
    GetAttributesRequest getAttr = prepareGetAttrRequest(uid);

    GetRequest get = new GetRequest();
    get.uid = UBObject.uidToStr(uid);
    get.formatType = KMIP.KeyFormatType.X_509;

    reqMsg.batch.add(getAttr);
    reqMsg.batch.add(get);
    return transmit(reqMsg);
  }

  long[] locate(int objectType, int algType) throws IOException
  {
    Log log = Log.func("Partition.locate").log("objectType", objectType).log("algType", algType).end(); try
    {
      LocateRequest req = UBObject.locateRequest(objectType, algType, null);
      req.maxItems = 1024;

      LocateResponse resp = (LocateResponse) transmit(req);
      long[] result = new long[resp.list.size()];
      int index = 0;
      for (String s : resp.list)
      {
        long uid = UBObject.strToUid(s);
        Log.print("Object").logHex("uid", uid).end();
        result[index++] = uid;
      }
      return result;
    }
    catch (Exception e) { log.failed(e); throw e; } finally { log.leave(); }
  }

  long locate(LocateRequest req) throws IOException
  {
    LocateResponse resp = (LocateResponse)transmit(req);
    if (resp.list.isEmpty()) return 0;
    return UBObject.strToUid(resp.list.get(0));
  }

  long locate(int objectType, int algType, String alias) throws IOException
  {
    long uid = 0;
    Log log = Log.func("Partition.locate").log("objectType", objectType).log("algType", algType).log("alias", alias).end(); try
    {
      LocateRequest req = UBObject.locateRequest(objectType, algType, alias);
      uid = locate(req);
      return uid;
    }
    catch (Exception e) { log.failed(e); throw e; } finally { log.leavePrint().logHex("uid", uid).end(); }
  }

  void login(String password) throws IOException
  {
    loginOrRenew(password, false);
  }

  private void loginRenew() throws IOException
  {
    loginOrRenew(null, true);
  }

  private void loginOrRenew(String password, boolean renewWjt) throws IOException
  {
    Log log = Log.func("Partition.login").log("renewWjt", renewWjt).log("password", password!=null && !password.isEmpty()).end(); try
    {
      RequestMessage reqMsg = new RequestMessage();

      reqMsg.header.auth = new Authentication();
      if (renewWjt)
      {
        reqMsg.header.auth = new Authentication();
        reqMsg.header.auth.credType = KMIP.CredentialType.Attestation;
        reqMsg.header.auth.attestationType = KMIP.AttestationType.DyJwtAssertion;
        reqMsg.header.auth.attestationAssertion = jwt;
      }
      else
      {
        reqMsg.header.auth.credType = KMIP.CredentialType.UsernameAndPassword;
        reqMsg.header.auth.username = "user";
        reqMsg.header.auth.password = "";

        if (password!=null)
        {
          String user;
          try
          {
            Map<String, Object> json = (Map<String, Object>) JSON.convert(password);
            user = (String) json.get("username");
            password = (String) json.get("password");
          }
          catch (Exception e)
          {
            user = "user";
          }
          reqMsg.header.auth.password = password;
          reqMsg.header.auth.username = user;
        }
      }

      DyLoginRequest req = new DyLoginRequest();
      req.doCreateWjt = true;
      reqMsg.batch.add(req);

      try
      {
        ResponseMessage respMsg = transmit(reqMsg);
        DyLoginResponse resp = (DyLoginResponse) respMsg.batch.get(0);
        if (resp.jwt!=null)
        {
          int validity = jwtTokenValidity(STR.utf8(resp.jwt));
          if (validity>0)
          {
            synchronized (this)
            {
              long now = clock.millis();
              jwt = resp.jwt;
              jwtValidityClock = now + validity*1000;
            }
          }
        }
      }
      catch (Exception e)
      {
        synchronized (this)
        {
          jwt = null;
        }
        throw e;
      }
    }
    catch (Exception e) { log.failed(e); throw e; } finally { log.leave(); }
  }

  private static int jwtTokenValidity(String jwt) throws IOException
  {
    String[] t = jwt.split("\\.");
    if (t.length!=3) return 0;

    byte[] t1 = Base64.decodeUrl(t[1]);
    String s = STR.utf8(t1);
    Map<String, Object> root = (Map<String, Object>) JSON.convert(s);
    Long iat = (Long)root.get("iat");
    Long exp = (Long)root.get("exp");
    return (int) (exp - iat - 30);
  }

}
