package com.dyadicsec.pkcs11;

import java.security.ProviderException;
import java.util.Arrays;
import static com.dyadicsec.cryptoki.CK.*;

/**
 * Created by valery.osheter on 18-Jun-17.
 */
public final class Session
{
    int handle = 0;
    private int slotID = -1;
    private Slot slot = null;
    boolean operationInProgress = false;

    public boolean isValid() { return handle!=0; }

    protected void finalize()
    {
        close();
    }

    Session(int slotID, int handle)
    {
        this.slotID = slotID;
        this.handle = handle;
    }

    Session(Slot slot, int handle)
    {
        this.slot = slot;
        this.slotID = slot.id;
        this.handle = handle;
    }

    public Session()  { }

    public int getSlotID() { return slotID; }

    public static Session open(int slotID)
    {
        long rvLong = Library.C_OpenSession(slotID, CKF_RW_SESSION| CKF_SERIAL_SESSION);
        int rv = Library.rvErr(rvLong);
        if (rv!=0) throw new ProviderException("Error opening PKCS#11 session: rv="+rv);
        return new Session(slotID, Library.rvValue(rvLong));
    }

    public void close()
    {
        if (handle!=0) C_CloseSession(handle);
        handle = 0;
    }

    public int C_Login(int userType, char[] password)
    { return C_Login(handle, userType, password); }

    public int C_Logout()
    { return C_Logout(handle); }

    public int C_GetSessionInfo(CK_SESSION_INFO info)
    { return C_GetSessionInfo(handle, info); }

    public long C_CreateObject(CK_ATTRIBUTE[] template)
    { return C_CreateObject(handle, template);}

    public int C_DestroyObject(int object)
    { return C_DestroyObject(handle, object); }

    public long C_GenerateKey(CK_MECHANISM mechanism, CK_ATTRIBUTE[] template)
    { return C_GenerateKey(handle, mechanism, template); }

    public long C_GenerateKeyPair(CK_MECHANISM mechanism, CK_ATTRIBUTE[] pubTemplate, CK_ATTRIBUTE[] prvTemplate, CK_ULONG_PTR pubKey)
    { return C_GenerateKeyPair(handle, mechanism, pubTemplate, prvTemplate, pubKey); }

    public int C_SetAttributeValue(int object, CK_ATTRIBUTE[] template)
    { return C_SetAttributeValue(handle, object, template); }

    int initOperation(int rv)
    {
        if (rv==0) operationInProgress = true;
        return rv;
    }

    int finalOperation(byte[] out, int rv)
    {
        if (rv==CKR_BUFFER_TOO_SMALL) return rv;
        if (rv!=0 || out!=null) operationInProgress = false;
        return rv;
    }

    public int C_FindObjectsInit(CK_ATTRIBUTE[] template)
    { return initOperation(C_FindObjectsInit(handle, template));  }

    public long C_FindObjects(int[] objects)
    { return C_FindObjects(handle, objects); }

    public int C_FindObjectsFinal()
    {
        int rv = C_FindObjectsFinal(handle);
        if (rv==0) operationInProgress = false;
        return rv;
    }

    public int C_EncryptInit(CK_MECHANISM mechanism, int key)
    { return initOperation(C_EncryptInit(handle, mechanism, key)); }

    public long C_Encrypt(byte[] in, int inOffset, int inLen, byte[] out, int outOffset)
    {
        long rvLong = C_Encrypt(handle, in, inOffset, inLen, out, outOffset);
        finalOperation(out, Library.rvErr(rvLong));
        return rvLong;
    }

    public long C_EncryptUpdate(byte[] in, int inOffset, int inLen, byte[] out, int outOffset)
    { return C_EncryptUpdate(handle, in, inOffset, inLen, out, outOffset); }

    public long C_EncryptFinal(byte[] out, int outOffset)
    {
        long rvLong = C_EncryptFinal(handle, out, outOffset);
        finalOperation(out, Library.rvErr(rvLong));
        return rvLong;
    }

    public int C_DecryptInit(CK_MECHANISM mechanism, int key)
    { return initOperation(C_DecryptInit(handle, mechanism, key)); }

