package org.jsmth.jorm.jdbc;


import org.apache.commons.dbcp.BasicDataSource;
import org.jsmth.data.dialect.Dialect;
import org.jsmth.data.dialect.DialectFactory;
import org.jsmth.data.dialect.HSQLDialect;
import org.jsmth.exception.SmthExceptionDict;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.jsmth.exception.SmthDataAccessException;
import org.jsmth.page.CommonPage;
import org.jsmth.page.RollPage;
import org.jsmth.page.TailPage;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.SingleColumnRowMapper;
//import org.springframework.jdbc.core.simple.SimpleJdbcTemplate;
import org.springframework.jdbc.support.SQLErrorCodeSQLExceptionTranslator;
import org.springframework.jdbc.support.SQLExceptionTranslator;
import org.springframework.jdbc.support.SQLStateSQLExceptionTranslator;

import javax.sql.DataSource;
import java.io.Serializable;
import java.util.Collection;
import java.util.List;
import java.util.Map;

/**
 * User: 马生录（mason
 * Date: 12-2-9
 * Time: 下午10:06
 */
public class BaseJdbcDao extends JdbcTemplate {

    protected EntityEventCallback entityEventCallback;

    protected Logger logger = LoggerFactory.getLogger(this.getClass());

    protected SQLExceptionTranslator exceptionTranslator;
    protected DataSource dataSource;
    protected org.hibernate.dialect.Dialect hbdialect = new org.hibernate.dialect.MySQL5Dialect();
    protected Dialect dialect=new HSQLDialect();// MySqlDialect();

    public BaseJdbcDao(DataSource dataSource) {
        super(dataSource);
        this.dataSource = dataSource;
    }

    public DataSource getDataSource() {
        return this.dataSource;
    }

    public void setDataSource(DataSource dataSource) {
        this.dataSource = dataSource;
    }

    protected void setDialect(String dialect) {
        this.dialect = DialectFactory.getDialect((BasicDataSource) dataSource);
        this.hbdialect=DialectFactory.getHibernateDialect(this.dialect);
    }

    protected org.hibernate.dialect.Dialect getHbdialect() {
        return hbdialect;
    }

    protected Dialect getDialect() {
        return dialect;
    }

    protected void setDialect(Dialect dialect) {
        this.dialect = dialect;
    }

    /**
     * 设置数据库方言
     *
     * @param dialect d
     */
    public void setDialect(DialectType dialect) {
        this.hbdialect = DialectType.getHibernateDialect(dialect);
    }

    public void executeDDL(String sql) {
        this.execute(sql);
//        this.getJdbcOperations().execute(sql);
    }



    //<editor-fold desc="find all">

    public <T, E extends Serializable> List<E> findAllId(Class<T> entityClass) throws SmthDataAccessException {
        return findIds(entityClass, "1=1");
    }

    /**
     * 获得所有
     *
     * @param entityClass d
     * @param <T> d
     * @param fieldNames d
     * @return 返回信息
     * @throws SmthDataAccessException s
     */
    public <T> List<T> findAll(Class<T> entityClass, String... fieldNames) throws SmthDataAccessException {
        Table<T> table = Table.getTable(entityClass);
        String sql = table.selectClause(true, fieldNames);
        List<T> ts = this.query(sql, JPARowMapper.forClass(entityClass));
        for (T t : ts) {
            invokeEvent(table, t, Event.PostLoad);
        }
        return ts;
    }

    /**
     * 获得所有
     *
     * @param entityClass d
     * @param <T> d
     * @param fieldNames d
     * @param pageNumber d
     * @param pageSize d
     * @return 返回信息
     * @throws SmthDataAccessException s
     */
    public <T> TailPage<T> pageFindAll(Class<T> entityClass, int pageNumber, int pageSize, String... fieldNames) throws SmthDataAccessException {
        Table<T> table = Table.getTable(entityClass);
        String sql = table.selectClause(true, fieldNames);
        sql += SQLHelper.pagingClause(pageNumber, pageSize);
        List<T> ts = this.query(sql, JPARowMapper.forClass(entityClass));
        for (T t : ts) {
            invokeEvent(table, t, Event.PostLoad);
        }
        int count = count(entityClass, "1=1");
        TailPage page = new CommonPage(pageNumber, pageSize, count, ts);
        return page;
    }

