package org.jsmth.jorm.jdbc;

import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang.Validate;
import org.jsmth.domain.Identifier;
import org.jsmth.exception.SmthDataAccessException;
import org.jsmth.exception.SmthExceptionDict;
import org.jsmth.util.IdentifierKeyHelper;
import org.jsmth.util.aaaa;
import org.springframework.jdbc.support.GeneratedKeyHolder;
import org.springframework.jdbc.support.KeyHolder;
import org.springframework.util.Assert;

import javax.sql.DataSource;
import java.io.Serializable;
import java.util.*;

/**
 * 在spring simpleJdbcTemplate的基础上扩展，满足基本CRUD业务的dao。
 * 在完全使用jdbc基础上，借助spring的声明式事务管理，获得最大的便利。
 * User: 马生录（mason
 * Date: 12-2-9
 * Time: 下午9:57
 */
public class BaseTableJdbcDao extends BaseJdbcDao {

    public BaseTableJdbcDao(DataSource dataSource) {
        super(dataSource);
    }


    //<editor-fold desc="Description">
    /**
     * 单体查询
     *
     * @param entityClass d
     * @param id d
     * @param <T> d
     * @param ommitFieldNames d
     * @return 返回信息
     */
    public <T> T getById(Class<T> entityClass, Serializable id, String... ommitFieldNames) throws SmthDataAccessException {
        return getById(false,entityClass,id,ommitFieldNames);
    }
    public <T> T getById(boolean checkNull,Class<T> entityClass, Serializable id, String... ommitFieldNames) throws SmthDataAccessException {
        if (id == null) {
            return ThrowDataAccessException(checkNull,entityClass, SmthExceptionDict.KeyNullException);
        }
        Table<T> table = Table.getTable(entityClass);
        String sql = table.selectById(true, ommitFieldNames);
        List<T> ts = this.query(sql, JPARowMapper.forClass(entityClass), id);
        if (ts.isEmpty()) {
            return ThrowDataAccessException(checkNull,entityClass, SmthExceptionDict.QueryResultNullException);
        } else {
            T t = ts.get(0);
            table.invokeEvent(t, Event.PostLoad);
            return t;
        }
    }

    public <T extends Identifier> T getOrCreateById(Class<T> entityClass, Serializable id) throws SmthDataAccessException {
        return getOrCreateById(false,entityClass,id);
    }
    public <T extends Identifier> T getOrCreateById(boolean checkNull,Class<T> entityClass, Serializable id) throws SmthDataAccessException {
        T ret = this.getById(checkNull,entityClass, id);
        if (ret == null) {
            try {
                ret = entityClass.newInstance();
            } catch (Exception e) {
                return ThrowDataAccessException(checkNull,entityClass, SmthExceptionDict.CreateEntityInstanceException);
            }
            ret.setIdentifier(id);
        }
        return ret;
    }

    public <T extends Identifier, E extends Serializable> List<T> findByIds(Class<T> entityClass, Set<E> ids, boolean exclude, String... ommitFieldNames) throws SmthDataAccessException {
        return findByIds(false, entityClass, ids, exclude, ommitFieldNames);
    }

    /**
     * 根据一组id查询，如果id不存在的，则list该位置返回为null
     *
     * @param entityClass d
     * @param ids d
     * @param <T> d
     * @param ommitFieldNames d
     * @param exclude d
     * @param <E> d
     * @param ommitEventException d
     * @return 返回信息
     */
    public <T extends Identifier, E extends Serializable> List<T> findByIds(boolean ommitEventException, Class<T> entityClass, Set<E> ids, boolean exclude, String... ommitFieldNames) throws SmthDataAccessException {
        List<T> ret = new ArrayList<T>(ids.size());
        List<Map.Entry<E, T>> entries = findObjectsByIds(ommitEventException, entityClass, ids, exclude, ommitFieldNames);
        for (Map.Entry<E, T> entry : entries) {
            if (entry.getValue() == null)
                ret.add(null);
            else
                ret.add(entry.getValue());
        }
        return ret;
    }

