package com.dyadicsec.provider;

import com.dyadicsec.pkcs11.*;
import static com.dyadicsec.cryptoki.CK.*;

import java.io.*;
import java.security.KeyStoreSpi;
import java.security.KeyStoreException;
import java.security.ProviderException;
import java.security.NoSuchAlgorithmException;
import java.security.Key;
import java.security.PrivateKey;
import java.security.cert.*;
import java.security.cert.Certificate;
import java.util.*;

/**
 * Created by valery.osheter on 19-Apr-16.
 */
public final class KeyStore extends KeyStoreSpi
{
    Slot slot = null;

    private static class Entry
    {
        Key key = null;
        CKCertificate cert = null;
        Entry(Key key) { this.key = key; }
    }

    private boolean loggedInSO = false;
    private Hashtable<String, Entry> map = new Hashtable<String, Entry>();
    private synchronized void removeMapAlias(String alias) { map.remove(alias); }

    public KeyStore(Slot slot)
    {
        this.slot = slot;
    }

    public Slot getSlot() { return slot; }

    void register(Key key, String alias)
    {
//        if (alias==null) return;
//        if (key==null) return;
        //map.put(alias, new Entry(key));
    }

    synchronized void setAlias(Key key, String alias) throws KeyStoreException
    {
        if (alias==null) return;
        if (key==null) return;

        try
        {
            CKKey pkcs11Key = getPkcs11Key(key);
            if (pkcs11Key==null) return;

            String old = pkcs11Key.getName();
            if (alias.equals(old)) return;

            pkcs11Key.setName(alias);
            removeMapAlias(old);
            removeMapAlias(alias);
        }
        catch (CKException e)  { throw new KeyStoreException(e); }
    }

    private synchronized Entry findKeyEntry(String alias)
    {
        Entry entry = map.get(alias);
        if (entry!=null)
        {
            if (entry.key!=null) return entry;
            return null;
        }
        entry = findPrivateKeyEntry(alias);
        if (entry!=null) return entry;
        return findSecretKeyEntry(alias);
    }

    private synchronized Entry findAnyEntry(String alias)
    {
        Entry entry = map.get(alias);
        if (entry!=null) return entry;

        entry = findKeyEntry(alias);
        if (entry!=null) return entry;
        return findCertEntry(alias);
    }

    private synchronized Entry findSecretKeyEntry(String alias)
    {
        if (!isPrintableAlias(alias)) return null;
        Entry entry = map.get(alias);
        if (entry!=null)
        {
            if (entry.cert!=null) return null;
            if (entry.key instanceof SecretKey) return entry;
            return null;
        }

        CKSecretKey pkcs11Key = CKSecretKey.find(slot, alias);
        if (pkcs11Key==null) return null;
        entry = new Entry(new SecretKey(pkcs11Key));
        map.put(alias, entry);
        return entry;
    }

    private synchronized Entry findPrivateKeyEntry(String alias)
    {
        if (!isPrintableAlias(alias)) return null;
        Entry entry = map.get(alias);
        if (entry!=null)
        {
            if (entry.key instanceof PrivateKey) return entry;
            return null;
        }

        CKObject pkcs11Key = slot.findObject(CKO_PRIVATE_KEY, -1, alias);
        if (pkcs11Key==null) return null;

        Key key;
        if (pkcs11Key instanceof CKRSAPrivateKey) key = new RSAPrivateKey((CKRSAPrivateKey)pkcs11Key);
        else if (pkcs11Key instanceof CKECPrivateKey) key = new ECPrivateKey((CKECPrivateKey)pkcs11Key);
        else if (pkcs11Key instanceof CKLIMAPrivateKey) key = new LIMAPrivateKey((CKLIMAPrivateKey)pkcs11Key);
        else if (pkcs11Key instanceof CKEDDSAPrivateKey) key = new EDDSAPrivateKey((CKEDDSAPrivateKey)pkcs11Key);
        else return null;

        long uid;
        try { uid = pkcs11Key.getUID(); } catch (CKException e) { return null; }

        entry = new Entry(key);
        entry.cert = CKCertificate.findCertByPrivateKeyUID(slot, uid);
        map.put(alias, entry);
        return entry;
    }