    /**
     * 获得所有
     *
     * @param entityClass d
     * @param pageSize d
     * @param fieldNames d@
     * @param pageNumber d
     * @param <T> d
     * @return 返回信息
     * @throws SmthDataAccessException s
     */
    public <T> RollPage<T> rollFindAll(Class<T> entityClass, int pageNumber, int pageSize, String... fieldNames) throws SmthDataAccessException {
        Table<T> table = Table.getTable(entityClass);
        String sql = table.selectClause(true, fieldNames);
        sql += SQLHelper.pagingClause(pageNumber, pageSize);
        List<T> ts = this.query(sql, JPARowMapper.forClass(entityClass));
        for (T t : ts) {
            invokeEvent(table, t, Event.PostLoad);
        }
        RollPage page = new RollPage(pageNumber, pageSize, ts);
        return page;
    }
    //</editor-fold>

    //<editor-fold desc="find column">
    public <T, E> List<E> groupColumn(Class<T> entityClass, Class<E> columnClass, String fieldName, String where, Object... params) throws SmthDataAccessException {
        Table<T> table = Table.getTable(entityClass);
        StringBuilder sql = new StringBuilder();
        sql.append(table.distinctSelectClause(fieldName));
        sql.append(" where ");
        sql.append(where);
        sql.append(" group by ");
        sql.append(table.getColumnByFieldName(fieldName).getColumnName());
        List<E> ts;
        if (params == null || params.length == 0)
            ts = query(sql.toString(), new SingleColumnRowMapper<E>(columnClass));
        else
            ts = query(sql.toString(), new SingleColumnRowMapper<E>(columnClass), params);
        return ts;
    }

    public <T, E> List<Map.Entry<E, Integer>> groupCountColumn(Class<T> entityClass, Class<E> columnClass, String fieldName, String countFieldName, String where, Object... params) throws SmthDataAccessException {
        Table<T> table = Table.getTable(entityClass);
        StringBuilder sql = new StringBuilder();
        sql.append(table.selectGroupCountClause(fieldName, countFieldName));
        sql.append(" where ");
        sql.append(where);
        sql.append(" group by ");
        sql.append(table.getColumnByFieldName(fieldName).getColumnName());
        List<Map.Entry<E, Integer>> ts;
        if (params == null || params.length == 0)
            ts = query(sql.toString(), new TwoColumnRowMapper<E, Integer>(columnClass, Integer.class));
        else
            ts = query(sql.toString(), new TwoColumnRowMapper<E, Integer>(columnClass, Integer.class), params);
        return ts;
    }

    public <T, E> List<Map.Entry<E, Integer>> groupMaxColumn(Class<T> entityClass, Class<E> columnClass, String fieldName, String maxFieldName, String where, Object... params) throws SmthDataAccessException {
        Table<T> table = Table.getTable(entityClass);
        StringBuilder sql = new StringBuilder();
        sql.append(table.selectGroupMaxClause(fieldName, maxFieldName));
        sql.append(" where ");
        sql.append(where);
        sql.append(" group by ");
        sql.append(table.getColumnByFieldName(fieldName).getColumnName());
        List<Map.Entry<E, Integer>> ts;
        if (params == null || params.length == 0)
            ts = query(sql.toString(), new TwoColumnRowMapper<E, Integer>(columnClass, Integer.class));
        else
            ts = query(sql.toString(), new TwoColumnRowMapper<E, Integer>(columnClass, Integer.class), params);
        return ts;
    }

    public <T, E> List<Map.Entry<E, Integer>> groupMinColumn(Class<T> entityClass, Class<E> columnClass, String fieldName, String maxFieldName, String where, Object... params) throws SmthDataAccessException {
        Table<T> table = Table.getTable(entityClass);
        StringBuilder sql = new StringBuilder();
        sql.append(table.selectGroupMinClause(fieldName, maxFieldName));
        sql.append(" where ");
        sql.append(where);
        sql.append(" group by ");
        sql.append(table.getColumnByFieldName(fieldName).getColumnName());
        List<Map.Entry<E, Integer>> ts;
        if (params == null || params.length == 0)
            ts = query(sql.toString(), new TwoColumnRowMapper<E, Integer>(columnClass, Integer.class));
        else
            ts = query(sql.toString(), new TwoColumnRowMapper<E, Integer>(columnClass, Integer.class), params);
        return ts;
    }

