package ai.yda.framework.rag.core;

import ai.yda.framework.rag.core.augmenter.Augmenter;
import ai.yda.framework.rag.core.generator.StreamingGenerator;
import ai.yda.framework.rag.core.model.RagContext;
import ai.yda.framework.rag.core.model.RagRequest;
import ai.yda.framework.rag.core.model.RagResponse;
import ai.yda.framework.rag.core.retriever.Retriever;
import ai.yda.framework.rag.core.util.ContentUtil;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import lombok.Generated;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/* loaded from: input_file:ai/yda/framework/rag/core/DefaultStreamingRag.class */
public class DefaultStreamingRag implements StreamingRag<RagRequest, RagResponse> {
    private final List<Retriever<RagRequest, RagContext>> retrievers;
    private final List<Augmenter<RagRequest, RagContext>> augmenters;
    private final StreamingGenerator<RagRequest, RagResponse> generator;

    public DefaultStreamingRag(List<Retriever<RagRequest, RagContext>> list, List<Augmenter<RagRequest, RagContext>> list2, StreamingGenerator<RagRequest, RagResponse> streamingGenerator) {
        this.retrievers = list;
        this.augmenters = list2;
        this.generator = streamingGenerator;
    }

    @Override // ai.yda.framework.rag.core.StreamingRag
    public Flux<RagResponse> streamRag(RagRequest ragRequest) {
        return Flux.fromStream(this.retrievers.parallelStream()).flatMap(retriever -> {
            return Mono.fromCallable(() -> {
                return retriever.retrieve(ragRequest);
            });
        }).collectList().flatMap(list -> {
            Iterator<Augmenter<RagRequest, RagContext>> it = this.augmenters.iterator();
            while (it.hasNext()) {
                list = it.next().augment(ragRequest, list);
            }
            return Mono.just(list);
        }).flatMap(this::mergeContexts).flatMapMany(str -> {
            return this.generator.streamGeneration(ragRequest, str);
        });
    }

    protected Mono<String> mergeContexts(List<RagContext> list) {
        return Flux.fromStream(list.parallelStream()).map(ragContext -> {
            return String.join(ContentUtil.SENTENCE_SEPARATOR, ragContext.getKnowledge());
        }).collect(Collectors.joining(ContentUtil.SENTENCE_SEPARATOR));
    }

    @Generated
    protected List<Retriever<RagRequest, RagContext>> getRetrievers() {
        return this.retrievers;
    }

    @Generated
    protected List<Augmenter<RagRequest, RagContext>> getAugmenters() {
        return this.augmenters;
    }

    @Generated
    protected StreamingGenerator<RagRequest, RagResponse> getGenerator() {
        return this.generator;
    }
}