    public long C_Decrypt(byte[] in, int inOffset, int inLen, byte[] out, int outOffset)
    {
        long rvLong = C_Decrypt(handle, in, inOffset, inLen, out, outOffset);
        finalOperation(out, Library.rvErr(rvLong));
        return rvLong;
    }

    public long C_DecryptUpdate(byte[] in, int inOffset, int inLen, byte[] out, int outOffset)
    { return C_DecryptUpdate(handle, in, inOffset, inLen, out, outOffset); }

    public long C_DecryptFinal(byte[] out, int outOffset)
    {
        long rvLong = C_DecryptFinal(handle, out, outOffset);
        finalOperation(out, Library.rvErr(rvLong));
        return rvLong;
    }

    public int C_SignInit(CK_MECHANISM mechanism, int key)
    { return initOperation(C_SignInit(handle, mechanism, key)); }

    public long C_Sign(byte[] in, int inOffset, int inLen, byte[] signature)
    {
        long rvLong = C_Sign(handle, in, inOffset, inLen, signature);
        finalOperation(signature, Library.rvErr(rvLong));
        return rvLong;
    }

    public int C_SignUpdate(byte[] in, int inOffset, int inLen)
    { return C_SignUpdate(handle, in, inOffset, inLen); }

    public long C_SignFinal(byte[] signature)
    {
        long rvLong = C_SignFinal(handle, signature);
        finalOperation(signature, Library.rvErr(rvLong));
        return rvLong;
    }

    public int C_VerifyInit(CK_MECHANISM mechanism, int key)
    { return initOperation(C_VerifyInit(handle, mechanism, key)); }

    public int C_Verify(byte[] in, int inOffset, int inLen, byte[] signature)
    { return finalOperation(signature, C_Verify(handle, in, inOffset, inLen, signature)); }

    public int C_VerifyUpdate(byte[] in, int inOffset, int inLen)
    { return C_VerifyUpdate(handle, in, inOffset, inLen); }

    public int C_VerifyFinal(byte[] signature)
    { return finalOperation(signature, C_VerifyFinal(handle, signature)); }

    public long C_DeriveKey(CK_MECHANISM mechanism, int key, CK_ATTRIBUTE[] t)
    { return C_DeriveKey(handle, mechanism, key, t); }

    public long C_WrapKey(CK_MECHANISM mechanism, int wrappingKey, int key, byte[] out, int outOffset)
    { return C_WrapKey(handle, mechanism, wrappingKey, key, out, outOffset); }

    public long C_UnwrapKey(CK_MECHANISM mechanism, int unwrappingKey, byte[] in, int inOffset, int inLen, CK_ATTRIBUTE[] template)
    { return C_UnwrapKey(handle, mechanism, unwrappingKey, in, inOffset, inLen, template); }

    public int C_GenerateRandom(byte[] out, int outOffset, int outLen)
    { return C_GenerateRandom(handle, out, outOffset, outLen); }

