package org.jsmth.jorm.jdbc;

import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.Validate;
import org.jsmth.exception.SmthException;
import org.jsmth.util.ReflectUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.BeanWrapper;
import org.springframework.beans.NotReadablePropertyException;
import org.springframework.beans.PropertyAccessorFactory;
import org.springframework.jdbc.support.JdbcUtils;
import org.springframework.util.Assert;

import javax.persistence.*;
import javax.persistence.Column;
import javax.persistence.Table;
import java.lang.reflect.*;
import java.sql.Types;
import java.util.*;


/**
 * 解析并获得JPA信息
 *
 * @author mason
 */
public class JPAHelper {

    protected static Logger logger = LoggerFactory.getLogger(JPAHelper.class);

    /**
     * Types.NULL
     * Types.BOOLEAN
     * Types.BIT
     * Types.CHAR
     * Types.NCHAR
     * Types.TINYINT
     * Types.SMALLINT
     * Types.INTEGER
     * Types.DECIMAL
     * Types.BIGINT
     * Types.DOUBLE
     * Types.FLOAT
     * Types.NUMERIC
     * Types.TIME
     * Types.DATE
     * Types.TIMESTAMP
     * Types.VARCHAR
     * Types.NVARCHAR
     * Types.ARRAY
     * Types.BINARY
     * Types.DATALINK
     * Types.DISTINCT
     * Types.OTHER
     * Types.REAL
     * Types.REF
     * Types.ROWID
     * Types.SQLXML
     * Types.STRUCT
     * Types.VARBINARY
     * Types.JAVA_OBJECT
     * Types.LONGVARCHAR
     * Types.LONGNVARCHAR
     * Types.LONGVARBINARY
     * Types.BLOB
     * Types.NCLOB
     * Types.CLOB
     */
    static final int[] SQL_TYPES = new int[]{Types.NULL, Types.BOOLEAN, Types.BIT, Types.CHAR, Types.NCHAR, Types.TINYINT, Types.SMALLINT, Types.INTEGER, Types.DECIMAL, Types.BIGINT, Types.DOUBLE, Types.FLOAT, Types.NUMERIC, Types.TIME, Types.DATE, Types.TIMESTAMP, Types.VARCHAR, Types.NVARCHAR, Types.ARRAY, Types.BINARY, Types.DATALINK, Types.DISTINCT, Types.OTHER, Types.REAL, Types.REF, Types.ROWID, Types.SQLXML, Types.STRUCT, Types.VARBINARY, Types.JAVA_OBJECT, Types.LONGVARCHAR, Types.LONGNVARCHAR, Types.LONGVARBINARY, Types.CLOB, Types.NCLOB, Types.BLOB};
    static final Map<Integer, Integer> ORDERED_SQL_TYPES = new LinkedHashMap<Integer, Integer>();

    static {
        int idx = 0;
        for (int i : SQL_TYPES) {
            ORDERED_SQL_TYPES.put(i, idx);
            idx++;
        }
    }

    /**
     * 检测是否为可持久化的实体
     *
     * @param clazz d
     * @param <T> d
     * @return 返回信息
     */
    public static <T> boolean isEntity(Class<T> clazz) {
        return clazz.getAnnotation(Entity.class) != null || clazz.getAnnotation(Table.class) != null;
    }

    /**
     * 返回表名
     *
     * @param clazz d
     * @param <T> d
     * @return 返回信息
     */
    public static <T> String getTableName(Class<T> clazz) {
        Assert.isTrue(isEntity(clazz));

        Table table = clazz.getAnnotation(Table.class);
        if (table != null && !table.name().equals(""))
            return table.name();

        org.hibernate.annotations.Table htable = clazz.getAnnotation(org.hibernate.annotations.Table.class);
        if (htable != null && !htable.appliesTo().equals(""))
            return htable.appliesTo();

        Entity entity = clazz.getAnnotation(Entity.class);
        if (entity != null && !entity.name().equals(""))
            return entity.name();

        return clazz.getSimpleName();
    }

