/*
 * Decompiled with CFR 0.152.
 */
package org.apache.wayang.tensorflow.operators;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import org.apache.wayang.basic.model.DLModel;
import org.apache.wayang.basic.operators.DLTrainingOperator;
import org.apache.wayang.core.optimizer.OptimizationContext;
import org.apache.wayang.core.plan.wayangplan.ExecutionOperator;
import org.apache.wayang.core.platform.ChannelDescriptor;
import org.apache.wayang.core.platform.ChannelInstance;
import org.apache.wayang.core.platform.lineage.ExecutionLineageNode;
import org.apache.wayang.core.types.DataSetType;
import org.apache.wayang.core.util.Tuple;
import org.apache.wayang.java.channels.CollectionChannel;
import org.apache.wayang.tensorflow.channels.TensorChannel;
import org.apache.wayang.tensorflow.execution.TensorflowExecutor;
import org.apache.wayang.tensorflow.model.TensorflowModel;
import org.apache.wayang.tensorflow.operators.TensorflowExecutionOperator;
import org.tensorflow.ndarray.NdArray;

public class TensorflowDLTrainingOperator
extends DLTrainingOperator<NdArray<?>, NdArray<?>>
implements TensorflowExecutionOperator {
    public TensorflowDLTrainingOperator(DLModel model, DLTrainingOperator.Option option) {
        super(model, option, DataSetType.createDefaultUnchecked(NdArray.class), DataSetType.createDefaultUnchecked(NdArray.class));
    }

    public TensorflowDLTrainingOperator(DLTrainingOperator<NdArray<?>, NdArray<?>> that) {
        super(that);
    }

    public List<ChannelDescriptor> getSupportedInputChannels(int index) {
        return Collections.singletonList(TensorChannel.DESCRIPTOR);
    }

    public List<ChannelDescriptor> getSupportedOutputChannels(int index) {
        return Collections.singletonList(CollectionChannel.DESCRIPTOR);
    }

    @Override
    public Tuple<Collection<ExecutionLineageNode>, Collection<ChannelInstance>> evaluate(ChannelInstance[] inputs, ChannelInstance[] outputs, TensorflowExecutor tensorflowExecutor, OptimizationContext.OperatorContext operatorContext) {
        assert (inputs.length == this.getNumInputs());
        assert (outputs.length == this.getNumOutputs());
        TensorChannel.Instance x = (TensorChannel.Instance)inputs[0];
        TensorChannel.Instance y = (TensorChannel.Instance)inputs[1];
        CollectionChannel.Instance output = (CollectionChannel.Instance)outputs[0];
        TensorflowModel tfModel = new TensorflowModel(this.model, this.option.getCriterion(), this.option.getOptimizer(), this.option.getAccuracyCalculation());
        tensorflowExecutor.addResource(tfModel);
        Object xData = x.provideTensor();
        Object yData = y.provideTensor();
        tfModel.train(xData, yData, this.option.getEpoch(), this.option.getBatchSize());
        output.accept(Collections.singletonList(tfModel));
        return ExecutionOperator.modelEagerExecution((ChannelInstance[])inputs, (ChannelInstance[])outputs, (OptimizationContext.OperatorContext)operatorContext);
    }
}

