package org.jsmth.data.dao;

import org.apache.commons.dbcp.BasicDataSource;
import org.apache.commons.lang.Validate;
import org.jsmth.data.dialect.Dialect;
import org.jsmth.data.dialect.DialectFactory;
import org.jsmth.data.jdbc.JdbcDao;
import org.jsmth.data.jdbc.Query;
import org.jsmth.data.schema.ObjectTableMeta;
import org.jsmth.data.sql.EntityQuery;
import org.jsmth.data.sql.wrap.WhereWrap;
import org.jsmth.domain.Identifier;
import org.jsmth.exception.SmthDataAccessException;
import org.jsmth.jorm.jdbc.SchemaUpdateStrategy;
import org.jsmth.jorm.jdbc.SlaveJdbcDao;
import org.jsmth.page.CommonPage;
import org.jsmth.page.Page;
import org.springframework.util.Assert;

import javax.annotation.PostConstruct;
import javax.persistence.MappedSuperclass;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
 * Created by mason on 15/12/26.
 */
@MappedSuperclass
public class ReadWriteEntityDao<KEY extends Serializable, MODEL extends Identifier<KEY>> extends EntityDao<KEY, MODEL> implements IReadWriteEntityDao<KEY, MODEL> {
    protected JdbcDao slaveJdbcDao;
    protected EntityDao<KEY,MODEL> slaveDao;

    public ReadWriteEntityDao(Class<MODEL> entityClass) {
        super(entityClass);
    }

    public JdbcDao getMasterJdbcDao() {
        return getJdbcDao();
    }

    public void setMasterJdbcDao(JdbcDao masterJdbcDao) {
        setJdbcDao(masterJdbcDao);
    }

    public JdbcDao getSlaveJdbcDao() {
        return slaveJdbcDao;
    }

    public void setSlaveJdbcDao(JdbcDao slaveJdbcDao) {
        this.slaveJdbcDao = slaveJdbcDao;
    }

    public EntityDao<KEY,MODEL> getSlaveDao(){
        if(slaveDao==null){
            slaveDao=new EntityDao<>(this.entityClass);
            slaveDao.setJdbcDao(getSlaveJdbcDao());
        }
        return slaveDao;
    }

    @PostConstruct
    public void init() {
        Assert.notNull(entityClass, "entityClass must be set!");
        Assert.notNull(getMasterJdbcDao(), "master jdbcDao must be set!");
        Assert.notNull(getSlaveJdbcDao(), "slave jdbcDao must be set!");

        getJdbcDao().setEntityEventCallback(this);
        getSlaveJdbcDao().setEntityEventCallback(this);
//        if (updateSchemaWhileInitial) {
//            if ((getSlaveJdbcDao() instanceof SlaveJdbcDao) && ((SlaveJdbcDao) getSlaveJdbcDao()).isCluster()) {
//                //如果说是真正的主库，则不能自动创建表，必须显示的调用api去创建，其它的仍然自动创建表
//            } else {
//                //本地环境或者根本还没用集群从库，则仍然需呀主库创建索引
//                getMasterDao().updateSchema(entityClass, true);
//            }
//        }
    }

    /**
     * 对于主从分离，需要在具体的entityDao上进行rebuildSchema的操作
     *
     * @param entityClass ff
     * @throws SmthDataAccessException ss
     */
    @SuppressWarnings({"unchecked"})
    public void rebuildSchema(Class... entityClass) throws SmthDataAccessException {
        for (Class clazz : entityClass) {
            getMasterJdbcDao().rebuildSchema(false,clazz);
            getMasterJdbcDao().rebuildSchema(true,clazz);
        }
    }

    @Override
    public List<MODEL> findByIds(List<? extends KEY> ids) {
        if (ids == null || ids.size() == 0) {
            return new ArrayList<>();
        }
        ObjectTableMeta<MODEL> table = ObjectTableMeta.getTable(entityClass);
        return getSlaveJdbcDao().queryByColumns(entityClass, table.getIdColumn().getFieldName(), ids.toArray());
    }

    @Override
    public List<KEY> findIds(Class<KEY> kClass, String where, Object... params) {
        ObjectTableMeta<MODEL> table = ObjectTableMeta.getTable(entityClass);
        return getSlaveJdbcDao().queryColumn(entityClass, kClass, table.getIdColumn().getFieldName(), where,params);
    }

