package ai.h2o.mojos.runtime.h2o3;

import ai.h2o.mojos.runtime.MojoPipeline;
import ai.h2o.mojos.runtime.api.MojoColumnMeta;
import ai.h2o.mojos.runtime.frame.MojoColumn;
import ai.h2o.mojos.runtime.frame.MojoColumnFloat64;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import ai.h2o.mojos.runtime.frame.MojoFrameBuilder;
import ai.h2o.mojos.runtime.frame.MojoFrameMeta;
import ai.h2o.mojos.runtime.frame.StringConverter;
import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import hex.genmodel.easy.prediction.BinomialModelPrediction;
import hex.genmodel.easy.prediction.MultinomialModelPrediction;
import hex.genmodel.easy.prediction.RegressionModelPrediction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.joda.time.DateTime;

/**
 * A MOJO2 pipeline implementation that uses an H2O-3 (or Sparkling Water)
 * MOJO as the predictor inside.  The intent is to provide as identical an
 * experience to the MOJO2 API as possible.
 *
 * A non-goal is to expose every possible low-level H2O-3 MOJO API capability.
 * If you want to do that, call the H2O-3 MOJO API directly, instead.
 */
public class MojoPipelineH2O3Impl extends MojoPipeline {
    private final EasyPredictModelWrapper _model;
    private final GenModel genModel;

    private final MojoFrameMeta _inputMeta;
    private final MojoFrameMeta _outputMeta;

    /**
     * A MOJO2 pipeline implementation that uses an H2O-3 (or Sparkling Water)
     * MOJO as the predictor inside.
     *
     * Must provide a valid Binomial, Multinomial or Regression model.
     * Other model types not currently supported.
     *
     * @param model H2O-3 MOJO model.
     */
    MojoPipelineH2O3Impl(MojoModel model) {
        super(model.getUUID(), new DateTime(1970, 1, 1, 0, 0), "");

        _model = wrapModelForPrediction(model);
        genModel = _model.m;

        switch (genModel.getModelCategory()) {
            case Binomial:
            case Multinomial:
            case Regression:
                break;

            default:
                throw new UnsupportedOperationException("Unsupported ModelCategory: " + genModel.getModelCategory().toString());
        }

        {
            final List<MojoColumnMeta> columns = new ArrayList<>();
            for (int i = 0; i < genModel.getNumCols(); i += 1) {
                final String columnName = genModel.getNames()[i];
                final MojoColumn.Type columnType = (genModel.getDomainValues(i) == null) ? MojoColumn.Type.Float64 : MojoColumn.Type.Str;
                columns.add(MojoColumnMeta.newInput(columnName, columnType));
            }
            _inputMeta = new MojoFrameMeta(columns);
        }

        switch (genModel.getModelCategory()) {
            case Binomial:
            case Multinomial: {
                final List<MojoColumnMeta> columns = new ArrayList<>();
                for (int i = 0; i < genModel.getNumResponseClasses(); i += 1) {
                    final String columnName = genModel.getResponseName() + "." + genModel.getDomainValues(genModel.getResponseIdx())[i];
                    columns.add(MojoColumnMeta.newOutput(columnName,MojoColumn.Type.Float64));
                }
                _outputMeta = new MojoFrameMeta(columns);
                break;
            }
            case Regression: {
                final MojoColumnMeta column = MojoColumnMeta.newOutput(genModel.getResponseName(), MojoColumn.Type.Float64);
                _outputMeta = new MojoFrameMeta(Collections.singletonList(column));
                break;
            }
            default:
                throw new UnsupportedOperationException("Unsupported ModelCategory: " + genModel.getModelCategory().toString());
        }
    }

    @Override
    protected MojoFrameBuilder getFrameBuilder(MojoColumn.Kind kind) {
        return new MojoFrameBuilder(getMeta(kind), Collections.<String>emptyList(), Collections.<String, StringConverter>emptyMap());
    }

    @Override
    protected MojoFrameMeta getMeta(MojoColumn.Kind kind) {
        switch(kind) {
            case Feature:
                return _inputMeta;
            case Output:
                return _outputMeta;
            default:
                throw new UnsupportedOperationException("Cannot generate meta for interim frame");
        }
    }

    @Override
    public MojoFrame transform(MojoFrame inputFrame, MojoFrame outputFrame) {
        final ModelCategory modelCategory = genModel.getModelCategory();
        final int colCount = inputFrame.getNcols();
        final int rowCount = inputFrame.getNrows();
        final String[] columnNames = inputFrame.getColumnNames();
        final String[][] columns = new String[colCount][];
        for (int j = 0; j < colCount; j += 1) {
            columns[j] = inputFrame.getColumn(j).getDataAsStrings();
        }

        for (int rowIdx = 0; rowIdx < rowCount; rowIdx++) {
            final RowData rowData = new RowData();
            for (int colIdx = 0; colIdx < colCount; colIdx++) {
                final String key = columnNames[colIdx];
                final String value = columns[colIdx][rowIdx];
                if (value != null) {
                    rowData.put(key, value);
                }
            }

            try {
                switch (modelCategory) {
                    case Binomial: {
                        final BinomialModelPrediction p = _model.predictBinomial(rowData);
                        setPrediction(outputFrame, rowIdx, p.classProbabilities);
                    }
                    break;
                    case Multinomial: {
                        final MultinomialModelPrediction p = _model.predictMultinomial(rowData);
                        setPrediction(outputFrame, rowIdx, p.classProbabilities);
                    }
                    break;
                    case Regression: {
                        final RegressionModelPrediction p = _model.predictRegression(rowData);
                        final MojoColumnFloat64 col = (MojoColumnFloat64) outputFrame.getColumn(0);
                        final double[] darr = (double[]) col.getData();
                        darr[rowIdx] = p.value;
                    }
                    break;
                    default:
                        throw new UnsupportedOperationException("Unsupported ModelCategory: " + modelCategory.toString());
                }
            } catch (UnsupportedOperationException e) {
                throw e;
            } catch (PredictException e) {
                if (ai.h2o.mojos.runtime.utils.Debug.getPrintH2O3Exceptions()) e.printStackTrace();
                throw new UnsupportedOperationException(String.format("%s failed: %s", modelCategory, e.getMessage()));
            } catch (Exception e) {
                if (ai.h2o.mojos.runtime.utils.Debug.getPrintH2O3Exceptions()) e.printStackTrace();
                throw new UnsupportedOperationException(String.format("%s failed with %s: %s", modelCategory, e.getClass().getName(), e.getMessage()));
            }
        }

        return outputFrame;
    }

    private void setPrediction(MojoFrame outputFrame, int rowIdx, double[] classProbabilities) {
        for (int outputColIdx = 0; outputColIdx < genModel.getNumResponseClasses(); outputColIdx++) {
            final MojoColumnFloat64 col = (MojoColumnFloat64) outputFrame.getColumn(outputColIdx);
            final double[] darr = (double[]) col.getData();
            darr[rowIdx] = classProbabilities[outputColIdx];
        }
    }

    /**
     * Wraps the specified {@link MojoModel} as an {@link EasyPredictModelWrapper} with
     * configuration to behave similar to Mojo2 behavior.
     *
     * This includes configuring the wrapper to tolerate and ignore (by forcing to NA) bad input
     * without throwing an exception.
     */
    private static EasyPredictModelWrapper wrapModelForPrediction(MojoModel model) {
        EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config()
            .setModel(model)
            .setConvertUnknownCategoricalLevelsToNa(true)
            .setConvertInvalidNumbersToNa(true);

        return new EasyPredictModelWrapper(config);
    }
}