    /**
     * 获得实体上绑定的jpa事件
     *
     * @param clazz d
     * @param <T> d
     * @return 返回信息
     */
    public static <T> Map<Event, Method> getEventMap(Class<T> clazz) {
        Assert.isTrue(isEntity(clazz));
        Class superclazz = clazz;
        Map<Event, Method> ret = new HashMap<Event, Method>();

        while (superclazz != null && !superclazz.equals(Object.class)) {
            Method[] methods = superclazz.getMethods();
            for (Method method : methods) {
                if (!ret.containsKey(Event.PreInsert) && method.getAnnotation(PrePersist.class) != null) {
                    ret.put(Event.PreInsert, method);
                } else if (!ret.containsKey(Event.PreUpdate) && method.getAnnotation(PreUpdate.class) != null) {
                    ret.put(Event.PreUpdate, method);
                } else if (!ret.containsKey(Event.PreDelete) && method.getAnnotation(PreRemove.class) != null) {
                    ret.put(Event.PreDelete, method);
                } else if (!ret.containsKey(Event.PostLoad) && method.getAnnotation(PostLoad.class) != null) {
                    ret.put(Event.PostLoad, method);
                } else if (!ret.containsKey(Event.PostInsert) && method.getAnnotation(PostPersist.class) != null) {
                    ret.put(Event.PostInsert, method);
                } else if (!ret.containsKey(Event.PostUpdate) && method.getAnnotation(PostUpdate.class) != null) {
                    ret.put(Event.PostUpdate, method);
                } else if (!ret.containsKey(Event.PostDelete) && method.getAnnotation(PostRemove.class) != null) {
                    ret.put(Event.PostDelete, method);
                }
            }
            superclazz = superclazz.getSuperclass();
        }
        return ret;
    }

    /**
     * 获得所有JPA属性的列表，并且将父类的放在前面
     *
     * @param clazz d
     * @param <T> d
     * @return 返回信息
     */
    public static <T> List<Field> getJPAFields(Class<T> clazz) {
        Class superclazz = clazz;
        List<Field> tmp = new ArrayList<Field>();

        while (superclazz != null && !superclazz.equals(Object.class)) {
            Field[] fields = superclazz.getDeclaredFields();

            List<Field> tmp2 = new ArrayList<Field>();

            for (Field field : fields) {
                //排除肯定不持久化的部分
                if (Modifier.isTransient(field.getModifiers())) continue;
                if (Modifier.isStatic(field.getModifiers())) continue;
                if (field.getAnnotation(Transient.class) != null) continue;
                tmp2.add(field);
            }
            tmp.addAll(0, tmp2);
            superclazz = superclazz.getSuperclass();
        }
        return tmp;
    }

    public static <T> Map<String, Column> getColumnOverride(Class<T> clazz) {
        List<Class> classes = new ArrayList<Class>();
        Class superclazz = clazz;
        while (superclazz != null && !superclazz.equals(Object.class)) {
            classes.add(0, superclazz);
            superclazz = superclazz.getSuperclass();
        }

        Map<String, Column> ret = new HashMap<String, Column>();

        for (Class<?> clz : classes) {
            AttributeOverrides aod = clz.getAnnotation(AttributeOverrides.class);
            processAttributeOverride(ret, aod);
            Field[] fields = clz.getDeclaredFields();
            for (Field field : fields) {
                AttributeOverrides faods = field.getAnnotation(AttributeOverrides.class);
                processAttributeOverride(ret, faods);
                AttributeOverride faod = field.getAnnotation(AttributeOverride.class);
                processAttributeOverride(ret, faod);
            }
        }
        return ret;
    }

    private static void processAttributeOverride(Map<String, Column> map, AttributeOverrides annotation) {
        if (annotation != null) {
            AttributeOverride[] overrides = annotation.value();
            if (overrides != null && overrides.length > 0) {
                for (AttributeOverride override : overrides) {
                    processAttributeOverride(map, override);
                }
            }
        }
    }

    private static void processAttributeOverride(Map<String, Column> map, AttributeOverride annotation) {
        if (annotation != null) {
            map.put(annotation.name(), annotation.column());
        }
    }

    public static boolean isEmbedded(Field field) {
        return field.getAnnotation(Embedded.class) != null && field.getType().getAnnotation(Embeddable.class) != null;
    }