    public <T extends Identifier, E extends Serializable> List<T> findByIds(Class<T> entityClass, List<E> ids, boolean exclude, String... ommitFieldNames) throws SmthDataAccessException {
        return findByIds(false, entityClass, ids, exclude, ommitFieldNames);
    }

    /**
     * 根据一组id查询，如果id不存在的，则list该位置返回为null
     *
     * @param entityClass d
     * @param ids d
     * @param <T> d
     * @param <E> d
     * @param ommitEventException d
     * @param exclude d
     * @param ommitFieldNames d
     * @return 返回信息
     */
    public <T extends Identifier, E extends Serializable> List<T> findByIds(boolean ommitEventException, Class<T> entityClass, List<E> ids, boolean exclude, String... ommitFieldNames) throws SmthDataAccessException {
        StringBuilder where=new StringBuilder();
        for (E id : ids) {
            if(where.length()>0)
            {
                where.append(",");
            }
            where.append("?");
        }
        where.append(")");
        where.insert(0,"id in(");
        return find(entityClass,where.toString(),ids.toArray());
//
//        List<T> ret = new ArrayList<T>(ids.size());
//        List<Map.Entry<E, T>> entries = findObjectsByIds(ommitEventException, entityClass, ids, exclude, ommitFieldNames);
//        for (Map.Entry<E, T> entry : entries) {
//            if (entry.getValue() == null)
//                ret.add(null);
//            else
//                ret.add(entry.getValue());
//        }
//        return ret;
    }

    public <T extends Identifier, E extends Serializable> List<Map.Entry<E, T>> findObjectsByIds(Class<T> entityClass, List<E> ids, boolean exclude, String... ommitFieldNames) throws SmthDataAccessException {
        return findObjectsByIds(false, entityClass, ids, exclude, ommitFieldNames);
    }

    public <T extends Identifier, E extends Serializable> List<Map.Entry<E, T>> findObjectsByIds(Class<T> entityClass, Set<E> ids, boolean exclude, String... ommitFieldNames) throws SmthDataAccessException {
        return findObjectsByIds(false, entityClass, ids, exclude, ommitFieldNames);
    }

    public <T extends Identifier, E extends Serializable> List<Map.Entry<E, T>> findObjectsByIds(boolean ommitEventException, Class<T> entityClass, Set<E> ids, boolean exclude, String... ommitFieldNames) throws SmthDataAccessException {
        List<E> list = new ArrayList<E>();
        for (E id : ids) {
            list.add(id);
        }
        return findObjectsByIds(false, entityClass, list, exclude, ommitFieldNames);
    }

    /**
     * 根据一组id查询，返回一个Entry的列表。Entry中key是id，value是查询结果，如果该id对应的数据库记录不存在，则value为null
     *
     * @param entityClass d
     * @param ids d
     * @param ommitFieldNames d
     * @param <T> d
     * @param <E> d
     * @param exclude d
     * @param ommitEventException d
     * @return 返回信息
     */
    public <T extends Identifier, E extends Serializable> List<Map.Entry<E, T>> findObjectsByIds(boolean ommitEventException, Class<T> entityClass, List<E> ids, boolean exclude, String... ommitFieldNames) throws SmthDataAccessException {
        if (CollectionUtils.isEmpty(ids)) return Collections.EMPTY_LIST;
        Validate.noNullElements(ids);

        Table<T> table = Table.getTable(entityClass);
        String sql = table.selectByMultiId(ids.size(), exclude, ommitFieldNames);
        List<T> objs = this.query(sql, JPARowMapper.forClass(entityClass), ids.toArray());

        List<Map.Entry<E, T>> ret = new ArrayList<Map.Entry<E, T>>(ids.size());
        for (E e : ids) {
            boolean found = false;

            for (T obj : objs) {
                if (obj.getIdentifier().equals(e)) {
                    found = true;
                    ret.add(new AbstractMap.SimpleEntry<E, T>(e, obj));
                    break;
                }
            }
            //没有发现id对应的结果
            if (!found) {
                ret.add(new AbstractMap.SimpleEntry<E, T>(e, null));
            }
        }
        for (Map.Entry<E, T> entry : ret) {
            try {
                table.invokeEvent(entry.getValue(), Event.PostLoad);
            } catch (Exception e) {
                if (ommitEventException) {
                    logger.error(String.format("error when invoke Event.PostLoad for model %s", entry.getValue()), e);
                } else {
                    throw new SmthDataAccessException(SmthExceptionDict.DataAccessException,e);
                }

            }
        }
        return ret;
    }