    @Override
    public Page<MODEL> pageModels(int pageNumber, int pageSize, boolean totalRecord, String where, Object... params) {
        String sql = getJdbcDao().buildQuerySql(entityClass, "count(*)", where);
        Integer count = getJdbcDao().queryForObject(Integer.class, sql, params);
        sql = getSlaveJdbcDao().buildQuerySql(entityClass, "*", where);
        if (pageNumber < 1) {
            pageNumber = 1;
        }
        CommonPage<MODEL> page = new CommonPage<>();
        page.setPageNumber(pageNumber);
        page.setPageSize(pageSize);
        page.setTotalItemsCount(count);
        Dialect dialect = DialectFactory.getDialect((BasicDataSource) getJdbcDao().getDataSource());
        sql = dialect.getLimitString(sql, (pageNumber - 1) * pageSize, pageSize);
        page.setItems(getSlaveJdbcDao().queryForEntityList(entityClass, sql, params));
        return page;
    }

    @Override
    public <C> List<C> findColumns(Class<C> columnClass, String fieldName, String where, Object... params) {
        WhereWrap whereWrap = new WhereWrap();
        whereWrap.placeholderW(where, params);
        return getSlaveJdbcDao().queryColumn(entityClass, columnClass, fieldName, whereWrap);
    }

    @Override
    public <K> Page<K> pageColumns(Class<K> kClass, int pageNumber, int pageSize, boolean totalRecord, String select, String where, Object... params) {
        String sql = getJdbcDao().buildQuerySql(entityClass, "count(*)", where);
        Integer count = getJdbcDao().queryForObject(Integer.class, sql, params);
        sql = getSlaveJdbcDao().buildQuerySql(entityClass, select, where);
        if (pageNumber < 1) {
            pageNumber = 1;
        }
        CommonPage<K> page = new CommonPage<>();
        page.setPageNumber(pageNumber);
        page.setPageSize(pageSize);
        page.setTotalItemsCount(count);

        Dialect dialect = DialectFactory.getDialect((BasicDataSource) getJdbcDao().getDataSource());
        sql = dialect.getLimitString(sql, (pageNumber - 1) * pageSize, pageSize);
        page.setItems(getSlaveJdbcDao().queryForList(kClass, sql, params));
        return page;
    }

    @Override
    public int count(String where, Object... params) {
        WhereWrap whereWrap = new WhereWrap();
        whereWrap.placeholderW(where, params);
        return getSlaveJdbcDao().queryForInt(entityClass, whereWrap);
    }

    @Override
    public List<Map<String, Object>> findMaps(String where, Object... params) {
        return getSlaveJdbcDao().findMaps(getEntityClass(), where, params);
    }

    @Override
    public List<Map<String, Object>> findMaps(String select, String where, Object... params) {
        return getSlaveJdbcDao().findMaps(getEntityClass(), select, where, params);
    }

    @Override
    public Map<String, Object> getMap(String where, Object... params) {
        return getSlaveJdbcDao().getMap(getEntityClass(), where, params);
    }

    @Override
    public Map<String, Object> getMap(String select, String where, Object... params) {
        return getSlaveJdbcDao().getMap(getEntityClass(), select, where, params);
    }

    @Override
    public List<MODEL> findModels(String where, Object... params) {
        WhereWrap whereWrap=new WhereWrap();
        whereWrap.placeholderW(where,params);
        return getSlaveJdbcDao().query(entityClass, whereWrap);
    }

    @Override
    public List<MODEL> findAll() {
        return getSlaveJdbcDao().queryAll(entityClass);
    }

    @Override
    public Page<MODEL> findAll(int pageNumber, int pageSize) {
        return getSlaveJdbcDao().queryAllForPage(entityClass, pageNumber, pageSize);
    }

    @Override
    public Page<MODEL> findPageModels(WhereWrap wrap, int pageNumber, int pageSize, boolean totalRecord) {
        return getSlaveJdbcDao().queryPage(entityClass, wrap, pageNumber, pageSize, totalRecord);
    }

    @Override
    public Page<MODEL> findPageModels(EntityQuery query) {
        return getSlaveJdbcDao().queryPage(entityClass, query);
    }

    @Override
    public List<MODEL> findModels(WhereWrap wrap) {
        return getSlaveJdbcDao().query(entityClass, wrap);
    }

    @Override
    public List<MODEL> findModels(EntityQuery query) {
        return getSlaveJdbcDao().query(entityClass, query);
    }

    @Override
    public <K> List<K> findColumns(Class<K> kClass, WhereWrap wrap) {
        return getSlaveJdbcDao().query(kClass, wrap);
    }

    @Override
    public <K> List<K> findColumns(Class<K> kClass, EntityQuery query) {
        return getSlaveJdbcDao().queryColumn(kClass, query);
    }


    @Override
    public <K> Page<K> pageColumns(Class<K> kClass, WhereWrap wrap, int pageNumber, int pageSize, boolean totalRecord) {
        return getSlaveJdbcDao().queryPage(kClass, wrap, pageNumber, pageSize, totalRecord);
    }

    @Override
    public <K> Page<K> pageColumns(Class<K> kClass, EntityQuery query) {
        return getSlaveJdbcDao().queryPageColumn(kClass, query);
    }