    private native int C_CloseSession(int session);
    private native int C_Login(int session, int userType, char[] password);
    private native int C_Logout(int session);
    private native int C_GetSessionInfo(int session, CK_SESSION_INFO info);
    private native long C_CreateObject(int session, CK_ATTRIBUTE[] template);
    private native int C_DestroyObject(int session, int object);
    private native int C_GenerateRandom(int session, byte[] out, int outOffset, int outLen);
    private native long C_GenerateKey(int session, CK_MECHANISM mechanism, CK_ATTRIBUTE[] template);
    private native long C_GenerateKeyPair(int session, CK_MECHANISM mechanism, CK_ATTRIBUTE[] pubTemplate, CK_ATTRIBUTE[] prvTemplate, CK_ULONG_PTR pubKey);
    private native int C_SetAttributeValue(int session, int object, CK_ATTRIBUTE[] template);
    private native int C_FindObjectsInit(int session, CK_ATTRIBUTE[] template);
    private native long C_FindObjects(int session, int[] objects);
    private native int C_FindObjectsFinal(int session);
    private native int C_EncryptInit(int session, CK_MECHANISM mechanism, int key);
    private native long C_Encrypt(int session, byte[] in, int inOffset, int inLen, byte[] out, int outOffset);
    private native long C_EncryptUpdate(int session, byte[] in, int inOffset, int inLen, byte[] out, int outOffset);
    private native long C_EncryptFinal(int session, byte[] out, int outOffset);
    private native int C_DecryptInit(int session, CK_MECHANISM mechanism, int key);
    private native long C_Decrypt(int session, byte[] in, int inOffset, int inLen, byte[] out, int outOffset);
    private native long C_DecryptUpdate(int session, byte[] in, int inOffset, int inLen, byte[] out, int outOffset);
    private native long C_DecryptFinal(int session, byte[] out, int outOffset);
    private native int C_SignInit(int session, CK_MECHANISM mechanism, int key);
    private native long C_Sign(int session, byte[] in, int inOffset, int inLen, byte[] signature);
    private native int C_SignUpdate(int session, byte[] in, int inOffset, int inLen);
    private native long C_SignFinal(int session, byte[] signature);
    private native int C_VerifyInit(int session, CK_MECHANISM mechanism, int key);
    private native int C_Verify(int session, byte[] in, int inOffset, int inLen, byte[] signature);
    private native int C_VerifyUpdate(int session, byte[] in, int inOffset, int inLen);
    private native int C_VerifyFinal(int session, byte[] signature);
    private native long C_DeriveKey(int session, CK_MECHANISM mechanism, int key, CK_ATTRIBUTE[] t);
    private native long C_WrapKey(int session, CK_MECHANISM mechanism, int wrappingKey, int key, byte[] out, int outOffset);
    private native long C_UnwrapKey(int session, CK_MECHANISM mechanism, int unwrappingKey, byte[] in, int inOffset, int inLen, CK_ATTRIBUTE[] template);

    private native int C_GetAttributeValueSize(int session, int object, int[] templateInfo);
    private native int C_GetAttributeValueData(int session, int object, CK_ATTRIBUTE[] template);

    private native long DYC_SelfSignX509(int session, int object, int mechanism, char[] subject, byte[] serial, int days, byte[] out);

    public byte[] DYC_SelfSignX509(int object, int mechanism, char[] subject, byte[] serial, int days) throws CKException
    {
        long rvLong = DYC_SelfSignX509(handle, object, mechanism, subject, serial, days, null);
        int rv = Library.rvErr(rvLong);
        if (rv!=0) throw new CKException("DYC_SelfSignX509", rv);
        int length = Library.rvValue(rvLong);
        byte[] out = new byte[length];
        DYC_SelfSignX509(handle, object, mechanism, subject, serial, days, out);
        return out;
    }

    public int C_GetAttributeSize(int object, int attribute, CK_ULONG_PTR pOutLen)
    {
        int[] templateInfo = new int[1];
        templateInfo[0] = attribute;
        int rv = C_GetAttributeValueSize(handle, object, templateInfo);
        if (rv!=0) return rv;
        pOutLen.value = templateInfo[0];
        return 0;
    }

    public int C_GetAttributeValue(int object, CK_ATTRIBUTE[] template)
    {
        boolean allKnownSize = true;
        for (CK_ATTRIBUTE a : template)
        {
            if (!CK_ATTRIBUTE.isBool(a.type) || !CK_ATTRIBUTE.isInt(a.type)) { allKnownSize = false; break; }
        }

        if (!allKnownSize)
        {
            int[] templateInfo = new int[template.length];
            for (int i=0; i<template.length; i++) templateInfo[i] = template[i].type;
            int rv = C_GetAttributeValueSize(handle, object, templateInfo);
            if (rv!=0) return rv;
            for (int i=0; i<template.length; i++)
            {
                CK_ATTRIBUTE a = template[i];
                if (CK_ATTRIBUTE.isBool(a.type) || CK_ATTRIBUTE.isInt(a.type)) continue;
                int size = templateInfo[i];
                a.value = new byte[size];
            }
        }

        return C_GetAttributeValueData(handle, object, template);
    }

    public void generateRandom(byte[] out, int outOffset, int outLen) throws CKException
    {
        int rv = C_GenerateRandom(out, outOffset, outLen);
        if (rv!=0) throw new CKException("C_GenerateRandom", rv);
    }