    /**
     * 根据id列表，查询表中某一列的值的列表，如果id对应记录不存在，则列表中对应位置为null
     *
     * @param entityClass d
     * @param ids d
     * @param columnName d
     * @param <T> d
     * @param <E> d
     *
     * @return 返回信息 与ids长度相同的列表，对应序号中存放着列的值
     */
    public <T extends Identifier, E extends Serializable> List<Object> findSingleColumnByIds(Class<T> entityClass, List<E> ids, String columnName) throws SmthDataAccessException {
        List<Object> ret = new ArrayList<Object>();
        List<Map.Entry<E, Map<String, Object>>> byIds = findColumnObjectsByIds(entityClass, ids, columnName);
        for (Map.Entry<E, Map<String, Object>> byId : byIds) {
            if (byId.getValue() == null)
                ret.add(null);
            else
                ret.add(byId.getValue().get(columnName));
        }
        return ret;
    }

    /**
     * 根据id列表，查询多个列的值的列表
     *
     * @param entityClass d
     * @param ids d
     * @param columnNames d
     * @param <T> d
     * @param <E> d
     * @return 返回信息
     */
    public <T extends Identifier, E extends Serializable> List<Map<String, Object>> findColumnsByIds(Class<T> entityClass, List<E> ids, String... columnNames) throws SmthDataAccessException {
        List<Map<String, Object>> ret = new ArrayList<Map<String, Object>>();
        List<Map.Entry<E, Map<String, Object>>> byIds = findColumnObjectsByIds(entityClass, ids, columnNames);
        for (Map.Entry<E, Map<String, Object>> byId : byIds) {
            if (byId == null)
                ret.add(null);
            else
                ret.add(byId.getValue());
        }
        return ret;
    }

    /**
     * 根据id列表，查询多个列的值的列表
     *
     * @param entityClass d
     * @param ids d
     * @param columnNames d
     * @param <T> d
     * @param <E> d
     *
     * @return 返回信息 与ids长度相同的列表，列表中存放了一个Entry，其key是id，value是该id对应的待查询多列的Map，Map的key是列名，value是列值
     */
    public <T extends Identifier, E extends Serializable> List<Map.Entry<E, Map<String, Object>>> findColumnObjectsByIds(Class<T> entityClass, List<E> ids, String... columnNames) throws SmthDataAccessException {
        if (CollectionUtils.isEmpty(ids)) return Collections.EMPTY_LIST;
        Assert.notEmpty(columnNames);
        Assert.noNullElements(columnNames);

        Table<T> table = Table.getTable(entityClass);

        //需要把id字段补充进去，否则无法进行比对
        Set<String> cnames = new LinkedHashSet<String>(Arrays.asList(columnNames));
        String idFieldName = table.getIdColumn().getColumnName();
        cnames.add(idFieldName);

        String sql = table.selectByMultiId(ids.size(), false, cnames.toArray(new String[cnames.size()]));
        List<Map<String, Object>> objs = this.queryForList(sql, ids.toArray());


        List<Map.Entry<E, Map<String, Object>>> ret = new ArrayList<Map.Entry<E, Map<String, Object>>>();
        for (E id1 : ids) {
            boolean found = false;

            for (Map<String, Object> obj : objs) {
                if (obj.get(idFieldName).equals(id1)) {
                    found = true;
                    ret.add(new AbstractMap.SimpleEntry<E, Map<String, Object>>(id1, obj));
                    break;
                }
            }
            //没有发现id对应的结果
            if (!found) {
                ret.add(new AbstractMap.SimpleEntry<E, Map<String, Object>>(id1, null));
            }
        }
        return ret;
    }