    private synchronized Entry findCertEntry(String alias)
    {
        if (!isPrintableAlias(alias)) return null;
        Entry entry = map.get(alias);
        if (entry!=null)
        {
            if (entry.key==null) return entry;
            return null;
        }

        CKCertificate cert = CKCertificate.find(slot, alias);
        if (cert==null) return null;
        long uid;
        try { uid = cert.getPrivateKeyUID(); } catch (CKException e) { return null; }

        if (0!=slot.findObjectHandle(uid)) return null;

        entry = new Entry(null);
        entry.cert = cert;
        map.put(alias, entry);
        return entry;
    }

    @Override
    public java.security.KeyStore.Entry engineGetEntry(String alias, java.security.KeyStore.ProtectionParameter protParam)
    {
        Entry entry = findAnyEntry(alias);
        if (entry==null) return null;

        if (entry.key==null) return new java.security.KeyStore.TrustedCertificateEntry(getX509(entry));
        if (entry.key instanceof SecretKey) return new java.security.KeyStore.SecretKeyEntry((SecretKey)entry.key);
        if (entry.key instanceof PrivateKey)
        {
            if (entry.cert==null) return new DYCryptoProvider.KeyEntry((PrivateKey)entry.key);
            return new java.security.KeyStore.PrivateKeyEntry((PrivateKey)entry.key, getChain(entry));
        }

        return null;
    }

    @Override
    public Key engineGetKey(String alias, char[] password) throws ProviderException
    {
        try { login(password); }
        catch (KeyStoreException e) { throw new ProviderException(e); }
        Entry entry = findKeyEntry(alias);
        if (entry==null) return null;
        return entry.key;
    }

    private X509Certificate getX509(Entry entry)
    {
        if (entry==null) return null;
        if (entry.cert==null) return null;
        try { return entry.cert.getX509(); }
        catch (CKException e) { return null; }
        catch (CertificateException e) { return null; }
    }

    private X509Certificate[] getChain(Entry entry)
    {
        if (entry==null) return null;
        X509Certificate cert = getX509(entry);
        if (cert==null) return new X509Certificate[] { null };
        if (cert.getSubjectX500Principal().equals(cert.getIssuerX500Principal()))  return new X509Certificate[] { cert }; // self signed

        ArrayList<X509Certificate> chain = new ArrayList<X509Certificate>();
        chain.add(cert);

        while (true)
        {
            CKCertificate object = slot.findObject(CKCertificate.class,
                    new CK_ATTRIBUTE[]
                            {
                                    new CK_ATTRIBUTE(CKA_TOKEN, true),
                                    new CK_ATTRIBUTE(CKA_CLASS, CKO_CERTIFICATE),
                                    new CK_ATTRIBUTE(CKA_SUBJECT, cert.getIssuerX500Principal().getEncoded())
                            });
            if (object==null) break;

            try { cert = object.getX509(); }
            catch (CKException e) { break; }
            catch (CertificateException e) { break; }

            chain.add(cert);
            if (cert.getSubjectX500Principal().equals(cert.getIssuerX500Principal())) break; // self signed
        }

        return chain.toArray(new X509Certificate[chain.size()]);
    }

    @Override
    public Certificate[] engineGetCertificateChain(String alias)
    {
        return getChain(findPrivateKeyEntry(alias));
    }

    @Override
    public boolean engineContainsAlias(String alias)
    {
        Entry entry = findAnyEntry(alias);
        return entry!=null;
    }

    @Override
    public boolean engineIsKeyEntry(String alias)
    {
        Entry entry = findKeyEntry(alias);
        return entry!=null;
    }

    @Override
    public boolean engineIsCertificateEntry(String alias)
    {
        Entry entry = findCertEntry(alias);
        return entry!=null;
    }