    public void encryptInit(CK_MECHANISM mechanism, CKKey key) throws CKException
    {
        int rv = C_EncryptInit(mechanism, key.handle);
        if (rv!=0) throw new CKException("C_EncryptInit", rv);
    }

    public void decryptInit(CK_MECHANISM mechanism, CKKey key) throws CKException
    {
        int rv = C_DecryptInit(mechanism, key.handle);
        if (rv!=0) throw new CKException("C_DecryptInit", rv);
    }

    public void signInit(CK_MECHANISM mechanism, CKKey key) throws CKException
    {
        int rv = C_SignInit(mechanism, key.handle);
        if (rv!=0) throw new CKException("C_SignInit", rv);
    }

    public void verifyInit(CK_MECHANISM mechanism, CKKey key) throws CKException
    {
        int rv = C_VerifyInit(mechanism, key.handle);
        if (rv!=0) throw new CKException("C_VerifyInit", rv);
    }

    public int encryptUpdate(byte[] in, int inOffset, int inLen, byte[] out, int outOffset) throws CKException
    {
        long rvLong = C_EncryptUpdate(in, inOffset, inLen, out, outOffset);
        int rv = Library.rvErr(rvLong);
        if (rv!=0) throw new CKException("C_EncryptUpdate", rv);
        return Library.rvValue(rvLong);
    }

    public int decryptUpdate(byte[] in, int inOffset, int inLen, byte[] out, int outOffset) throws CKException
    {
        long rvLong = C_DecryptUpdate(in, inOffset, inLen, out, outOffset);
        int rv = Library.rvErr(rvLong);
        if (rv!=0) throw new CKException("C_DecryptUpdate", rv);
        return Library.rvValue(rvLong);
    }

    public void signUpdate(byte[] in, int inOffset, int inLen) throws CKException
    {
        int rv = C_SignUpdate(in, inOffset, inLen);
        if (rv!=0) throw new CKException("C_SignUpdate", rv);
    }

    public void verifyUpdate(byte[] in, int inOffset, int inLen) throws CKException
    {
        int rv = C_VerifyUpdate(in, inOffset, inLen);
        if (rv!=0) throw new CKException("C_VerifyUpdate", rv);
    }

    public int encryptFinal(byte[] out, int outOffset) throws CKException
    {
        long rvLong = C_EncryptFinal(out, outOffset);
        int rv = Library.rvErr(rvLong);
        if (rv!=0) throw new CKException("C_EncryptFinal", rv);
        return Library.rvValue(rvLong);
    }

    public int decryptFinal(byte[] out, int outOffset) throws CKException
    {
        long rvLong = C_DecryptFinal(out, outOffset);
        int rv = Library.rvErr(rvLong);
        if (rv!=0) throw new CKException("C_DecryptFinal", rv);
        return Library.rvValue(rvLong);
    }

    public byte[] signFinal(int outLen) throws CKException
    {
        if (outLen==0)
        {
            long rvLong = C_SignFinal(null);
            int rv = Library.rvErr(rvLong);
            if (rv!=0) throw new CKException("C_SignFinal", rv);
            outLen = Library.rvValue(rvLong);
        }

        byte[] out = new byte[outLen];
        long rvLong = C_SignFinal(out);
        int rv = Library.rvErr(rvLong);
        if (rv!=0) throw new CKException("C_SignFinal", rv);
        outLen = Library.rvValue(rvLong);

        if (outLen==out.length) return out;
        return Arrays.copyOf(out, outLen);
    }

    public boolean verifyFinal(byte[] signature)
    {
        return 0==C_VerifyFinal(signature);
    }

    public int encrypt(byte[] in, int inOffset, int inLen, byte[] out, int outOffset) throws CKException
    {
        long rvLong = C_Encrypt(in, inOffset, inLen, out, outOffset);
        int rv = Library.rvErr(rvLong);
        if (rv!=0) throw new CKException("C_Encrypt", rv);
        return Library.rvValue(rvLong);
    }

    public int decrypt(byte[] in, int inOffset, int inLen, byte[] out, int outOffset) throws CKException
    {
        long rvLong = C_Decrypt(in, inOffset, inLen, out, outOffset);
        int rv = Library.rvErr(rvLong);
        if (rv!=0) throw new CKException("C_Decrypt", rv);
        return Library.rvValue(rvLong);
    }

