package cn.xisoil.dao.utils.impl;

import cn.xisoil.annotation.batch.BatchSQLDelete;
import cn.xisoil.dao.utils.BatchRepository;
import com.fasterxml.jackson.annotation.JsonFormat;
import jakarta.persistence.*;
import org.apache.commons.lang3.StringUtils;
import org.hibernate.annotations.SQLDelete;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.jpa.repository.support.JpaEntityInformation;
import org.springframework.data.jpa.repository.support.SimpleJpaRepository;

import java.io.Serializable;
import java.lang.reflect.Field;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;

/**
 * @Description TODO
 * @Author Vien
 * @CreateTime 2023-04-2023/4/18 16:56:36
 **/

public class BatchRepositoryImpl<T, ID extends Serializable> extends SimpleJpaRepository<T, ID> implements BatchRepository<T, ID> {

    private  final Logger logger = LoggerFactory.getLogger(this.getClass());

    private EntityManager entityManager;
    private Field[] fields = null;
    private String tableName;
    private String entityName;

    private String id="id";
    private Field idField;

    private Class<T> tClass;

    private Integer BATCH_SIZE= 1000;

    public BatchRepositoryImpl(JpaEntityInformation<T, ?> entityInformation, EntityManager entityManager) {
        super(entityInformation, entityManager);
        this.entityManager = entityManager;
        tClass = getDomainClass();
        fields = tClass.getDeclaredFields();
        tableName = getTableName();
        entityName = getEntityName();
        id=getId();
    }


    public BatchRepositoryImpl(Class<T> domainClass, EntityManager em) {
        super(domainClass, em);

        this.entityManager = em;
    }


    @Override
    public void saveAll(List<T> list) {
        StringBuilder sb = into();
        int i = 0;
        for (T dis : list) {
            i++;
            if (i > 1) {
                sb.append(",");
            }
            //使用dis的所有字段值拼接sql
            sb.append("(");
            for (Field field : fields) {
                //设置为true，可以访问私有变量
                field.setAccessible(true);
                if (isIgnore(field)) {
                    continue;
                }
                try {
                    if (!sb.toString().endsWith("(")) {
                        sb.append(",");
                    }
                    //如果是列表类型，转化成逗号分隔
                    if (field.get(dis) instanceof List || field.get(dis) instanceof Set) {
                        String list1 = field.get(dis).toString();
                        sb.append("'").append(list1.replace("[", "").replace("]", "")).append("'");
                        continue;
                    }
                    //如果是布尔值，转化成0或1
                    if (field.get(dis) instanceof Boolean) {
                        sb.append((Boolean) field.get(dis) ? 1 : 0);
                        continue;
                    }
                    //如果是时间类型，转化成时间戳
                    if (field.get(dis) instanceof java.util.Date) {
                        //如果有JsonFormat注解，转化成指定格式
                        if (field.getAnnotation(JsonFormat.class) != null) {
                            sb.append("'").append(new SimpleDateFormat(field.getAnnotation(JsonFormat.class).pattern()).format(field.get(dis))).append("'");
                            continue;
                        }
                        else {
                            //否则转成标准格式
                            sb.append("'").append(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(field.get(dis))).append("'");
                        }
                        continue;
                    }
                    //如果是枚举，则判断是否存在@Enumerated注解，如果存在则转化成枚举的name，否则转化成枚举的ordinal
                    if (field.get(dis) instanceof Enum) {
                        if (field.getAnnotation(Enumerated.class) != null) {
                            if (field.getAnnotation(Enumerated.class).value().equals(EnumType.STRING)) {
                                sb.append("'").append(((Enum) field.get(dis)).name()).append("'");
                                continue;
                            }
                        }
                        sb.append(((Enum) field.get(dis)).ordinal());
                        continue;
                    }
                    //如果是空值，转化成NULL
                    if (field.get(dis) == null) {
                        sb.append("NULL");
                        continue;
                    }
                    //field.get(dis)包含单引号，进行反斜杠转义
                    if (field.get(dis).toString().contains("'")) {
                        sb.append("'").append(field.get(dis).toString().replace("'", "''")).append("'");
                        continue;
                    }
                    sb.append("'").append(field.get(dis)).append("'");
                } catch (IllegalAccessException e) {
                    e.printStackTrace();
                }
            }
            sb.append(")");
            //一千条执行一次
            if (i >= BATCH_SIZE) {
                //以上sql,添加存在则修改
                setUpdate(sb);
                entityManager.createNativeQuery(sb.toString()).executeUpdate();
                i = 0;
                sb = into();
            }
        }
        if (i > 0) {
            //以上sql,添加存在则修改
            setUpdate(sb);
            entityManager.createNativeQuery(sb.toString()).executeUpdate();
        }
    }


    @Override
    public void deleteAll() {
        entityManager.createQuery("delete from " + entityName + " en ")
                .executeUpdate();
    }

    @Override
    public List<T> findAllByIds(List<ID> ids) {
        List tList=entityManager.createQuery("select en from " + entityName + " en where en." + id + " in :ids")
                .setParameter("ids", ids)
                .getResultList();
        return tList;
    }