    @Override
    public Certificate engineGetCertificate(String alias)
    {
        return getX509(findAnyEntry(alias));
    }

    @Override
    public int engineSize() { return 0; }

    @Override
    public String engineGetCertificateAlias(Certificate cert)
    {
        byte[] encoded;
        try { encoded = cert.getEncoded(); }
        catch (CertificateEncodingException e) { return null; }

        CKCertificate object = slot.findObject(CKCertificate.class,
                new CK_ATTRIBUTE[]
                        {
                                new CK_ATTRIBUTE(CKA_TOKEN, true),
                                new CK_ATTRIBUTE(CKA_CLASS, CKO_CERTIFICATE),
                                new CK_ATTRIBUTE(CKA_VALUE, encoded)
                        });
        if (object==null) return null;
        try { return object.getName(); }
        catch (CKException e) { return null; }
    }

    private Key newKey(CKObject object)
    {
        if (object instanceof CKRSAPrivateKey) return new RSAPrivateKey((CKRSAPrivateKey)object);
        if (object instanceof CKECPrivateKey) return new ECPrivateKey((CKECPrivateKey)object);
        if (object instanceof CKLIMAPrivateKey) return new LIMAPrivateKey((CKLIMAPrivateKey)object);
        if (object instanceof CKEDDSAPrivateKey) return new EDDSAPrivateKey((CKEDDSAPrivateKey)object);
        if (object instanceof CKSecretKey) return new SecretKey((CKSecretKey)object);
        return null;
    }

    @Override
    public Enumeration<String> engineAliases()
    {
        Hashtable<String, Entry> map = new Hashtable<String, Entry>();

        String alias = null;

        ArrayList<CKObject> prvKeys = slot.findObjects(CKO_PRIVATE_KEY, -1);
        for (CKObject object : prvKeys)
        {
            try { alias = object.getName(); } catch (CKException e) { continue; }
            if (alias.isEmpty()) continue;
            Key key = newKey(object);
            if (key==null) continue;
            map.put(alias, new Entry(key));
        }

        ArrayList<CKObject> secretKeys = slot.findObjects(CKO_SECRET_KEY, -1);
        for (CKObject object : secretKeys)
        {
            try { alias = object.getName(); } catch (CKException e) { continue; }
            if (alias.isEmpty()) continue;
            Key key = newKey(object);
            if (key==null) continue;
            map.put(alias, new Entry(key));
        }

        ArrayList<CKObject> certs = slot.findObjects(CKO_CERTIFICATE, -1);
        for (CKObject object : certs)
        {
            try { alias = object.getName(); } catch (CKException e) { continue; }
            if (alias.isEmpty()) continue;
            Entry entry = map.get(alias);
            if (entry==null)
            {
                entry = new Entry(null);
                map.put(alias, entry);
            }
            entry.cert = (CKCertificate)object;
        }

        synchronized (this) { this.map = map; }
        return map.keys();
    }

    @Override
    public void engineStore(OutputStream stream, char[] password) throws IOException, NoSuchAlgorithmException, CertificateException
    {
        // nothing to do
    }

//    private static class LoginInfo
//    {
//        String userName = null;
//        String password = null;
//    }

    private void login(char[] password) throws KeyStoreException
    {
        if (password==null || password.length==0)
        {
            if (loggedInSO) slot.logout();
            loggedInSO = false;
            return;
        }
        loggedInSO = false;

        StringTokenizer stok = new StringTokenizer(new String(password), "\t\n\r\f\" :,{}");
        String tokens[] = new String[stok.countTokens()];
        for(int i=0; i<tokens.length; i++) tokens[i] = stok.nextToken();

        //{"username":"value", "password" :"value"}
        boolean so = tokens.length>2 && tokens[0].equalsIgnoreCase("USERNAME") && tokens[1].equalsIgnoreCase("SO");
        if (!so && !slot.isUserLoginRequired()) return;

        int rv = 0;
        if (so)
        {
            slot.logout();
            rv = slot.login(CKU_SO, password);
        }
        else
        {
            rv = slot.login(DYCKU_USER_CHECK, password);
        }
        if (rv!=0) throw new KeyStoreException("Login failed");
        loggedInSO = so;
    }