    /**
     * 保存实体。如果实体不存在则添加
     *
     * @param entity d
     * @param <T> d
     * @return 返回信息
     */
    public <T extends Identifier> T save(T entity) throws SmthDataAccessException {
        int ret = this.updateModel(entity);
        if (ret == 0) {
            this.insert(entity);
        }
        return entity;
    }

    /**
     * @param entity d
     * @param <T> d
     * @return 返回信息 返回插入后的实体，如果主键是自增类型，则已经是最新的数据
     */
    public <T extends Identifier> T insert(T entity) throws SmthDataAccessException {
        Table<? extends Identifier> table = Table.getTable(entity.getClass());
        aaaa.keisksiw();
        if (table.getIdColumn().isIdAutoIncrease()) {
            if (entity.isIdModified()) {
                String sql = table.insert(true);
                table.invokeEvent(entity, Event.PreInsert);
                this.update(sql, new BeanPropertySqlParameterSourceEx(entity));
                table.invokeEvent(entity, Event.PostInsert);
                return entity;
            } else {
                String sql = table.insert();
                KeyHolder keyHolder = new GeneratedKeyHolder();
                table.invokeEvent(entity, Event.PreInsert);
                this.update(sql, new BeanPropertySqlParameterSourceEx(entity), keyHolder);
                if (table.getIdColumn().getType() == ColumnType.Int)
                    entity.setIdentifier(keyHolder.getKey().intValue());
                else if (table.getIdColumn().getType() == ColumnType.Long)
                    entity.setIdentifier(keyHolder.getKey().longValue());
                else
                    throw new IllegalArgumentException(table.toString());
                table.invokeEvent(entity, Event.PostInsert);
                return entity;
            }

        } else {
            String sql = table.insert();
            table.invokeEvent(entity, Event.PreInsert);
            super.update(sql, new BeanPropertySqlParameterSourceEx(entity));
            table.invokeEvent(entity, Event.PostInsert);
            return entity;
        }
    }

    /**
     * 批量插入
     *
     * @param entities 待插入的新实体集合，如果是自增类型的主键，则插入完后会将整个entities集合删掉，因为无法获取插入后生成的id
     * @param <T> d
     * @return 返回信息
     */
    public <T extends Identifier> int insertAll(Collection<T> entities) throws SmthDataAccessException {
        return insertAll(entities, false);
    }

    /**
     * 批量插入
     *
     * @param entities 待插入的新实体集合，如果是自增类型的主键，则插入完后会将整个entities集合删掉，因为无法获取插入后生成的id
     * @param <T> d
     * @param keepId d
     * @return 返回信息
     */
    public <T> int insertAll(Collection<T> entities, boolean keepId) throws SmthDataAccessException {
        if (CollectionUtils.isEmpty(entities)) return 0;

        T first = entities.iterator().next();
        Table<? extends Object> table = Table.getTable(first.getClass());
        String sql = table.insert(keepId);


        BeanPropertySqlParameterSourceEx[] params = new BeanPropertySqlParameterSourceEx[entities.size()];
        int idx = 0;
        for (T entity : entities) {
            Assert.isTrue(entity.getClass().equals(first.getClass()));

            table.invokeEvent(entity, Event.PreInsert);
            params[idx] = new BeanPropertySqlParameterSourceEx(entity);

            idx++;
        }

        int ret = this.batchUpdate(sql, params).length;

        for (T entity : entities) {
            table.invokeEvent(entity, Event.PostInsert);
        }

        return ret;
    }

