package cn.zhangfusheng.elasticsearch.template;

import cn.zhangfusheng.elasticsearch.annotation.dsl.DslIndex;
import cn.zhangfusheng.elasticsearch.constant.ElasticSearchConstant;
import cn.zhangfusheng.elasticsearch.exception.GlobalSystemException;
import cn.zhangfusheng.elasticsearch.model.page.PageRequest;
import cn.zhangfusheng.elasticsearch.scan.ElasticSearchEntityRepositoryDetail;
import cn.zhangfusheng.elasticsearch.thread.ThreadLocalDetail;
import org.apache.commons.lang3.StringUtils;
import org.elasticsearch.action.search.ClearScrollRequest;
import org.elasticsearch.action.search.ClearScrollResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchScrollRequest;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.core.CountRequest;
import org.elasticsearch.client.core.CountResponse;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.sort.SortOrder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;

import java.io.IOException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;

/**
 * @author fusheng.zhang
 * @date 2022-04-26 10:56:23
 */
public interface ElasticSearchTemplateApi extends Template {

    Logger log = LoggerFactory.getLogger(ElasticSearchTemplateApi.class);

    default Object search(
            ElasticSearchEntityRepositoryDetail entityRepositoryDetail, Method method,
            SearchRequest searchRequest, PageRequest pageRequest) throws IOException {
        List<SearchHit> searchHits = new ArrayList<>();
        Consumer<SearchHit[]> consumer = hs -> searchHits.addAll(Arrays.asList(hs));
        SearchResponse searchResponse = Objects.nonNull(pageRequest)
                ? this.searchWithPage(searchRequest, pageRequest, consumer)
                : this.search(searchRequest, consumer);
        return entityRepositoryDetail.analysisSearchResponse(searchResponse, searchHits, method);
    }

    /**
     * 执行 searchRequest
     * @param searchRequest
     * @param consumer
     * @return
     * @throws IOException
     */
    default SearchResponse search(SearchRequest searchRequest, Consumer<SearchHit[]> consumer) throws IOException {
        boolean trackTotalHits = ThreadLocalDetail.trackTotalHits();
        long totalNum = ElasticSearchConstant.MAX_DOC_SIZE.longValue();
        // 查询全部数据
        if (trackTotalHits) {
            // 获取符合全部数据的总条数
            CountRequest countRequest =
                    new CountRequest(searchRequest.indices(), searchRequest.source()).routing(searchRequest.routing());
            totalNum = this.count(countRequest);
            // 总条数大于最大查询数据则启用滚动查询
            if (totalNum > ElasticSearchConstant.MAX_DOC_SIZE) {
                // 配置滚动查询每次查询的数量
                Integer scrollSize = ThreadLocalDetail.scrollSize().orElse(ElasticSearchConstant.THRESHOLD_DOC_SIZE);
                searchRequest.source().size(scrollSize);
                return this.searchWitchScroll(searchRequest, consumer, -1, null);
            }
        }
        SearchResponse searchResponse = null;
        // 设置最大查询数量
        searchRequest.source().size(Long.valueOf(totalNum).intValue());
        log.debug("index:{},routing:{},queryJson:{}", searchRequest.indices(), searchRequest.routing(), searchRequest.source());
        // 当总查询数量大于 THRESHOLD_DOC_SIZE 的2倍时,启用分批次,循环查询
        if (totalNum > ElasticSearchConstant.THRESHOLD_DOC_SIZE * 2) {
            searchRequest.source().size(ElasticSearchConstant.THRESHOLD_DOC_SIZE);
            long loopNum = totalNum % ElasticSearchConstant.THRESHOLD_DOC_SIZE == 0
                    ? ElasticSearchConstant.DEFAULT_LOOP_PAGE_NUM
                    : totalNum / ElasticSearchConstant.THRESHOLD_DOC_SIZE + 1;
            for (int i = 0; i < loopNum; i++) {
                searchRequest.source().from(i * ElasticSearchConstant.THRESHOLD_DOC_SIZE);
                log.debug("index:{},routing:{},queryJson:{}", searchRequest.indices(), searchRequest.routing(), searchRequest.source());
                searchResponse = restHighLevelClient().search(searchRequest, RequestOptions.DEFAULT);
                SearchHit[] hits = searchResponse.getHits().getHits();
                if (Objects.nonNull(hits) && hits.length > 0) {
                    consumer.accept(hits);
                    // 返回值的长度小于分页的大小,没有下一页,break;
                    if (hits.length < ElasticSearchConstant.THRESHOLD_DOC_SIZE) break;
                } else {
                    break;
                }
            }
        } else {
            searchResponse = restHighLevelClient().search(searchRequest, RequestOptions.DEFAULT);
            SearchHit[] hits = searchResponse.getHits().getHits();
            if (Objects.nonNull(hits) && hits.length > 0) consumer.accept(hits);
        }
        return searchResponse;
    }