    /**
     * 获得持久化的对象属性集合，
     *
     * @param clazz dd
     * @param <T> dd
     * @return 返回信息 Map 属性名，属性值
     */
    public static <T> Map<String, Field> getJPAFieldsMap(Class<T> clazz) {
        Validate.isTrue(isEntity(clazz));
        Class superclazz = clazz;
        List<Field> tmp = new ArrayList<Field>();

        while (superclazz != null && !superclazz.equals(Object.class)) {
            Field[] fields = superclazz.getDeclaredFields();

            List<Field> tmp2 = new ArrayList<Field>();

            for (Field field : fields) {
                //不持久化的部分
                if (Modifier.isTransient(field.getModifiers())) continue;
                if (Modifier.isStatic(field.getModifiers())) continue;
                if (field.getAnnotation(Transient.class) != null) continue;
//                if (field.getAnnotation(Embedded.class) != null) {
//                    Field[] embeddedfields = field.getType().getDeclaredFields();
//                    for (Field f : embeddedfields) {
//                        if (Modifier.isTransient(f.getModifiers())) continue;
//                        if (Modifier.isStatic(f.getModifiers())) continue;
//                        if (f.getAnnotation(Transient.class) != null) continue;
//                        tmp2.add(f);
//                    }
//                    continue;
//                }
                if (field.getAnnotation(Embedded.class) != null) continue;
                tmp2.add(field);
            }
            tmp.addAll(0, tmp2);
            superclazz = superclazz.getSuperclass();
        }

        //把lob类型的挪到表中列的最后去
        LinkedHashMap<String, Field> ret = new LinkedHashMap<String, Field>(tmp.size());
        List<Field> tmp3 = new ArrayList<Field>();
        for (Field field : tmp) {
            Validate.isTrue(validJPAType(field), "invalid JPA type for field [" + field + "]");

            Lob lob = field.getAnnotation(Lob.class);
            if (lob == null)
                ret.put(field.getName(), field);
            else
                tmp3.add(field);

        }
        for (Field field : tmp3) {
            ret.put(field.getName(), field);
        }
        return ret;
    }

    public static <T> Set<Index> getIndexFromClass(Class<T> clazz) {
        org.hibernate.annotations.Table table = clazz.getAnnotation(org.hibernate.annotations.Table.class);
        if (table == null) {
            return Collections.emptySet();
        }
        org.hibernate.annotations.Index[] indexes = table.indexes();
        if (indexes == null || indexes.length == 0) {
            return Collections.emptySet();
        }

        Set<Index> ret = new HashSet<Index>();
        for (org.hibernate.annotations.Index annotation : indexes) {
            ret.add(new Index(annotation));
        }
        return ret;
    }

    public static Index getIndexFromField(Field field) {
        org.hibernate.annotations.Index annotation = field.getAnnotation(org.hibernate.annotations.Index.class);
        if (annotation != null) {
            Index ret = new Index();
            if (StringUtils.isBlank(annotation.name())) {
                ret.setName(field.getName());
            } else {
                ret.setName(annotation.name());
            }

            if (ArrayUtils.isEmpty(annotation.columnNames())) {
                ret.setColumns(Arrays.asList(field.getName()));
            } else {
                ret.setColumns(Arrays.asList(annotation.columnNames()));
            }
            return ret;
        } else {
            return null;
        }
    }


    /**
     * 检查类属性是否为JPA标准的简单类型，并且不能使用@embedded
     *
     * @param field d
     * @return 返回信息
     */
    public static boolean validJPAType(Field field) {
        return field.getAnnotation(Embedded.class) == null;
    }

    /**
     * 获得可持久化属性的数据库列名
     *
     * @param field d
     * @return 返回信息
     */
    public static String getColumnName(Field field) {
        Column column = field.getAnnotation(Column.class);
        if (column != null && column.name() != null && !column.name().equals(""))
            return column.name();
        else
            return field.getName();
    }

    public static ColumnType getFieldType(Field field) {
        Class clazz = field.getType();
        ColumnType ctype = getFieldType(clazz);
        if (ctype == ColumnType.Enum) {
            return JPAHelper.isUseOrdinal(field) ? ColumnType.Int : ColumnType.string;
        }
        return ctype;
    }

    public static ColumnType getFieldType(Class clazz) {
        if(clazz==null){
            throw new SmthException("Class is null");
        }
        //logger.debug("getFieldType by class " +clazz.getName());
        if (clazz.equals(int.class) || clazz.equals(Integer.class)) {
            return ColumnType.Int;
        }
        if (clazz.equals(long.class) || clazz.equals(Long.class)) {
            return ColumnType.Long;
        }
        if (clazz.equals(float.class) || clazz.equals(Float.class)) {
            return ColumnType.Int;
        }
        if (clazz.equals(double.class) || clazz.equals(Double.class)) {
            return ColumnType.Double;
        }
        if (clazz.equals(String.class)) {
            return ColumnType.string;
        }
        if (clazz.equals(Date.class) || clazz.equals(java.sql.Date.class)) {
            return ColumnType.date;
        }
        if (clazz.equals(boolean.class) || clazz.equals(Boolean.class))
            return ColumnType.Boolean;
        if (clazz.isEnum()) {
            return ColumnType.Enum;
        }
        if (clazz.isInterface()) {
            Type type = ReflectUtil.getGenericClassParameterizedType(clazz);
            if (type == null)
                return null;
            return getFieldType((Class) type);
        }
        return null;

    }