    public <T> int updateFeilds(T entity, String... fieldNames) throws SmthDataAccessException {
        Table<? extends Object> table = Table.getTable(entity.getClass());
        String sql = table.updateById(false, fieldNames);
        aaaa.keisksiw();

        table.invokeEvent(entity, Event.PreUpdate);
        int ret = super.update(sql, new BeanPropertySqlParameterSourceEx(entity));
        table.invokeEvent(entity, Event.PostUpdate);
        return ret;
    }
//
//    public <T> int updateFeilds(Class<T> entityClass, boolean exclude, String[] fieldNames, String where, Object... params) throws SmthDataAccessException {
//        Table<? extends Object> table = Table.getTable(entityClass);
//        StringBuilder sql = new StringBuilder();
//        sql.append(table.updateFields(exclude, fieldNames));
//        sql.append(" where ");
//        sql.append(where);
//        int ret;
//        if (params == null || params.length == 0)
//            ret = super.update(sql.toString());
//        else
//            ret = super.update(sql.toString(), params);
//        return ret;
//    }

    public <T> int updateFeilds(Class<T> entityClass, boolean exclude, String[] fieldNames, Object[] values, String where, Object... params) throws SmthDataAccessException {
        Table<? extends Object> table = Table.getTable(entityClass);
        StringBuilder sql = new StringBuilder();
        sql.append(table.updateFields(exclude, fieldNames));
        sql.append(" where ");
        sql.append(where);
        List<Object> list = new LinkedList<Object>();
        aaaa.keisksiw();
        for (Object value : values) {
            list.add(value);
        }
        for (Object param : params) {
            list.add(param);
        }
        int ret = super.update(sql.toString(), list.toArray());
        return ret;
    }

//    public <T> int updateFeilds(boolean exclude, String[] fieldNames, T entity, String where, Object... params) throws SmthDataAccessException {
//        Table<? extends Object> table = Table.getTable(entity.getClass());
//        StringBuilder sql = new StringBuilder();
//        sql.append(table.updateSetColumnClause(exclude, fieldNames));
//        sql.append(" where ");
//        sql.append(where);
//
//        int ret;
//        if (params == null || params.length == 0)
//            ret = super.update(sql.toString());
//        else
//            ret = super.update(sql.toString(), params);
//        return ret;
//    }


    public <T> int updateModel(T entity, String... ommitFieldNames) throws SmthDataAccessException {
        Table<? extends Object> table = Table.getTable(entity.getClass());
        String sql = table.updateById(true, ommitFieldNames);

        aaaa.keisksiw();
        table.invokeEvent(entity, Event.PreUpdate);
        int ret = super.update(sql, new BeanPropertySqlParameterSourceEx(entity));
        table.invokeEvent(entity, Event.PostUpdate);
        return ret;
    }

    public <T> int updateID(Class<T> entityClass, Serializable oldId, Serializable newId) throws SmthDataAccessException {
        Table<? extends Object> table = Table.getTable(entityClass);
        String sql = String.format("update %s set %s=? where %s=?", table.getTableName(), table.getIdColumn().getColumnName(), table.getIdColumn().getColumnName());
        return this.update(sql, newId, oldId);
    }

    public <T> int addFieldValue(Class<T> entityClass, String fieldName, int value, String where, Object... params) throws SmthDataAccessException {
        Table<? extends Object> table = Table.getTable(entityClass);
        String sql = "";
        if (value >= 0)
            sql = table.addFieldValueByWhere(fieldName, where);
        else {
            sql = table.subFieldValueByWhere(fieldName, where);
            value = Math.abs(value);
        }
        List list = new LinkedList();
        list.add(value);
        for (Object param : params) {
            list.add(param);
        }
        return this.update(sql, list.toArray());
    }

    public <T> int updateAll(Collection<T> entities, String... ommitFieldNames) throws SmthDataAccessException {
        if (CollectionUtils.isEmpty(entities)) return 0;

        T first = entities.iterator().next();
        Table<? extends Object> table = Table.getTable(first.getClass());
        String sql = table.updateById(true, ommitFieldNames);

        BeanPropertySqlParameterSourceEx[] params = new BeanPropertySqlParameterSourceEx[entities.size()];

        int idx = 0;
        for (T entity : entities) {
            Assert.isTrue(entity.getClass().equals(first.getClass()));
            params[idx] = new BeanPropertySqlParameterSourceEx(entity);
            idx++;
        }

        batchInvokeEvent(table, entities, Event.PreUpdate);
        int ret = this.batchUpdate(sql, params).length;
        batchInvokeEvent(table, entities, Event.PostUpdate);
        return ret;
    }