    @Override
    public boolean deleteAllByIds(List list) {
        SQLDelete sqlDelete = this.getDomainClass().getAnnotation(SQLDelete.class);
        if (sqlDelete != null) {
            String sql = sqlDelete.sql();
            if (sql != null && !sql.isEmpty()) {
                System.out.println(sql);
            }
        }

        return false;
    }

    @Override
    public boolean deleteAll(List<T> list) {
        List<ID>ids=new ArrayList<>();
        //获取list的id
        for (T t : list) {
            try{
                ids.add((ID) idField.get(t));
            }
            catch (Exception e){
                logger.error(e.getMessage());
                return false;
            }
        }
        this.deleteAllByIdIn(ids);
        return true;
    }

    @Override
    public void deleteAllByIdIn(List<ID> ids) {
        BatchSQLDelete sqlDelete = tClass.getAnnotation(BatchSQLDelete.class);
        if (sqlDelete == null) {
            entityManager.createQuery("delete from " + entityName + " en  where en."+ id +" in (:ids)")
                    .setParameter("ids", ids)
                    .executeUpdate();
        } else {
            if (sqlDelete.nativeQuery()) {
                entityManager.createNativeQuery(sqlDelete.value())
                        .setParameter(1, ids)
                        .executeUpdate();
            } else {
                entityManager.createQuery(sqlDelete.value())
                        .setParameter(1, ids)
                        .executeUpdate();
            }
        }
    }

    @Override
    public Optional<T> findTopByIdNotNull() {
        Optional<T> tOptional=Optional.of(
                (T)
                entityManager.createQuery("select en from "+entityName+ " en where "+ idField.getName() +" is not null ")
                        .setMaxResults(1)
                        .getSingleResult()
        );
        return tOptional;
    }


    private String getId() {
        for (Field field : fields) {
            if (field.getAnnotation(Id.class) != null) {
                id = field.getName();
                idField=field;
                break;
            }
        }
        return id;
    }


    //获取表名
    private String getTableName() {
        String tableName = tClass.getSimpleName();
        tableName = toUnderlineName(tableName);
        if (tClass.getAnnotation(Table.class) != null) {
            Table table = tClass.getAnnotation(Table.class);
            tableName = table.name();
        }
        return tableName;
    }

    //获取表名
    private String getEntityName() {
        String entityName = tClass.getSimpleName();
        if (tClass.getAnnotation(Entity.class) != null) {
            Entity table = tClass.getAnnotation(Entity.class);
            entityName = table.name();
        }
        return entityName;
    }


    //驼峰转下划线，首位大写忽略
    private static String toUnderlineName(String s) {
        if (s == null) {
            return null;
        }
        StringBuilder sb = new StringBuilder();
        boolean upperCase = false;
        for (int i = 0; i < s.length(); i++) {
            char c = s.charAt(i);
            boolean nextUpperCase = true;
            if (i < (s.length() - 1)) {
                nextUpperCase = Character.isUpperCase(s.charAt(i + 1));
            }
            if ((i >= 0) && Character.isUpperCase(c)) {
                if (!upperCase || !nextUpperCase) {
                    if (i > 0) sb.append("_");
                }
                upperCase = true;
            } else {
                upperCase = false;
            }
            sb.append(Character.toLowerCase(c));
        }
        return sb.toString();
    }


    private boolean isIgnore(Field field) {
        return field.getAnnotation(Transient.class) != null
                || field.getAnnotation(OneToMany.class) != null
                || field.getAnnotation(ManyToMany.class) != null
                || field.getAnnotation(OneToOne.class) != null
                || field.getAnnotation(ManyToOne.class) != null;
    }


    private StringBuilder into() {
        StringBuilder sb = new StringBuilder();
        //tableName为表名，clazz字段为数据库字段，生成insert
        sb.append("insert into " + tableName + "(");
        for (int i = 0; i < fields.length; i++) {
            //排除Transient注解的字段
            if (isIgnore(fields[i])) {
                continue;
            }
            if (i > 0) {
                sb.append(",");
            }
            //如果使用@Column注解，则从注解取值
            if (tClass.getAnnotation(Column.class) != null) {
                Column column =fields[i].getAnnotation(Column.class);
                if (StringUtils.isNotBlank(column.name())){
                    sb.append(column.name());
                    continue;
                }
            }
            sb.append(toUnderlineName(fields[i].getName()));
        }
        sb.append(") values");
        return sb;
    }

    private StringBuilder setUpdate(StringBuilder sb) {
        sb.append(" ON DUPLICATE KEY UPDATE ");
        for (Field field : fields) {
            field.setAccessible(true);
            if (isIgnore(field)) {
                continue;
            }
            try {
                //字段存在@Id注解跳过
                if (field.getAnnotation(Id.class) != null) {
                    continue;
                }
                if (!sb.toString().endsWith("UPDATE ")) {
                    sb.append(",");
                }
                //如果使用@Column注解，则从注解取值
                if (field.getAnnotation(Column.class) != null) {
                    Column column = field.getAnnotation(Column.class);
                    if (StringUtils.isNotBlank(column.name())){
                        sb.append(column.name()).append("=VALUES(").append(column.name()).append(")");
                        continue;
                    }
                }
                sb.append(toUnderlineName(field.getName())).append("=VALUES(").append(toUnderlineName(field.getName())).append(")");
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        return sb;
    }



}