    public static boolean isNumberFieldType(Class clazz) {
        if(clazz==null){
            throw new SmthException("Class is null");
        }
        //logger.debug("isNumberFieldType " +clazz.getName());
        if (clazz.equals(int.class) || clazz.equals(Integer.class)) {
            return true;
        }
        if (clazz.equals(long.class) || clazz.equals(Long.class)) {
            return true;
        }
        if (clazz.equals(float.class) || clazz.equals(Float.class)) {
            return true;
        }
        if (clazz.equals(double.class) || clazz.equals(Double.class)) {
            return true;
        }
        return false;
    }


    /**
     * 获得持久化对象的id，以及id对应的数据库字段名
     *
     * @param field dd
     * @return 返回信息
     */
    public static boolean isIdField(Field field) {
        Id id = field.getAnnotation(Id.class);
        return id != null;
    }

    public static boolean isIdAutoIncrease(Field field) {
        GeneratedValue gv = field.getAnnotation(GeneratedValue.class);
        if (gv == null)
            return false;
        if (gv.strategy() == GenerationType.AUTO || gv.strategy() == GenerationType.IDENTITY)
            return true;
        return false;
//        throw new IllegalArgumentException("Unsupproted GeneratedValue strategy");
    }

    public static boolean isEnumerate(Field field) {
        return field.getType().isEnum();
    }

    public static boolean isUseOrdinal(Field field) {
        Enumerated ann = field.getAnnotation(Enumerated.class);
        return ann != null && ann.value() != EnumType.STRING;
    }


    public static Object getEntityFieldValue(Object entity, String fieldName) {
        Object mappedObject = BeanUtils.instantiateClass(entity.getClass());
        BeanWrapper bw = PropertyAccessorFactory.forBeanPropertyAccess(mappedObject);
        org.jsmth.jorm.jdbc.Table table = org.jsmth.jorm.jdbc.Table.getTable(entity.getClass());
        org.jsmth.jorm.jdbc.Column column = table.getColumnByFieldName(fieldName);
        Validate.notNull(column, "no mapped field for " + fieldName);
        try {
            if (column.isEnumerate()) {
                if (column.isUseOrdinal()) {
                    Enum value = (Enum) bw.getPropertyValue(column.getFieldName());
                    if (value != null)
                        return value.ordinal();
                    else
                        throw new IllegalArgumentException("Enum Field [" + column.getField() + "] can not be null while using Ordinal.");
                } else {
                    Object value = bw.getPropertyValue(column.getFieldName());
                    if (value != null)
                        return value.toString();
                    else
                        return null;
                }
            } else {
                return bw.getPropertyValue(column.getFieldName());
            }
        }
        catch (NotReadablePropertyException ex) {
            throw new IllegalArgumentException(ex.getMessage());
        }
    }

    public static void setEntityFieldValue(Object entity, String fieldName, Object value) {
        if (entity == null) {
            return;
        }
        Object mappedObject = BeanUtils.instantiateClass(entity.getClass());
        BeanWrapper bw = PropertyAccessorFactory.forBeanPropertyAccess(mappedObject);
        org.jsmth.jorm.jdbc.Table table = org.jsmth.jorm.jdbc.Table.getTable(entity.getClass());
        org.jsmth.jorm.jdbc.Column col = table.getColumnByFieldName(fieldName);

        if (col.getField().getType().isEnum()) {
            boolean ordinal = false;
            Enumerated[] enumeratedTypes = col.getField().getAnnotationsByType(Enumerated.class);
            int ivalue = 0;
            if (enumeratedTypes != null && enumeratedTypes.length > 0) {
                if (enumeratedTypes[0].value() == EnumType.ORDINAL) {
                    ordinal = true;
                    ivalue = Integer.valueOf(value.toString());
                }
            }
            Class<Enum> typeclass = (Class<Enum>) col.getField().getType();
            Enum[] enumConstants = typeclass.getEnumConstants();
            for (Enum enumConstant : enumConstants) {
                if (ordinal) {
                    if (enumConstant.ordinal() == ivalue) {
                        value = enumConstant;
                        break;
                    }
                } else {
                    if (enumConstant.name().endsWith(value.toString())) {
                        value = enumConstant;
                        break;
                    }
                }
            }
        }

        if (col.isEnumerate() && col.isUseOrdinal()) {
            value = col.getEnumValues()[(Integer) value];
        }
        bw.setPropertyValue(col.getFieldName(), value);
    }

}