    public byte[] encrypt(byte[] in, int outLen) throws CKException
    {
        if (outLen==0) outLen = encrypt(in, 0, in.length, null, 0);
        byte[] out = new byte[outLen];
        outLen = encrypt(in, 0, in.length, out, 0);
        if (outLen==out.length) return out;
        return Arrays.copyOf(out, outLen);
    }

    public byte[] decrypt(byte[] in, int outLen) throws CKException
    {
        if (outLen==0) outLen = decrypt(in, 0, in.length, null, 0);
        byte[] out = new byte[outLen];
        outLen = decrypt(in, 0, in.length, out, 0);
        if (outLen==out.length) return out;
        return Arrays.copyOf(out, outLen);
    }

    public byte[] sign(byte[] in, int outLen) throws CKException
    {
        return sign(in, 0, in.length, outLen);
    }

    public byte[] sign(byte[] in, int inOffset, int inLen, int outLen) throws CKException
    {
        if (outLen==0)
        {
            long rvLong = C_Sign(in, 0, in.length, null);
            int rv = Library.rvErr(rvLong);
            if (rv!=0) throw new CKException("C_Sign", rv);
            outLen = Library.rvValue(rvLong);
        }

        byte[] out = new byte[outLen];
        long rvLong = C_Sign(in, inOffset, inLen, out);
        int rv = Library.rvErr(rvLong);
        if (rv!=0) throw new CKException("C_Sign", rv);
        outLen = Library.rvValue(rvLong);
        if (outLen==out.length) return out;
        return Arrays.copyOf(out, outLen);
    }

    public boolean verify(byte[] in, byte[] signature) throws CKException
    {
        int rv = C_Verify(in, 0, in.length, signature);
        if (rv == 0) return true;
        else if (rv == CKR_SIGNATURE_INVALID) return false;
        else throw new CKException("C_Verify", rv);
    }

    public static byte[] encrypt(CK_MECHANISM mechanism, CKKey key, byte[] in, int outLen) throws CKException
    {
        Session session = key.slot.getSession();
        try
        {
            session.encryptInit(mechanism, key);
            return session.encrypt(in, outLen);
        }
        finally { key.slot.releaseSession(session); }
    }

    public static byte[] decrypt(CK_MECHANISM mechanism, CKKey key, byte[] in, int outLen) throws CKException
    {
        Session session = key.slot.getSession();
        try
        {
            session.decryptInit(mechanism, key);
            return session.decrypt(in, outLen);
        }
        finally { key.slot.releaseSession(session); }
    }

    public static byte[] sign(CK_MECHANISM mechanism, CKKey key, byte[] in, int outLen) throws CKException
    {
        Session session = key.slot.getSession();
        try
        {
            session.signInit(mechanism, key);
            return session.sign(in, outLen);
        }
        finally { key.slot.releaseSession(session); }
    }

    public static boolean verify(CK_MECHANISM mechanism, CKKey key, byte[] in, byte[] signature) throws CKException
    {
        Session session = key.slot.getSession();
        try
        {
            session.verifyInit(mechanism, key);
            return session.verify(in, signature);
        }
        finally { key.slot.releaseSession(session); }
    }

    public byte[] wrap(CK_MECHANISM mechanism, CKKey key, CKKey wrappedKey, int outLen) throws CKException
    {
        if (outLen==0)
        {
            long rvLong = C_WrapKey(mechanism, key.handle, wrappedKey.handle, null, 0);
            int rv = Library.rvErr(rvLong);
            if (rv!=0) throw new CKException("C_WrapKey", rv);
            outLen = Library.rvValue(rvLong);
        }

        byte[] out = new byte[outLen];
        long rvLong = C_WrapKey(mechanism, key.handle, wrappedKey.handle, out, 0);
        int rv = Library.rvErr(rvLong);
        if (rv!=0) throw new CKException("C_WrapKey", rv);
        outLen = Library.rvValue(rvLong);
        if (outLen==out.length) return out;
        return Arrays.copyOf(out, outLen);
    }

}