    @Override
    public MODEL getModel(WhereWrap where) {
        return getSlaveJdbcDao().queryForObject(entityClass, where);
    }

    @Override
    public MODEL getModel(EntityQuery query) {
        return getSlaveJdbcDao().queryForObject(entityClass, query);
    }

    @Override
    public Page<MODEL> pageFindModels(String where, int pageNumber, int pageSize, boolean isGetTotal, Object... params) {
        WhereWrap whereWrap=new WhereWrap();
        whereWrap.placeholderW(where,params);
        return getSlaveJdbcDao().queryPage(entityClass, whereWrap, pageNumber, pageSize);
    }

    @Override
    public <T, E> List<E> groupColumn(Class<E> columnClass, String fieldName, String where, Object... params) {
        WhereWrap whereWrap=new WhereWrap();
        whereWrap.placeholderW(where,params);
        return getSlaveJdbcDao().queryColumn(entityClass, columnClass, fieldName, whereWrap);
    }

    @Override
    public <T> List<T> findColumn(Class<T> columnClass, String fieldName, String where, Object... params) {
        WhereWrap whereWrap=new WhereWrap();
        whereWrap.placeholderW(where,params);
        return getSlaveJdbcDao().queryColumn(entityClass, columnClass, fieldName, whereWrap);
    }

    @Override
    public List<MODEL> query(WhereWrap where) {
        return getSlaveJdbcDao().query(entityClass, where);
    }

    @Override
    public MODEL queryForObject(WhereWrap where) {
        return getSlaveJdbcDao().queryForObject(entityClass, where);
    }

    @Override
    public List<Map<String, Object>> queryForList(WhereWrap where) {
        return getSlaveJdbcDao().queryForList(entityClass, where);
    }

    //    @Override
//    public List<KEY> getByIds1(String where, Object... params) throws SmthDataAccessException {
//        return getSlaveJdbcDao().findIds(entityClass, where, params);
//    }
//
//    public List<KEY> findIds(int limit, boolean asceding) throws SmthDataAccessException {
//        if (limit < 1) limit = 1;
//        String sql = String.format("1=1 order by id %s limit %d", (asceding ? "asc" : "desc"), limit);
//        return getSlaveJdbcDao().findIds(entityClass, sql);
//    }
//
//    @Override
//    public List<MODEL> findAll() throws SmthDataAccessException {
//        return getSlaveJdbcDao().findAll(entityClass);
//    }
//
//    /**
//     * 根据一组条件，查询得到一组model
//     *
//     * @param where
//     * @param params
//     * @return 返回信息
//     */
//    @Override
//    public List<MODEL> findModels(String where, Object... params) throws SmthDataAccessException {
//        return getSlaveJdbcDao().find(entityClass, where, params);
//    }
//
//    @Override
//    public int countModels(String where, Object... params) throws SmthDataAccessException {
//        return getSlaveJdbcDao().count(entityClass, where, params);
//    }
//
//    @Override
//    public <T, E> List<Map.Entry<E, Integer>> groupCountColumn(Class<E> columnClass, String fieldName, String countFieldName, String where, Object... params) throws SmthDataAccessException {
//        return getSlaveJdbcDao().groupCountColumn(entityClass, columnClass, fieldName, countFieldName, where, params);
//    }
//
//    @Override
//    public <T, E> List<Map.Entry<E, Integer>> groupMaxColumn(Class<E> columnClass, String fieldName, String maxFieldName, String where, Object... params) throws SmthDataAccessException {
//        return getSlaveJdbcDao().groupMaxColumn(entityClass, columnClass, fieldName, maxFieldName, where, params);
//    }
//
//    @Override
//    public <T, E> List<Map.Entry<E, Integer>> groupMinColumn(Class<E> columnClass, String fieldName, String minFieldName, String where, Object... params) throws SmthDataAccessException {
//        return getSlaveJdbcDao().groupMinColumn(entityClass, columnClass, fieldName, minFieldName, where, params);
//    }
//
//    @Override
//    public <T, E> List<Map.Entry<E, Integer>> groupSumColumn(Class<E> columnClass, String fieldName, String sumFieldName, String where, Object... params) throws SmthDataAccessException {
//        return getSlaveJdbcDao().groupSumColumn(entityClass, columnClass, fieldName, sumFieldName, where, params);
//    }
//
//    /**
//     * 根据一组条件，查询得到一个model
//     *
//     * @param where
//     * @param params
//     * @return 返回信息
//     */
//    @Override
//    public MODEL findUnique(String where, Object... params) throws SmthDataAccessException {
//        List<MODEL> model = findModels(where, params);
//        if (model.isEmpty()) {
//            return null;
//        } else {
//            return model.get(0);
//        }
//    }
}