    @Override
    public void engineLoad(InputStream stream, char[] password) throws ProviderException
    {
        try { login(password); }
        catch (KeyStoreException e) { throw new ProviderException(e); }
    }

    private CKKey getPkcs11Key(Key key) throws KeyStoreException
    {
        if (key==null) return null;
        if (key instanceof DYKey) return ((DYKey)key).getPkcs11Key();
        throw new KeyStoreException("Unsupported key type");
    }

    private void destroyKey(Key key) throws KeyStoreException
    {
        if (key==null) return;
        try { getPkcs11Key(key).destroy(); }
        catch (CKException e) { throw new KeyStoreException(e); }
    }

    @Override
    public void engineDeleteEntry(String alias) throws KeyStoreException
    {
        Entry entry = findAnyEntry(alias);
        if (entry==null) return;

        removeMapAlias(alias);

        destroyKey(entry.key);
        if (entry.cert!=null)
        {
            try { entry.cert.destroy(); }
            catch (CKException e) { throw new KeyStoreException(e); }
        }
    }

    private CKCertificate createCert(String alias, X509Certificate cert) throws KeyStoreException
    {
        try { return CKCertificate.create(slot, alias, null, cert); }
        catch (CKException e) { throw new KeyStoreException(e); }
        catch (CertificateEncodingException e) { throw new KeyStoreException(e); }
    }

    @Override
    public synchronized void engineSetCertificateEntry(String alias, Certificate cert) throws KeyStoreException
    {
        setCertificateEntry(alias, cert, null);
    }

    private void createTrustedPublicKeyFromCert(String alias, X509Certificate cer) throws KeyStoreException
    {
        java.security.PublicKey pub = cer.getPublicKey();
        if (!(pub instanceof  java.security.interfaces.RSAPublicKey)) return;
        java.security.interfaces.RSAPublicKey rsa = (java.security.interfaces.RSAPublicKey)pub;

        RSAPublicKey key = new RSAPublicKey(rsa);
        key.createTrusted(this, alias);
    }

    private synchronized void setCertificateEntry(String alias, Certificate cert, char[] password) throws KeyStoreException
    {
        login(password);

        Entry entry = findCertEntry(alias);
        if (entry!=null)
        {
            engineDeleteEntry(alias);
            createCert(alias, (X509Certificate)cert);
            return;
        }

        entry = findPrivateKeyEntry(alias);
        if (entry!=null)
        {
            if (entry.cert!=null)
            {
                try { entry.cert.destroy(); } catch (CKException e) { }
                entry.cert = null;
            }

            createCert(alias, (X509Certificate)cert);
        }
        else
        {
            createCert(alias, (X509Certificate)cert);
        }

        if (loggedInSO) { createTrustedPublicKeyFromCert(alias, (X509Certificate)cert); }
    }

    @Override
    public void engineSetKeyEntry(String alias, byte[] key, Certificate[] chain) throws KeyStoreException
    {
        throw new KeyStoreException("Not supported");
    }

