package com.dyadicsec.pkcs11;

import java.security.ProviderException;
import java.util.*;
import static com.dyadicsec.cryptoki.CK.*;

/**
 * Created by valery.osheter on 22-Jun-17.
 */
public final class Slot
{
    private static Map<String, Slot> slotsByName = null;
    private static Map<Integer, Slot> slotsById = null;

    int id = -1;
    private String name = null;
    private int knownAuthReq = -1;
    private Queue<Session> sessions = new LinkedList<Session>();
    Session persistentSession;

    public Session getPersistentSession() { return persistentSession; }

    static
    {
        int rv = Library.C_Initialize();
        if (rv!=0) throw new ProviderException("Can't initialize PKCS11 library");

        long rvLen = Library.C_GetSlotList(true, null);
        rv = Library.rvErr(rvLen);
        int count = Library.rvValue(rvLen);
        if (rv!=0) count = 0;
        int[] list = new int[count];
        if (count>0)
        {
            rvLen = Library.C_GetSlotList(true, list);
            rv = Library.rvErr(rvLen);
            if (rv!=0) count = 0;
            else count = Library.rvValue(rvLen);
        }
        slotsByName = new HashMap<String, Slot>();
        slotsById = new HashMap<Integer, Slot>();
        for (int i=0; i<count; i++)
        {
            Slot slot = new Slot(list[i]);
            slotsById.put(list[i], slot);
            if (!slot.name.isEmpty()) slotsByName.put(slot.name, slot);
        }
    }

    private Slot(int id)
    {
        this.id = id;

        CK_SLOT_INFO si = new CK_SLOT_INFO();
        int rv = Library.C_GetSlotInfo(id, si);
        name = rv==0 ? new String(si.slotDescription) : "";

        try { persistentSession = openSession(); }
        catch (CKException e) { throw new ProviderException(e); }
    }

    private static CK_ATTRIBUTE[] buildFindTemplate(int clazz, int keyType, String name)
    {
        int count = 1;
        CK_ATTRIBUTE aToken = new CK_ATTRIBUTE(CKA_TOKEN, true);
        CK_ATTRIBUTE aClass = null;
        CK_ATTRIBUTE aPrivate = null;
        CK_ATTRIBUTE aKeyType = null;
        CK_ATTRIBUTE aID = null;
        if (keyType!=-1) { aKeyType = new CK_ATTRIBUTE(CKA_KEY_TYPE, keyType); count++; }
        if (name!=null) { aID = new CK_ATTRIBUTE(CKA_ID, Utils.name2id(name)); count++; }

        if (clazz!=-1)
        {
            aClass = new CK_ATTRIBUTE(CKA_CLASS, clazz); count++;
            boolean isPrivate = (clazz== CKO_SECRET_KEY || clazz== CKO_PRIVATE_KEY);
            aPrivate = new CK_ATTRIBUTE(CKA_PRIVATE, isPrivate); count++;
        }

        CK_ATTRIBUTE[] t = new CK_ATTRIBUTE[count];
        count = 0;
        t[count++] = aToken;
        if (aClass!=null) t[count++] = aClass;
        if (aPrivate!=null) t[count++] = aPrivate;
        if (aKeyType!=null) t[count++] = aKeyType;
        if (aID!=null) t[count++] = aID;
        return t;
    }

    public static Iterable<Slot> getList() { return slotsById.values(); }

    public static Slot find(String name)
    {
        if (name==null) return getDefault();
        return slotsByName.get(name);
    }

    public static Slot find(int id) { return slotsById.get(id); }
    public static Slot getDefault() { return slotsById.get(0); }

    public String getName() { return name; }

    public boolean isUserLoginRequired()
    {
        if (knownAuthReq<0)
        {
            CK_TOKEN_INFO ti = new CK_TOKEN_INFO();
            int rv = Library.C_GetTokenInfo(id, ti);
            knownAuthReq = (rv==0) && (((ti.flags & CKF_LOGIN_REQUIRED)!=0)) ? 1 : 0;
        }

        return knownAuthReq>0;
    }

    public int getID() { return id; }

    public Session getSession() throws CKException
    {
        synchronized (sessions)
        {
            if (!sessions.isEmpty()) return sessions.remove();
        }
        return openSession();
    }

    public Session openSession() throws CKException
    {
        long rvLong = Library.C_OpenSession(id, CKF_RW_SESSION| CKF_SERIAL_SESSION);
        int rv = Library.rvErr(rvLong);
        if (rv!=0) throw new ProviderException("Error opening PKCS#11 session: rv="+rv);
        return new Session(this, Library.rvValue(rvLong));
    }

