/*
 * Decompiled with CFR 0.152.
 */
package de.datexis.encoder;

import de.datexis.encoder.AbstractRESTEncoder;
import de.datexis.encoder.RESTAdapter;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.model.Token;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class SimpleRESTEncoder
extends AbstractRESTEncoder {
    private static final Logger log = LoggerFactory.getLogger(SimpleRESTEncoder.class);
    private Class<? extends Span> elementClass;
    private String vectorIdentifier;

    protected SimpleRESTEncoder(String id) {
        super(id);
    }

    public SimpleRESTEncoder(String id, RESTAdapter restAdapter, Class<? extends Span> elementClass) {
        super(id, restAdapter);
        this.elementClass = elementClass;
        this.vectorIdentifier = this.vectorIdentifier;
    }

    public SimpleRESTEncoder(String id, RESTAdapter restAdapter, String vectorIdentifier, Class<? extends Span> elementClass) {
        super(id, restAdapter);
        this.elementClass = elementClass;
        this.vectorIdentifier = vectorIdentifier;
    }

    public INDArray encode(String word) {
        try {
            return this.encodeImpl(word);
        }
        catch (IOException e) {
            log.error("IO Error while encoding word: {}", (Object)word, (Object)e);
            throw new UncheckedIOException(e);
        }
    }

    public abstract INDArray encodeImpl(String var1) throws IOException;

    public INDArray encode(Span span) {
        if (this.elementClass.isInstance(span)) {
            try {
                return this.encodeImpl(span);
            }
            catch (IOException e) {
                log.error("IO Error while encoding span: {}", (Object)span, (Object)e);
                throw new UncheckedIOException(e);
            }
        }
        throw new UnsupportedOperationException();
    }

    public abstract INDArray encodeImpl(Span var1) throws IOException;

    public void encodeEach(Sentence input, Class<? extends Span> elementClass) {
        if (elementClass == this.elementClass) {
            try {
                this.encodeEachImpl(input);
            }
            catch (IOException e) {
                log.error("IO Error while encoding sentence: {}", (Object)input, (Object)e);
                throw new UncheckedIOException(e);
            }
        } else {
            throw new UnsupportedOperationException();
        }
    }

    public abstract void encodeEachImpl(Sentence var1) throws IOException;

    public void encodeEach(Document input, Class<? extends Span> elementClass) {
        if (elementClass == this.elementClass) {
            try {
                this.encodeEachImpl(input);
            }
            catch (IOException e) {
                log.error("IO Error while encoding document: {}", (Object)input.getTitle(), (Object)e);
                throw new UncheckedIOException(e);
            }
        } else {
            throw new UnsupportedOperationException();
        }
    }

    public abstract void encodeEachImpl(Document var1) throws IOException;

    public void encodeEach(Collection<Document> docs, Class<? extends Span> elementClass) {
        if (elementClass == this.elementClass) {
            try {
                this.encodeEachImpl(docs);
            }
            catch (IOException e) {
                log.error("IO Error while encoding documents", (Throwable)e);
                throw new UncheckedIOException(e);
            }
        } else {
            throw new UnsupportedOperationException();
        }
    }

    public abstract void encodeEachImpl(Collection<Document> var1) throws IOException;

    public INDArray encodeValue(String value) throws IOException {
        return Nd4j.create((double[])this.restAdapter.encode(value), (long[])new long[]{this.getEmbeddingVectorSize(), 1L});
    }

    public List<List<Token>> getTokensOfSentencesOfDocument(Document document) {
        return document.streamSentences().map(Sentence::getTokens).collect(Collectors.toList());
    }

    public Stream<Stream<Token>> streamTokensOfSentencesOfDocument(Document document) {
        return document.streamSentences().map(Sentence::streamTokens);
    }

    public <S> Stream<Stream<S>> streamSpans2D(List<? extends List<S>> spans2D) {
        return spans2D.stream().map(Collection::stream);
    }

    public <S extends Span> void encodeEach(S span) throws IOException {
        this.encodeEach(span, Span::getText);
    }

    public <S extends Span> void encodeEach(S span, Function<S, String> getText) throws IOException {
        String text = getText.apply(span);
        double[] embedding = this.restAdapter.encode(text);
        this.putVectorInSpan(span, embedding);
    }

    public <S extends Span> void encodeEach1D(List<S> spans) throws IOException {
        this.encodeEach1D(spans, Span::getText);
    }

    public <S extends Span> void encodeEach1D(List<S> spans, Function<S, String> getText) throws IOException {
        String[] spansAsStringArray1D = this.spansToStringArray1D(spans.stream(), getText);
        double[][] embedding = this.restAdapter.encode(spansAsStringArray1D);
        this.putVectorInSpans(spans.stream(), embedding);
    }

    public <S extends Span> void encodeEach2D(List<? extends List<S>> spans) throws IOException {
        this.encodeEach2D(spans, Span::getText);
    }

    public <S extends Span> void encodeEach2D(List<? extends List<S>> spans, Function<S, String> getText) throws IOException {
        String[][] spansAsStringArray2D = this.spansToStringArray2D(this.streamSpans2D(spans), getText);
        double[][][] embedding = this.restAdapter.encode(spansAsStringArray2D);
        this.putVectorInSpans(this.streamSpans2D(spans), embedding);
    }

    public <S extends Span> String[] spansToStringArray1D(Stream<S> spans) {
        return this.spansToStringArray1D(spans, Span::getText);
    }

    public <S extends Span> String[] spansToStringArray1D(Stream<S> spans, Function<S, String> getText) {
        return (String[])spans.map(getText).toArray(String[]::new);
    }

    public <S extends Span> String[][] spansToStringArray2D(Stream<? extends Stream<S>> spans) {
        return (String[][])spans.map(this::spansToStringArray1D).toArray(x$0 -> new String[x$0][]);
    }

    public <S extends Span> String[][] spansToStringArray2D(Stream<? extends Stream<S>> spans, Function<S, String> getText) {
        return (String[][])spans.map(span -> this.spansToStringArray1D((Stream)span, getText)).toArray(x$0 -> new String[x$0][]);
    }

    public <S extends Span> void putVectorInSpan(S span, double[] data) {
        if (this.vectorIdentifier == null) {
            span.putVector(((Object)((Object)this)).getClass(), Nd4j.create((double[])data, (long[])new long[]{this.getEmbeddingVectorSize(), 1L}));
        } else {
            span.putVector(this.vectorIdentifier, Nd4j.create((double[])data, (long[])new long[]{this.getEmbeddingVectorSize(), 1L}));
        }
    }

    public <S extends Span> void putVectorInSpans(Stream<S> spans, double[][] data) {
        AtomicInteger i = new AtomicInteger();
        spans.forEach(span -> this.putVectorInSpan(span, data[i.getAndIncrement()]));
    }

    public <S extends Span> void putVectorInSpans(Stream<? extends Stream<S>> spans, double[][][] data) {
        AtomicInteger i = new AtomicInteger();
        spans.forEach(span -> this.putVectorInSpans((Stream)span, data[i.getAndIncrement()]));
    }
}

