/*
 * Decompiled with CFR 0.152.
 */
package org.apache.wayang.tensorflow.channels;

import java.util.OptionalLong;
import org.apache.wayang.core.optimizer.OptimizationContext;
import org.apache.wayang.core.plan.executionplan.Channel;
import org.apache.wayang.core.plan.wayangplan.OutputSlot;
import org.apache.wayang.core.platform.AbstractChannelInstance;
import org.apache.wayang.core.platform.ChannelDescriptor;
import org.apache.wayang.core.platform.ChannelInstance;
import org.apache.wayang.core.platform.Executor;
import org.apache.wayang.tensorflow.execution.TensorflowExecutor;
import org.tensorflow.ndarray.NdArray;

public class TensorChannel
extends Channel {
    public static final ChannelDescriptor DESCRIPTOR = new ChannelDescriptor(TensorChannel.class, false, false);

    public TensorChannel(OutputSlot<?> producerSlot) {
        super(DESCRIPTOR, producerSlot);
    }

    public TensorChannel(ChannelDescriptor descriptor, OutputSlot<?> producerSlot) {
        super(descriptor, producerSlot);
        assert (descriptor == DESCRIPTOR);
    }

    private TensorChannel(TensorChannel parent) {
        super((Channel)parent);
    }

    public Channel copy() {
        return new TensorChannel(this);
    }

    public ChannelInstance createInstance(Executor executor, OptimizationContext.OperatorContext producerOperatorContext, int producerOutputIndex) {
        return new Instance((TensorflowExecutor)executor, producerOperatorContext, producerOutputIndex);
    }

    public class Instance
    extends AbstractChannelInstance {
        private NdArray<?> tensor;
        private long cardinality;

        public Instance(TensorflowExecutor executor, OptimizationContext.OperatorContext producerOperatorContext, int producerOutputIndex) {
            super((Executor)executor, producerOperatorContext, producerOutputIndex);
            this.cardinality = 0L;
        }

        public void accept(NdArray<?> tensor) {
            assert (this.tensor == null);
            this.tensor = tensor;
            if (this.isMarkedForInstrumentation()) {
                this.cardinality = tensor.shape().size(0);
            }
        }

        public <T extends NdArray<?>> T provideTensor() {
            return (T)this.tensor;
        }

        public Channel getChannel() {
            return TensorChannel.this;
        }

        public OptionalLong getMeasuredCardinality() {
            return this.cardinality == 0L ? super.getMeasuredCardinality() : OptionalLong.of(this.cardinality);
        }

        protected void doDispose() throws Throwable {
            this.tensor = null;
        }
    }
}