    public <T, E> List<Map.Entry<E, Integer>> groupSumColumn(Class<T> entityClass, Class<E> columnClass, String fieldName, String maxFieldName, String where, Object... params) throws SmthDataAccessException {
        Table<T> table = Table.getTable(entityClass);
        StringBuilder sql = new StringBuilder();
        sql.append(table.selectGroupSumClause(fieldName, maxFieldName));
        sql.append(" where ");
        sql.append(where);
        sql.append(" group by ");
        sql.append(table.getColumnByFieldName(fieldName).getColumnName());
        List<Map.Entry<E, Integer>> ts;
        if (params == null || params.length == 0)
            ts = query(sql.toString(), new TwoColumnRowMapper<E, Integer>(columnClass, Integer.class));
        else
            ts = query(sql.toString(), new TwoColumnRowMapper<E, Integer>(columnClass, Integer.class), params);
        return ts;
    }

    /**
     * 根据where条件查询一个独立字段
     * 不会使用JPA事件回调
     *
     * @param entityClass d
     * @param columnClass d
     * @param fieldName d
     * @param where d
     * @param params d
     * @param <T> d
     * @param <E> d
     * @return 返回信息
     * @throws SmthDataAccessException s
     */
    public <T, E> List<E> findColumn(Class<T> entityClass, Class<E> columnClass, String fieldName, String where, Object... params) throws SmthDataAccessException {
        Table<T> table = Table.getTable(entityClass);
        StringBuilder sql = new StringBuilder();
        sql.append(table.selectClause(false, fieldName));
        sql.append(" where ");
        sql.append(where);
        List<E> ts;
        if (params == null || params.length == 0)
            ts = query(sql.toString(), new SingleColumnRowMapper<E>(columnClass));
        else
            ts = query(sql.toString(), new SingleColumnRowMapper<E>(columnClass), params);
        return ts;
    }

    /**
     * 根据where条件查询一个独立字段
     * 不会使用JPA事件回调
     *
     * @param entityClass d
     * @param columnClass d
     * @param fieldName d
     * @param where d
     * @param params d
     * @param <T> d
     * @param <E> d
     *
     * @return 返回信息
     * @throws SmthDataAccessException s
     */
    public <T, E> List<E> findDistinctColumn(Class<T> entityClass, Class<E> columnClass, String fieldName, String where, Object... params) throws SmthDataAccessException {
        Table<T> table = Table.getTable(entityClass);
        StringBuilder sql = new StringBuilder();
        sql.append(table.distinctSelectClause(fieldName));
        sql.append(" where ");
        sql.append(where);
        List<E> ts;
        if (params == null || params.length == 0)
            ts = query(sql.toString(), new SingleColumnRowMapper<E>(columnClass));
        else
            ts = query(sql.toString(), new SingleColumnRowMapper<E>(columnClass), params);
        return ts;
    }

    /**
     * 根据where条件查询一个独立字段
     * 不会使用JPA事件回调
     *
     * @param entityClass fd
     * @param columnClass d
     * @param fieldName d
     * @param where d
     * @param params d
     * @param <T> d
     * @param <E> d
     * @param pageSize d
     * @param pageNumber d
     * @return 返回信息
     * @throws SmthDataAccessException s
     */
    public <T, E> TailPage<E> pageFindColumn(Class<T> entityClass, Class<E> columnClass, String fieldName, String where, int pageNumber, int pageSize, Object... params) throws SmthDataAccessException {
        Table<T> table = Table.getTable(entityClass);
        StringBuilder sql = new StringBuilder();
        sql.append(table.selectClause(false, fieldName));
        sql.append(" where ");
        sql.append(where);
        int count = count(entityClass, where, params);
        sql.append(SQLHelper.pagingClause(pageNumber, pageSize));
        List<E> ts;
        if (params == null || params.length == 0)
            ts = query(sql.toString(), new SingleColumnRowMapper<E>(columnClass));
        else
            ts = query(sql.toString(), new SingleColumnRowMapper<E>(columnClass), params);
        TailPage page = new CommonPage(pageNumber, pageSize, count, ts);
        return page;
    }

    /**
     * 根据where条件查询一个独立字段
     * 不会使用JPA事件回调
     *
     * @param entityClass d
     * @param columnClass d
     * @param fieldName d
     * @param where d
     * @param params d
     * @param <T> d
     * @param <E> d
     * @param pageSize d
     * @param pageNumber d
     * @return 返回信息
     * @throws SmthDataAccessException s
     */
    public <T, E> RollPage<E> rollFindColumn(Class<T> entityClass, Class<E> columnClass, String fieldName, String where, int pageNumber, int pageSize, Object... params) throws SmthDataAccessException {
        Table<T> table = Table.getTable(entityClass);
        StringBuilder sql = new StringBuilder();
        sql.append(table.selectClause(false, fieldName));
        sql.append(" where ");
        sql.append(where);
        sql.append(SQLHelper.pagingClause(pageNumber, pageSize));
        List<E> ts;
        if (params == null || params.length == 0)
            ts = query(sql.toString(), new SingleColumnRowMapper<E>(columnClass));
        else
            ts = query(sql.toString(), new SingleColumnRowMapper<E>(columnClass), params);
        RollPage page = new RollPage(pageNumber, pageSize, ts);
        return page;
    }
    //</editor-fold>

