/*
 * Decompiled with CFR 0.152.
 */
package org.apache.wayang.tensorflow.model;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.wayang.basic.model.DLModel;
import org.apache.wayang.basic.model.op.Input;
import org.apache.wayang.basic.model.op.Op;
import org.apache.wayang.basic.model.optimizer.Optimizer;
import org.apache.wayang.tensorflow.model.Convertor;
import org.tensorflow.ExecutionEnvironment;
import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.Result;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.index.Index;
import org.tensorflow.ndarray.index.Indices;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.types.TBool;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.family.TType;

public class TensorflowModel
extends DLModel
implements AutoCloseable {
    private final Op criterion;
    private final Optimizer optimizer;
    private final Op accuracyCalculation;
    private final Graph graph;
    private final Ops tf;
    private final Placeholder<TBool> trainingMode;
    private final Session session;
    private final Map<Integer, Operand<?>> opMap;
    private final org.tensorflow.op.Op minimize;

    public TensorflowModel(DLModel model, Op criterion, Optimizer optimizer, Op accuracyCalculation) {
        super(model.getOut());
        this.criterion = criterion;
        this.optimizer = optimizer;
        this.accuracyCalculation = accuracyCalculation;
        this.graph = new Graph();
        this.tf = Ops.create((ExecutionEnvironment)this.graph);
        this.trainingMode = this.tf.placeholder(TBool.class, new Placeholder.Options[]{Placeholder.shape((Shape)Shape.scalar())});
        this.session = new Session(this.graph);
        this.opMap = new HashMap();
        this.compile(criterion);
        if (accuracyCalculation != null) {
            this.compile(accuracyCalculation);
        }
        this.minimize = Convertor.convert(this.graph, optimizer).minimize(this.opMap.get(criterion.getId()));
    }

    private Operand<?> compile(Op op) {
        List inputs = op.getFromList().stream().map(e -> {
            Operand<?> operand = this.opMap.get(e.getId());
            if (operand == null) {
                operand = this.compile((Op)e);
            }
            return operand;
        }).collect(Collectors.toList());
        inputs.add(this.trainingMode);
        Operand<?> ret = Convertor.convert(this.graph, this.tf, op, (Operand[])inputs.toArray(Operand[]::new));
        this.opMap.put(op.getId(), ret);
        return ret;
    }

    public <XT extends NdArray<?>, YT extends NdArray<?>> void train(XT x, YT y, int epoch, int batchSize) {
        System.out.println("Start training:");
        for (int i = 0; i < epoch; ++i) {
            int n = (int)y.shape().size(0);
            for (int start = 0; start < n; start += batchSize) {
                int end = Math.min(start + batchSize, n);
                NdArray x_ = x.slice(new Index[]{Indices.slice((long)start, (long)end)});
                NdArray y_ = y.slice(new Index[]{Indices.slice((long)start, (long)end)});
                try (Tensor tx = Convertor.ndArrayToTensor(x_);
                     Tensor ty = Convertor.ndArrayToTensor(y_);){
                    Session.Runner runner = this.session.runner().feed(Input.Type.FEATURES.getName(), tx).feed(Input.Type.LABEL.getName(), ty).feed(this.trainingMode, (Tensor)TBool.scalarOf((boolean)true)).addTarget(this.minimize).fetch(this.criterion.getName());
                    if (this.accuracyCalculation != null) {
                        runner.fetch(this.accuracyCalculation.getName());
                    }
                    try (Result ret = runner.run();){
                        TFloat32 loss = (TFloat32)ret.get(0);
                        System.out.printf("[epoch %d, batch %d] loss: %f ", i + 1, start / batchSize + 1, Float.valueOf(loss.getFloat(new long[0])));
                        if (this.accuracyCalculation != null) {
                            TFloat32 acc = (TFloat32)ret.get(1);
                            System.out.printf("accuracy: %f ", Float.valueOf(acc.getFloat(new long[0])));
                        }
                    }
                    System.out.println();
                    continue;
                }
            }
        }
        System.out.println("Finish training.\n");
    }

    public <XT extends NdArray<?>, PT extends NdArray<?> & TType> PT predict(XT x) {
        try (Tensor tx = Convertor.ndArrayToTensor(x);){
            Tensor predicted = this.session.runner().feed(Input.Type.FEATURES.getName(), tx).feed(this.trainingMode, (Tensor)TBool.scalarOf((boolean)false)).fetch(this.out.getName()).run().get(0);
            NdArray ndArray = (NdArray)predicted;
            return (PT)ndArray;
        }
    }

    public Op getCriterion() {
        return this.criterion;
    }

    public Optimizer getOptimizer() {
        return this.optimizer;
    }

    public Op getAccuracyCalculation() {
        return this.accuracyCalculation;
    }

    @Override
    public void close() {
        this.session.close();
        this.graph.close();
    }
}

