/*
 * Decompiled with CFR 0.152.
 */
package org.apache.wayang.tensorflow.operators;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import org.apache.wayang.basic.operators.PredictOperator;
import org.apache.wayang.core.optimizer.OptimizationContext;
import org.apache.wayang.core.plan.wayangplan.ExecutionOperator;
import org.apache.wayang.core.platform.ChannelDescriptor;
import org.apache.wayang.core.platform.ChannelInstance;
import org.apache.wayang.core.platform.lineage.ExecutionLineageNode;
import org.apache.wayang.core.types.DataSetType;
import org.apache.wayang.core.util.Tuple;
import org.apache.wayang.java.channels.CollectionChannel;
import org.apache.wayang.tensorflow.channels.TensorChannel;
import org.apache.wayang.tensorflow.execution.TensorflowExecutor;
import org.apache.wayang.tensorflow.model.TensorflowModel;
import org.apache.wayang.tensorflow.operators.TensorflowExecutionOperator;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.types.family.TType;

public class TensorflowPredictOperator
extends PredictOperator<NdArray<?>, NdArray<?>>
implements TensorflowExecutionOperator {
    public TensorflowPredictOperator() {
        super(DataSetType.createDefaultUnchecked(NdArray.class), DataSetType.createDefaultUnchecked(NdArray.class));
    }

    public TensorflowPredictOperator(PredictOperator<NdArray<?>, NdArray<?>> that) {
        super(that);
    }

    public List<ChannelDescriptor> getSupportedInputChannels(int index) {
        if (index == 0) {
            return Collections.singletonList(CollectionChannel.DESCRIPTOR);
        }
        return Collections.singletonList(TensorChannel.DESCRIPTOR);
    }

    public List<ChannelDescriptor> getSupportedOutputChannels(int index) {
        return Collections.singletonList(TensorChannel.DESCRIPTOR);
    }

    @Override
    public Tuple<Collection<ExecutionLineageNode>, Collection<ChannelInstance>> evaluate(ChannelInstance[] inputs, ChannelInstance[] outputs, TensorflowExecutor tensorflowExecutor, OptimizationContext.OperatorContext operatorContext) {
        assert (inputs.length == this.getNumInputs());
        assert (outputs.length == this.getNumOutputs());
        CollectionChannel.Instance inputModel = (CollectionChannel.Instance)inputs[0];
        TensorChannel.Instance inputData = (TensorChannel.Instance)inputs[1];
        TensorChannel.Instance output = (TensorChannel.Instance)outputs[0];
        TensorflowModel model = (TensorflowModel)inputModel.provideCollection().iterator().next();
        Object data = inputData.provideTensor();
        TType predicted = (TType)model.predict(data);
        tensorflowExecutor.addResource((AutoCloseable)predicted);
        output.accept((NdArray)predicted);
        return ExecutionOperator.modelEagerExecution((ChannelInstance[])inputs, (ChannelInstance[])outputs, (OptimizationContext.OperatorContext)operatorContext);
    }
}