    public void releaseSession(Session session)
    {
        if (session==null) return;
        if (session.handle==0) return;

        if (session.operationInProgress)
        {
            session.close();
            return;
        }

        synchronized (sessions)
        {
            sessions.add(session);
        }
    }

    public int findObjectHandle(CK_ATTRIBUTE[] template)
    {
        int handle = 0;
        Session session = null;
        boolean init = false;
        try
        {
            session = getSession();
            int rv = session.C_FindObjectsInit(template);
            if (rv!=0) throw new CKException("C_FindObjectsInit", rv);
            init = true;

            int[] buffer = new int[1];

            long rvLong = session.C_FindObjects(buffer);
            if (Library.rvErr(rvLong)==0 && Library.rvValue(rvLong)>0) handle = buffer[0];
        }
        catch (CKException e) { }
        finally
        {
            if (init) session.C_FindObjectsFinal();
            releaseSession(session);
        }
        return handle;
    }

    CKObject newObject(int handle, int clazz, int keyType)
    {
        if (handle==0) return null;
        CKObject object = null;
        switch (clazz)
        {
            case CKO_CERTIFICATE  : object = new CKCertificate(); break;
            case CKO_SECRET_KEY   : object = new CKSecretKey(); break;
            case CKO_DATA         : object = new CKData(); break;
            case CKO_PRIVATE_KEY  :
                if (keyType==-1)
                {
                    try { keyType = getAttributeValueInt(handle, CKA_KEY_TYPE);  }
                    catch (CKException e) { return null; }
                }
                switch (keyType)
                {
                    case CKK_RSA            : object = new CKRSAPrivateKey(); break;
                    case CKK_EC             : object = new CKECPrivateKey(); break;
                    case DYCKK_LIMA         : object = new CKLIMAPrivateKey(); break;
                    case DYCKK_ADV_PASSWORD : object = new CKPasswordKey(); break;
                    case DYCKK_ADV_PRF      : object = new CKPRFKey(); break;
                    case DYCKK_EDDSA        : object = new CKEDDSAPrivateKey(); break;
                }
                break;
            case CKO_PUBLIC_KEY  :
                if (keyType==-1)
                {
                    try { keyType = getAttributeValueInt(handle, CKA_KEY_TYPE);  }
                    catch (CKException e) { return null; }
                }
                switch (keyType)
                {
                    case CKK_RSA            : object = new CKRSAPublicKey(); break;
                    case CKK_EC             : object = new CKECPublicKey(); break;
                    case DYCKK_LIMA         : object = new CKLIMAPublicKey(); break;
                    case DYCKK_EDDSA        : object = new CKEDDSAPublicKey(); break;
                }
                break;
        }

        if (object!=null)
        {
            object.handle = handle;
            object.slot = this;
        }
        return object;
    }

    <T extends CKObject> T newObject(Class<T> c, int handle)
    {
        if (handle==0) return null;
        try
        {
            T object = c.newInstance();
            object.handle = handle;
            object.slot = this;
            return object;
        }
        catch (java.lang.Exception e) { }
        return null;
    }

    private <T extends CKObject> ArrayList<T> newObjects(Class<T> c, ArrayList<Integer> handles)
    {
        ArrayList<T> list = new ArrayList<T>();
        list.ensureCapacity(handles.size());
        for (Integer handle : handles) list.add(newObject(c, handle));
        return list;
    }

    public int findObjectHandle(long uid)
    {
        return findObjectHandle(new CK_ATTRIBUTE[] { new CK_ATTRIBUTE(DYCKA_UID, uid) });
    }

    public int findObjectHandle(int clazz, int keyType, String name)
    {
        return findObjectHandle(buildFindTemplate(clazz, keyType, name));
    }

    public <T extends CKObject> T findObject(Class<T> c, CK_ATTRIBUTE[] template)
    {
        return newObject(c, findObjectHandle(template));
    }

    public <T extends CKObject> T findObject(Class<T> clazz, long uid)
    {
        return newObject(clazz, findObjectHandle(uid));
    }

    public CKObject findObject(int clazz, int keyType, String name)
    {
        return newObject(findObjectHandle(clazz, keyType, name), clazz, keyType);
    }

