/*
 * Decompiled with CFR 0.152.
 */
package de.datexis.encoder.impl;

import de.datexis.common.Resource;
import de.datexis.encoder.Encoder;
import de.datexis.encoder.impl.RESTAdapter;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.model.Token;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class AbstractRESTEncoder
extends Encoder {
    private static final Logger log = LoggerFactory.getLogger(AbstractRESTEncoder.class);
    private RESTAdapter restAdapter;
    private String vectorIdentifier;

    public AbstractRESTEncoder(RESTAdapter restAdapter) {
        this(restAdapter, null);
    }

    public AbstractRESTEncoder(RESTAdapter restAdapter, String vectorIdentifier) {
        this.restAdapter = restAdapter;
        this.vectorIdentifier = vectorIdentifier;
    }

    public long getEmbeddingVectorSize() {
        return this.restAdapter.getEmbeddingVectorSize();
    }

    public void trainModel(Collection<Document> documents) {
        throw new UnsupportedOperationException("REST Encoders are not trainable");
    }

    public void trainModel(Stream<Document> documents) {
        throw new UnsupportedOperationException("REST Encoders are not trainable");
    }

    public void loadModel(Resource file) throws IOException {
        throw new UnsupportedOperationException("REST Encoders cant load a model");
    }

    public void saveModel(Resource dir, String name) throws IOException {
        throw new UnsupportedOperationException("REST Encoders cant save a model");
    }

    public INDArray encodeValue(String value) throws IOException {
        return Nd4j.create((double[])this.restAdapter.encode(value), (long[])new long[]{this.getEmbeddingVectorSize(), 1L});
    }

    public List<List<Token>> getTokensOfSentencesOfDocument(Document document) {
        return document.streamSentences().map(Sentence::getTokens).collect(Collectors.toList());
    }

    public Stream<Stream<Token>> streamTokensOfSentencesOfDocument(Document document) {
        return document.streamSentences().map(Sentence::streamTokens);
    }

    public <S> Stream<Stream<S>> streamSpans2D(List<? extends List<S>> spans2D) {
        return spans2D.stream().map(Collection::stream);
    }

    public <S extends Span> void encodeEach(S span) throws IOException {
        this.encodeEach(span, Span::getText);
    }

    public <S extends Span> void encodeEach(S span, Function<S, String> getText) throws IOException {
        String text = getText.apply(span);
        double[] embedding = this.restAdapter.encode(text);
        this.putVectorInSpan(span, embedding);
    }

    public <S extends Span> void encodeEach1D(List<S> spans) throws IOException {
        this.encodeEach1D(spans, Span::getText);
    }

    public <S extends Span> void encodeEach1D(List<S> spans, Function<S, String> getText) throws IOException {
        String[] spansAsStringArray1D = this.spansToStringArray1D(spans.stream(), getText);
        double[][] embedding = this.restAdapter.encode(spansAsStringArray1D);
        this.putVectorInSpans(spans.stream(), embedding);
    }

    public <S extends Span> void encodeEach2D(List<? extends List<S>> spans) throws IOException {
        this.encodeEach2D(spans, Span::getText);
    }

    public <S extends Span> void encodeEach2D(List<? extends List<S>> spans, Function<S, String> getText) throws IOException {
        String[][] spansAsStringArray2D = this.spansToStringArray2D(this.streamSpans2D(spans), getText);
        double[][][] embedding = this.restAdapter.encode(spansAsStringArray2D);
        this.putVectorInSpans(this.streamSpans2D(spans), embedding);
    }

    public <S extends Span> String[] spansToStringArray1D(Stream<S> spans) {
        return this.spansToStringArray1D(spans, Span::getText);
    }

    public <S extends Span> String[] spansToStringArray1D(Stream<S> spans, Function<S, String> getText) {
        return (String[])spans.map(getText).toArray(String[]::new);
    }

    public <S extends Span> String[][] spansToStringArray2D(Stream<? extends Stream<S>> spans) {
        return (String[][])spans.map(this::spansToStringArray1D).toArray(x$0 -> new String[x$0][]);
    }

    public <S extends Span> String[][] spansToStringArray2D(Stream<? extends Stream<S>> spans, Function<S, String> getText) {
        return (String[][])spans.map(span -> this.spansToStringArray1D((Stream)span, getText)).toArray(x$0 -> new String[x$0][]);
    }

    public <S extends Span> void putVectorInSpan(S span, double[] data) {
        if (this.vectorIdentifier == null) {
            span.putVector(((Object)((Object)this)).getClass(), Nd4j.create((double[])data, (long[])new long[]{this.getEmbeddingVectorSize(), 1L}));
        } else {
            span.putVector(this.vectorIdentifier, Nd4j.create((double[])data, (long[])new long[]{this.getEmbeddingVectorSize(), 1L}));
        }
    }

    public <S extends Span> void putVectorInSpans(Stream<S> spans, double[][] data) {
        AtomicInteger i = new AtomicInteger();
        spans.forEach(span -> this.putVectorInSpan(span, data[i.getAndIncrement()]));
    }

    public <S extends Span> void putVectorInSpans(Stream<? extends Stream<S>> spans, double[][][] data) {
        AtomicInteger i = new AtomicInteger();
        spans.forEach(span -> this.putVectorInSpans((Stream)span, data[i.getAndIncrement()]));
    }
}