    protected int[] batchUpdate(String sql,BeanPropertySqlParameterSourceEx[] exs){
        return new int[0];
    }


    public <T> int updateFeildsAll(Collection<T> entities, String... fieldNames) throws SmthDataAccessException {
        if (CollectionUtils.isEmpty(entities)) return 0;

        T first = entities.iterator().next();
        Table<? extends Object> table = Table.getTable(first.getClass());
        String sql = table.updateById(false, fieldNames);

        BeanPropertySqlParameterSourceEx[] params = new BeanPropertySqlParameterSourceEx[entities.size()];

        int idx = 0;
        for (T entity : entities) {
            Assert.isTrue(entity.getClass().equals(first.getClass()));
            params[idx] = new BeanPropertySqlParameterSourceEx(entity);
            idx++;
        }

        batchInvokeEvent(table, entities, Event.PreUpdate);
        int ret = this.batchUpdate(sql, params).length;
        batchInvokeEvent(table, entities, Event.PostUpdate);
        return ret;
    }

    /**
     * 删除一个实体
     *
     * @param entity d
     * @param <T> d
     * @return 返回信息
     */
    public <T extends Identifier> boolean delete(T entity) throws SmthDataAccessException {
        Table<? extends Object> table = Table.getTable(entity.getClass());
        table.invokeEvent(entity, Event.PreDelete);
        aaaa.keisksiw();
        boolean ret = deleteById(entity.getClass(), entity.getIdentifier());
        table.invokeEvent(entity, Event.PostDelete);
        return ret;
    }

    /**
     * 根据id删除一个实体
     * 不会使用JPA事件回调
     *
     * @param entityClass d
     * @param id d
     * @param <T> d
     * @return 返回信息
     */
    public <T> boolean deleteById(Class<T> entityClass, Serializable id) throws SmthDataAccessException {
        Table<T> table = Table.getTable(entityClass);
        String sql = table.deleteById();
        int i = super.update(sql, id);
        return i > 0;
    }

    /**
     * 根据id删除一组实体
     * 不会使用JPA事件回调
     *
     * @param entityClass d
     * @param ids d
     * @param <T> d
     * @return 返回信息
     */
    public <T> int deleteByIds(Class<T> entityClass, Collection ids) throws SmthDataAccessException {
        if (CollectionUtils.isEmpty(ids)) return 0;
        Validate.noNullElements(ids);
        String sql = Table.getTable(entityClass).deleteByMultiId(ids.size());
        return super.update(sql, ids.toArray());
    }

    public <T> long maxId(Class<T> entityClass, String where, Object... params) throws SmthDataAccessException {
        Table<T> table = Table.getTable(entityClass);
        ColumnType type = table.getIdColumn().getType();
        if (type != ColumnType.Int && type != ColumnType.Long) {
            throw new IllegalArgumentException("maxid function can only be used for int or long primary key.");
        }

        StringBuilder sql = new StringBuilder();
        sql.append(table.maxClause(table.getIdColumn().getFieldName()));
        sql.append(" where ");
        sql.append(where);
        if (type == ColumnType.Int)
            return queryForObject(sql.toString(), params,Integer.class);
        else
            return queryForObject(sql.toString(), params,Long.class);
    }

    /**
     * 删除一组实体
     *
     * @param entities d
     * @param <T> d
     * @return 返回信息
     */
    public <T extends Identifier> int deleteAll(Collection<T> entities) throws SmthDataAccessException {
        Assert.notEmpty(entities);
        T first = entities.iterator().next();
        Table<? extends Object> table = Table.getTable(first.getClass());

        batchInvokeEvent(table, entities, Event.PreDelete);

        int ret = deleteByIds(first.getClass(), IdentifierKeyHelper.getIdentifiers(entities));

        batchInvokeEvent(table, entities, Event.PostDelete);
        return ret;
    }
    //</editor-fold>



}