    //<editor-fold desc="delete">

    /**
     * 根据一组条件删除model
     *
     * @param where d
     * @param params d
     * @param <T> d
     * @param entityClass d
     * @return 返回信息
     * @throws SmthDataAccessException s
     */
    public <T> int delete(Class<T> entityClass, String where, Object... params) throws SmthDataAccessException {
        StringBuilder sb = new StringBuilder();
        sb.append(Table.getTable(entityClass).deleteClause());
        sb.append(" where ");
        sb.append(where);
        return this.update(sb.toString(), params);
    }
    //</editor-fold>

    //<editor-fold desc="find">
    public <T, E extends Serializable> List<E> findIds(Class<T> entityClass, String where, Object... params) throws SmthDataAccessException {
        Table table = Table.getTable(entityClass);
        return findColumn(entityClass, table.getIdColumn().getType().getMappingClazz(), table.getIdColumn().getColumnName(), where, params);
    }

    /**
     * 根据where语句查询一个实体
     *
     * @param entityClass d
     * @param exclude     查询中否排除掉fieldNames中指定的实体属性，以提高性能
     * @param fieldNames d
     * @param where d
     * @param params d
     * @param <T> d
     * @return 返回信息
     * @throws SmthDataAccessException s
     */
    public <T> List<T> find(Class<T> entityClass, boolean exclude, String[] fieldNames, String where, Object... params) throws SmthDataAccessException {
        Table<T> table = Table.getTable(entityClass);
        StringBuilder sql = new StringBuilder();
        sql.append(table.selectClause(exclude, fieldNames));
        sql.append(" where ");
        sql.append(where);
        List<T> ts;
        if (params == null || params.length == 0)
            ts = query(sql.toString(), JPARowMapper.forClass(entityClass));
        else
            ts = query(sql.toString(), JPARowMapper.forClass(entityClass), params);
        for (T t : ts) {
            invokeEvent(table, t, Event.PostLoad);
        }
        return ts;
    }

    public <T> TailPage<T> pageFind(Class<T> entityClass, boolean exclude, String[] fieldNames, String where, int pageNumber, int pageSize, Object... params) throws SmthDataAccessException {
        Table<T> table = Table.getTable(entityClass);
        StringBuilder sql = new StringBuilder();
        sql.append(table.selectClause(exclude, fieldNames));
        sql.append(" where ");
        sql.append(where);
        int count = count(entityClass, where, params);
        sql.append(SQLHelper.pagingClause(pageNumber, pageSize));
        List<T> ts;
        if (params == null || params.length == 0)
            ts = query(sql.toString(), JPARowMapper.forClass(entityClass));
        else
            ts = query(sql.toString(), JPARowMapper.forClass(entityClass), params);
        for (T t : ts) {
            invokeEvent(table, t, Event.PostLoad);
        }
        TailPage page = new CommonPage(pageNumber, pageSize, count, ts);
        return page;
    }

    public <T> RollPage<T> rollFind(Class<T> entityClass, boolean exclude, String[] fieldNames, String where, int pageNumber, int pageSize, Object... params) throws SmthDataAccessException {
        Table<T> table = Table.getTable(entityClass);
        StringBuilder sql = new StringBuilder();
        sql.append(table.selectClause(exclude, fieldNames));
        sql.append(" where ");
        sql.append(where);
        sql.append(SQLHelper.pagingClause(pageNumber, pageSize));
        List<T> ts;
        if (params == null || params.length == 0)
            ts = query(sql.toString(), JPARowMapper.forClass(entityClass));
        else
            ts = query(sql.toString(), JPARowMapper.forClass(entityClass), params);
        for (T t : ts) {
            invokeEvent(table, t, Event.PostLoad);
        }
        RollPage page = new RollPage(pageNumber, pageSize, ts);
        return page;
    }

    /**
     * 求列表
     *
     * @param entityClass d
     * @param where d
     * @param params d
     * @param <T> d
     * @return 返回信息
     * @throws SmthDataAccessException s
     */
    public <T> List<T> find(Class<T> entityClass, String where, Object... params) throws SmthDataAccessException {
        return find(entityClass, true, null, where, params);
    }