    /**
     * 分页查询
     * @param searchRequest
     * @param pageRequest
     * @param consumer
     * @return
     */
    default SearchResponse searchWithPage(SearchRequest searchRequest, PageRequest pageRequest, Consumer<SearchHit[]> consumer) {
        try {
            // 设置分页大小
            searchRequest.source().size(pageRequest.getSize());
            //
            if (pageRequest.getFrom() + pageRequest.getSize() > ElasticSearchConstant.MAX_DOC_SIZE) {
                if (Objects.isNull(pageRequest.getSearchAfter()) || pageRequest.getSearchAfter().length <= 0) {
                    throw new GlobalSystemException("超出最大查询数据限制,必须设置searchAfter参数");
                }
            }
            if (Objects.nonNull(pageRequest.getSearchAfter()) && pageRequest.getSearchAfter().length > 0) {
                searchRequest.source().searchAfter(pageRequest.getSearchAfter());
            } else {
                searchRequest.source().from(pageRequest.getFrom());
            }
            //  添加_id 排序 统计全部条数
            if (CollectionUtils.isEmpty(searchRequest.source().sorts())) {
                searchRequest.source().sort(ElasticSearchConstant.SORT_ID, SortOrder.DESC);
            }
            // 设置track total hots
            boolean trackTotalHits = ThreadLocalDetail.trackTotalHits();
            if (trackTotalHits) searchRequest.source().trackTotalHits(Boolean.TRUE);
            //
            log.debug("index:{},routing:{},queryJson:{}", searchRequest.indices(), searchRequest.routing(), searchRequest.source());
            SearchResponse searchResponse = restHighLevelClient().search(searchRequest, RequestOptions.DEFAULT);
            consumer.accept(searchResponse.getHits().getHits());
            return searchResponse;
        } catch (IOException e) {
            throw new GlobalSystemException(e);
        }
    }

    /**
     * 滚动查询
     * @param searchRequest
     * @param consumer
     * @param runNum        执行次数 <0:不限次数
     * @param scrollId      游标id
     */
    default SearchResponse searchWitchScroll(
            SearchRequest searchRequest, Consumer<SearchHit[]> consumer, int runNum, String scrollId) {
        SearchResponse searchResponse = null;
        try {
            // 滚动查询缓存时长
            Long keepAlive = ThreadLocalDetail.keepAlive().orElse(5000L);
            TimeValue timeValue = TimeValue.timeValueMillis(keepAlive);
            SearchScrollRequest scrollRequest = new SearchScrollRequest(scrollId);
            if (StringUtils.isBlank(scrollId)) {
                searchRequest.scroll(timeValue);
                log.debug("index:{},routing:{},queryJson:{}", searchRequest.indices(), searchRequest.routing(), searchRequest.source());
                searchResponse = restHighLevelClient().search(searchRequest, RequestOptions.DEFAULT);
            } else {
                log.debug("index:{},routing:{},searchScrollId:{}", searchRequest.indices(), searchRequest.routing(), scrollRequest.getDescription());
                searchResponse = restHighLevelClient().scroll(scrollRequest, RequestOptions.DEFAULT);
            }
            SearchHit[] hits = searchResponse.getHits().getHits();
            int num = 1;
            while (hits != null && hits.length > 0) {
                consumer.accept(hits);
                if (Objects.isNull(searchResponse.getScrollId()) || num++ == runNum) break;
                scrollRequest.scrollId(searchResponse.getScrollId()).scroll(timeValue);
                log.debug("searchScrollId:{}", scrollRequest.getDescription());
                searchResponse = restHighLevelClient().scroll(scrollRequest, RequestOptions.DEFAULT);
                hits = searchResponse.getHits().getHits();
            }
        } catch (Exception e) {
            throw new GlobalSystemException(e);
        } finally {
            if (Objects.nonNull(searchResponse)) this.clearScroll(searchResponse.getScrollId());
        }
        return searchResponse;
    }