    public static <T extends CKObject> T findObjectInAllSlots(Class<T> c, CK_ATTRIBUTE[] template)
    {
        for (Slot slot : slotsById.values())
        {
            T object = slot.findObject(c, template);
            if (object!=null) return object;
        }
        return null;
    }

    public static <T extends CKObject> T findObjectInAllSlots(Class<T> c, long uid)
    {
        return findObjectInAllSlots(c, new CK_ATTRIBUTE[] { new CK_ATTRIBUTE(DYCKA_UID, uid) });
    }

    public static CKObject findObjectInAllSlots(int clazz, int keyType, String name)
    {
        CK_ATTRIBUTE[] template = buildFindTemplate(clazz, keyType, name);
        for (Slot slot : slotsById.values())
        {
            int handle = slot.findObjectHandle(template);
            if (handle!=0) return slot.newObject(handle, clazz, keyType);
        }
        return null;
    }

    public ArrayList<Integer> findObjectHandles(CK_ATTRIBUTE[] template)
    {
        ArrayList<Integer> list = new ArrayList<Integer>();
        Session session = null;
        boolean init = false;
        try
        {
            session = getSession();
            int rv = session.C_FindObjectsInit(template);
            if (rv!=0) throw new CKException("C_FindObjectsInit", rv);
            init = true;

            final int MAX_COUNT = 256;
            int[] buffer = new int[MAX_COUNT];

            for (;;)
            {
                long rvLong = session.C_FindObjects(buffer);
                rv = Library.rvErr(rvLong);
                if (rv!=0) break;
                int count = Library.rvValue(rvLong);
                if (count == 0) break;
                list.ensureCapacity(list.size()+count);
                for (int i=0; i<count; i++) list.add(buffer[i]);
            }
        }
        catch (CKException e) { }
        finally
        {
            if (init) session.C_FindObjectsFinal();
            releaseSession(session);
        }
        return list;
    }

    private ArrayList<CKObject> newObjects(ArrayList<Integer> handles, int clazz, int keyType)
    {
        ArrayList<CKObject> list = new ArrayList<CKObject>();
        list.ensureCapacity(handles.size());
        for (Integer handle : handles)
        {
            CKObject object = newObject(handle, clazz, keyType);
            if (object!=null) list.add(object);
        }
        return list;
    }

    public <T extends CKObject> ArrayList<T> findObjects(Class<T> c, CK_ATTRIBUTE[] template)
    {
        return newObjects(c, findObjectHandles(template));
    }

    public ArrayList<CKObject> findObjects(int clazz, int keyType)
    {
        return newObjects(findObjectHandles(buildFindTemplate(clazz, keyType, null)), clazz, keyType);
    }

    public <T extends CKObject> ArrayList<T> findObjects(Class<T> c, int clazz, int keyType)
    {
        return newObjects(c, findObjectHandles(buildFindTemplate(clazz, keyType, null)));
    }

    public static <T extends CKObject> ArrayList<T> findObjectsInAllSlots(Class<T> c, CK_ATTRIBUTE[] template)
    {
        ArrayList<T> list = new ArrayList<T>();
        for (Slot slot : slotsById.values())
        {
            ArrayList<Integer> handles = slot.findObjectHandles(template);
            list.ensureCapacity(list.size() + handles.size());
            for (Integer handle : handles) list.add(slot.newObject(c, handle));
        }
        return list;
    }

    public static ArrayList<CKObject> findObjectsInAllSlots(int clazz, int keyType)
    {
        CK_ATTRIBUTE[] template = buildFindTemplate(clazz, keyType, null);
        ArrayList<CKObject> list = new ArrayList<CKObject>();
        for (Slot slot : slotsById.values())
        {
            ArrayList<Integer> handles = slot.findObjectHandles(template);
            list.ensureCapacity(list.size() + handles.size());
            for (Integer handle : handles)
            {
                CKObject object =slot.newObject(handle, clazz, keyType);
                if (object!=null) list.add(object);
            }
        }
        return list;
    }

    public static <T extends CKObject>ArrayList<T> findObjectsInAllSlots(Class<T> c, int clazz, int keyType)
    {
        CK_ATTRIBUTE[] template = buildFindTemplate(clazz, keyType, null);
        ArrayList<T> list = new ArrayList<T>();
        for (Slot slot : slotsById.values())
        {
            ArrayList<Integer> handles = slot.findObjectHandles(template);
            list.ensureCapacity(list.size() + handles.size());
            for (Integer handle : handles)
            {
                T object = slot.newObject(c, handle);
                if (object!=null) list.add(object);
            }
        }
        return list;
    }