    public <T> TailPage<T> pageFind(Class<T> entityClass, String where, int pageNumber, int pageSize, Object... params) throws SmthDataAccessException {
        return pageFind(entityClass, true, null, where, pageNumber, pageSize, params);
    }

    public <T> RollPage<T> rollFind(Class<T> entityClass, String where, int pageNumber, int pageSize, Object... params) throws SmthDataAccessException {
        return rollFind(entityClass, true, null, where, pageNumber, pageSize, params);
    }
    //</editor-fold>

    //<editor-fold desc="count">

    /**
     * 求值
     *
     * @param entityClass d
     * @param where d
     * @param params d
     * @param <T> d
     * @return 返回信息
     * @throws SmthDataAccessException s
     */
    public <T> int count(Class<T> entityClass, String where, Object... params) throws SmthDataAccessException {
        Table<T> table = Table.getTable(entityClass);
        StringBuilder sql = new StringBuilder();
        sql.append(table.countClause());
        sql.append(" where ");
        sql.append(where);
        if (params == null || params.length == 0)
            return queryForObject(sql.toString(),Integer.class);
        else
            return queryForObject(sql.toString(),Integer.class, params);
    }
    //</editor-fold>

    //<editor-fold desc="getModel">
    public <T> T getModel(Class<T> entityClass, boolean exclude, String[] fieldNames, String where, Object... params) throws SmthDataAccessException {
        return getModel(false, entityClass, exclude, fieldNames, where, params);
    }

    public <T> T getModel(boolean checkNull,Class<T> entityClass, boolean exclude, String[] fieldNames, String where, Object... params) throws SmthDataAccessException {
        List<T> list = find(entityClass, exclude, fieldNames, where, params);
        if (list == null || list.size() == 0) {
            return ThrowDataAccessException(checkNull,entityClass, SmthExceptionDict.ModelNullException);
        }
        return list.get(0);
    }
    public <T> T getModel(Class<T> entityClass, String where, Object... params) throws SmthDataAccessException {
        return getModel(false, entityClass, where, params);
    }

    public <T> T getModel(boolean checkNull,Class<T> entityClass, String where, Object... params) throws SmthDataAccessException {
        return getModel(checkNull, entityClass, true, null, where, params);
    }
    //</editor-fold>

    //<editor-fold desc="getColumn">
    public <T, E> E getColumn(Class<T> entityClass, Class<E> columnClass, String fieldName, String where, Object... params) throws SmthDataAccessException {
        return getColumn(false,entityClass, columnClass, fieldName, where,params);
    }
    public <T, E> E getColumn(boolean checkNull,Class<T> entityClass, Class<E> columnClass, String fieldName, String where, Object... params) throws SmthDataAccessException {
        List<E> list = findColumn(entityClass, columnClass, fieldName, where, params);
        if (list == null || list.size() == 0) {
            return ThrowDataAccessException(checkNull,columnClass, SmthExceptionDict.ModelNullException);
        }
        return list.get(0);
    }
    //</editor-fold>


    public synchronized SQLExceptionTranslator getExceptionTranslator() {
        if (this.exceptionTranslator == null) {
            DataSource dataSource = getDataSource();
            if (dataSource != null) {
                this.exceptionTranslator = new SQLErrorCodeSQLExceptionTranslator(dataSource);
            } else {
                this.exceptionTranslator = new SQLStateSQLExceptionTranslator();
            }
        }
        return this.exceptionTranslator;
    }


    void batchInvokeEvent(Table table, Collection targets, Event event) {
        batchInvokeEvent(table, targets, event);
        if(entityEventCallback!=null && entityEventCallback!=this){
            entityEventCallback.batchInvokeEvent(targets, event);
        }
    }

    void invokeEvent(Table table,Object target, Event event) {
        table.invokeEvent(target, event);
        if(entityEventCallback!=null && entityEventCallback!=this){
            entityEventCallback.invokeEvent(target, event);
        }
    }

    public EntityEventCallback getEntityEventCallback() {
        return entityEventCallback;
    }

    public void setEntityEventCallback(EntityEventCallback entityEventCallback) {
        this.entityEventCallback = entityEventCallback;
    }

    protected <T> T ThrowDataAccessException(boolean checkNull,Class<T> tClass, SmthExceptionDict smthExceptionCode){
        if(checkNull) {
            throw new SmthDataAccessException(smthExceptionCode);
        }
        return null;
    }
}