    @Override
    public void engineSetEntry(String alias, java.security.KeyStore.Entry entry, java.security.KeyStore.ProtectionParameter protParam) throws KeyStoreException
    {
        char[] password = null;
        if (protParam != null && protParam instanceof java.security.KeyStore.PasswordProtection) password =  ((java.security.KeyStore.PasswordProtection)protParam).getPassword();

        removeMapAlias(alias);

        if (entry instanceof java.security.KeyStore.PrivateKeyEntry)
        {
            java.security.KeyStore.PrivateKeyEntry keyEntry = (java.security.KeyStore.PrivateKeyEntry)entry;
            engineSetKeyEntry(alias, keyEntry.getPrivateKey(), password, keyEntry.getCertificateChain());
            return;
        }

        if (entry instanceof java.security.KeyStore.SecretKeyEntry)
        {
            java.security.KeyStore.SecretKeyEntry secretKeyEntry = (java.security.KeyStore.SecretKeyEntry)entry;
            engineSetKeyEntry(alias, secretKeyEntry.getSecretKey(), password, null);
            return;
        }

        if (entry instanceof java.security.KeyStore.TrustedCertificateEntry)
        {
            java.security.KeyStore.TrustedCertificateEntry certEntry = (java.security.KeyStore.TrustedCertificateEntry)entry;
            setCertificateEntry(alias, certEntry.getTrustedCertificate(), password);
            return;
        }

        if (entry instanceof DYCryptoProvider.KeyEntry)
        {
            DYCryptoProvider.KeyEntry keyEntry = (DYCryptoProvider.KeyEntry)entry;
            engineSetKeyEntry(alias, keyEntry.key, password, null);
            return;
        }

        throw new KeyStoreException(new UnsupportedOperationException("unsupported entry type: " + entry.getClass().getName()));
    }

    private void saveChain(CKKey key, String alias, Certificate[] chain) throws KeyStoreException
    {
        if (chain==null) return;
        try
        {
            long uid = key.getUID();
            ArrayList<CKCertificate> list = CKCertificate.findCertsByPrivateKeyUID(slot, uid);
            for (int i=0; i<list.size(); i++) list.get(i).destroy();
        }
        catch (CKException e) { } // ignore

        createCert(alias, (X509Certificate)chain[0]);

        for (int i=1; i<chain.length; i++)
        {
            try { createCert(null, (X509Certificate)chain[i]); }
            catch (KeyStoreException e) {} // ignore
        }
    }

    @Override
    synchronized public void engineSetKeyEntry(String alias, Key key, char[] password, Certificate[] chain) throws KeyStoreException
    {
        if (!isPrintableAlias(alias)) throw new KeyStoreException("invalid entry name " + alias);

        login(password);

        removeMapAlias(alias);

        if (key instanceof DYKey)
        {
            ((DYKey)key).save(this, alias);
            if (key instanceof RSAPrivateKey || key instanceof ECPrivateKey)
                saveChain(((DYKey)key).getPkcs11Key(), alias, chain);
            return;
        }

        if (key instanceof java.security.interfaces.RSAPrivateCrtKey)
        {
            engineDeleteEntry(alias);
            RSAPrivateKey prvKey = new RSAPrivateKey().initForImport((java.security.interfaces.RSAPrivateCrtKey)key);
            prvKey.save(this, alias);
            saveChain(prvKey.pkcs11Key, alias, chain);
            return;
        }

        if (key instanceof java.security.interfaces.ECPrivateKey)
        {
            engineDeleteEntry(alias);
            ECPrivateKey prvKey = new ECPrivateKey().initForImport((java.security.interfaces.ECPrivateKey)key);
            prvKey.save(this, alias);
            saveChain(prvKey.pkcs11Key, alias, chain);
            return;
        }

        if (key instanceof javax.crypto.spec.SecretKeySpec)
        {
            engineDeleteEntry(alias);
            SecretKey secretKey = new SecretKey().initForImport(null, -1, (javax.crypto.spec.SecretKeySpec)key);
            return;
        }

        throw new KeyStoreException("Unsupported key type");
    }

    private static boolean isPrintableAlias(String alias)
    {
        if (alias==null) return false;
        return true;
    }

    @Override
    public boolean engineEntryInstanceOf(String alias,  Class<? extends java.security.KeyStore.Entry> entryClass)
    {
        if (entryClass == java.security.KeyStore.TrustedCertificateEntry.class) return findCertEntry(alias)!=null;
        if (entryClass == java.security.KeyStore.PrivateKeyEntry.class) return findPrivateKeyEntry(alias)!=null;
        if (entryClass == java.security.KeyStore.SecretKeyEntry.class) return findSecretKeyEntry(alias)!=null;
        return false;
    }

    @Override
    public Date engineGetCreationDate(String alias)
    {
        if (alias == null) return null;
        return new Date();
    }
}