    public int login(int userType, char[] password)
    {
        if (password!=null && password.length==0) password=null;

        if (knownAuthReq==0 && password==null) return 0;

        int rv = persistentSession.C_Login(userType, password);
        if (userType==CKU_USER || userType==DYCKU_USER_CHECK)
        {
            if (rv==0 && password!=null) knownAuthReq = 1;
            if (rv==0 && password==null) knownAuthReq = 0;
            if (rv!=0 && password==null) knownAuthReq = 1;
        }

        return rv;
    }

    public int login(char[] password)
    {
        return login(CKU_USER, password);
    }

    public void logout()
    {
        if (knownAuthReq==0) return;
        persistentSession.C_Logout();
    }

    public int createObject(CK_ATTRIBUTE[] template) throws CKException
    {
        long rvLong = persistentSession.C_CreateObject(template);
        int rv = Library.rvErr(rvLong);
        if (rv!=0) throw new CKException("C_CreateObject", rv);
        return Library.rvValue(rvLong);
    }

    public void destroyObject(int handle) throws CKException
    {
        int rv = persistentSession.C_DestroyObject(handle);
        if (rv!=0) throw new CKException("C_DestroyObject", rv);
    }

    public int generateKey(int mechType, CK_ATTRIBUTE[] template) throws CKException
    {
        long rvLong = persistentSession.C_GenerateKey(new CK_MECHANISM(mechType), template);
        int rv = Library.rvErr(rvLong);
        if (rv!=0) throw new CKException("C_GenerateKey", rv);
        return Library.rvValue(rvLong);
    }

    public int generateKeyPair(int mechType, CK_ATTRIBUTE[] pubTemplate, CK_ATTRIBUTE[] prvTemplate) throws CKException
    {
        long rvLong = persistentSession.C_GenerateKeyPair(new CK_MECHANISM(mechType), pubTemplate, prvTemplate, null);
        int rv = Library.rvErr(rvLong);
        if (rv!=0) throw new CKException("C_GenerateKeyPair", rv);
        return Library.rvValue(rvLong);
    }

    public void setAttributeValue(int object, CK_ATTRIBUTE[] template) throws CKException
    {
        int rv = persistentSession.C_SetAttributeValue(object, template);
        if (rv!=0) throw new CKException("C_SetAttributeValue", rv);
    }

    public int getAttributeSize(int object, int attribute) throws CKException
    {
        CK_ULONG_PTR pSize = new CK_ULONG_PTR();
        int rv = persistentSession.C_GetAttributeSize(object, attribute, pSize);
        if (rv!=0) throw new CKException("C_GetAttributeSize", rv);
        return pSize.value;
    }

    public void getAttributeValue(int object, CK_ATTRIBUTE[] template) throws CKException
    {
        int rv = persistentSession.C_GetAttributeValue(object, template);
        if (rv!=0) throw new CKException("C_GetAttributeValue", rv);
    }

    public int getAttributeValueInt(int object, int attribute) throws CKException
    {
        CK_ATTRIBUTE[] t = new CK_ATTRIBUTE[] { new CK_ATTRIBUTE(attribute) };
        getAttributeValue(object, t);
        return t[0].toInt();
    }

    public int deriveKey(CK_MECHANISM mechanism, int key, CK_ATTRIBUTE[] t) throws CKException
    {
        long rvLong = persistentSession.C_DeriveKey(mechanism, key, t);
        int rv = Library.rvErr(rvLong);
        if (rv!=0) throw new CKException("C_DeriveKey", rv);
        return Library.rvValue(rvLong);
    }

    public int unwrapKey(CK_MECHANISM mechanism, int unwrappingKey, byte[] in, int inOffset, int inLen, CK_ATTRIBUTE[] template) throws CKException
    {
        long rvLong = persistentSession.C_UnwrapKey(mechanism, unwrappingKey, in, inOffset, inLen, template);
        int rv = Library.rvErr(rvLong);
        if (rv!=0) throw new CKException("C_UnwrapKey", rv);
        return Library.rvValue(rvLong);
    }

    public void generateRandom(byte[] out, int outOffset, int outLen) throws CKException
    {
        int rv = persistentSession.C_GenerateRandom(out, outOffset, outLen);
        if (rv!=0) throw new CKException("C_GenerateRandom", rv);
    }

    public byte[] generateRandom(int outLen) throws CKException
    {
        byte[] out = new byte[outLen];
        generateRandom(out, 0, outLen);
        return out;
    }

}