    /**
     * 清除滚动查询
     * @param scrollId
     */
    default void clearScroll(String scrollId) {
        if (StringUtils.isNotBlank(scrollId)) {
            ClearScrollRequest clearScrollRequest = new ClearScrollRequest();
            clearScrollRequest.addScrollId(scrollId);
            try {
                ClearScrollResponse clearScrollResponse =
                        restHighLevelClient().clearScroll(clearScrollRequest, RequestOptions.DEFAULT);
                log.debug("clear scroll:{} success?{}", scrollId, clearScrollResponse.isSucceeded());
            } catch (IOException e) {
                log.error(e.getMessage(), e);
            }
        }
    }

    /**
     * 统计数据
     * @param countRequest
     * @return
     */
    default long count(CountRequest countRequest) {
        try {
            CountResponse countResponse = restHighLevelClient().count(countRequest, RequestOptions.DEFAULT);
            return countResponse.getCount();
        } catch (IOException e) {
            throw new GlobalSystemException(e);
        }
    }

    /**
     * 获取 index
     * @param method
     * @param args
     * @param index
     * @return
     */
    default String[] analysisIndex(Method method, Object[] args, String index) {
        if (ElasticSearchConstant.METHOD_INDEX_CACHE.containsKey(method)) {
            return ElasticSearchConstant.METHOD_INDEX_CACHE.get(method);
        }
        synchronized (ElasticSearchConstant.METHOD_INDEX_CACHE) {
            if (ElasticSearchConstant.METHOD_INDEX_CACHE.containsKey(method)) {
                return ElasticSearchConstant.METHOD_INDEX_CACHE.get(method);
            } else {
                List<String> indices = new ArrayList<>(args.length + 1);
                if (Objects.nonNull(method)) {
                    DslIndex dslIndex = method.getAnnotation(DslIndex.class);
                    if (Objects.nonNull(dslIndex)) analysisIndex(dslIndex, indices);
                }
                if (CollectionUtils.isEmpty(indices)) indices.add(index);
                ElasticSearchConstant.METHOD_INDEX_CACHE.put(method, indices.stream().distinct().toArray(String[]::new));
            }
        }
        return ElasticSearchConstant.METHOD_INDEX_CACHE.get(method);
    }

    /**
     * 解析 DslIndex,获取 index
     * @param dslIndex
     * @param indices
     */
    default void analysisIndex(DslIndex dslIndex, List<String> indices) {
        Class<?>[] value = dslIndex.value();
        for (Class<?> indexClass : value) {
            if (indexClass.equals(Void.class)) continue;
            if (!ElasticSearchConstant.REPOSITORY_DETAIL_CACHE.containsKey(indexClass)) {
                throw new GlobalSystemException("根据{},未获取到对应的索引", indexClass.getName());
            }
            indices.add(ElasticSearchConstant.REPOSITORY_DETAIL_CACHE.get(indexClass).getIndexName());
        }
    }

    default String[] analysisIndex(Object[] args, String... index) {
        if (Objects.isNull(args)) return index;
        List<String> indices = new ArrayList<>(args.length);
        for (Object arg : args) {
            DslIndex dslIndex = arg.getClass().getAnnotation(DslIndex.class);
            if (Objects.nonNull(dslIndex)) this.analysisIndex(dslIndex, indices);
        }
        if (CollectionUtils.isEmpty(indices)) return index;
        return indices.stream().distinct().toArray(String[]::new);
    }
}
