package ai.h2o.mojos.runtime.h2o3;

import ai.h2o.mojos.runtime.frame.MojoColumnFloat64;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import ai.h2o.mojos.runtime.frame.MojoFrameMeta;
import ai.h2o.mojos.runtime.transforms.MojoTransform;
import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import hex.genmodel.easy.prediction.BinomialModelPrediction;
import hex.genmodel.easy.prediction.MultinomialModelPrediction;
import hex.genmodel.easy.prediction.RegressionModelPrediction;

/**
 * A MOJO2 pipeline implementation that uses an H2O-3 (or Sparkling Water)
 * MOJO as the predictor inside.  The intent is to provide as identical an
 * experience to the MOJO2 API as possible.
 * <p>
 * A non-goal is to expose every possible low-level H2O-3 MOJO API capability.
 * If you want to do that, call the H2O-3 MOJO API directly, instead.
 */
public class H2O3Transform extends MojoTransform {
    private final GenModel genModel;
    private final EasyPredictModelWrapper easyPredictModelWrapper;

    /**
     * A MOJO2 transformer implementation that uses an H2O-3 (or Sparkling Water) MOJO as the predictor inside.
     * <p>
     * Must provide a valid Binomial, Multinomial or Regression model.
     * Other model types not currently supported.
     * <p>
     * Note: later, we might consider splitting this into one class per each supported model type, to more closely represent underlying H2O-3 algos.
     *
     * @param easyPredictModelWrapper H2O-3 MOJO model.
     */
    H2O3Transform(MojoFrameMeta meta, int[] iindices, int[] oindices, EasyPredictModelWrapper easyPredictModelWrapper) {
        super(iindices, oindices);
        this.easyPredictModelWrapper = easyPredictModelWrapper;
        this.genModel = easyPredictModelWrapper.m;
    }

    @Override
    public void transform(final MojoFrame frame) {
        final ModelCategory modelCategory = genModel.getModelCategory();
        final int colCount = iindices.length;
        final int rowCount = frame.getNrows();
        final String[][] columns = new String[colCount][];
        for (int j = 0; j < colCount; j += 1) {
            final int iidx = iindices[j];
            columns[j] = frame.getColumn(iidx).getDataAsStrings();
        }

        for (int rowIdx = 0; rowIdx < rowCount; rowIdx++) {
            final RowData rowData = new RowData();
            for (int colIdx = 0; colIdx < colCount; colIdx++) {
                final int iidx = iindices[colIdx];
                final String key = frame.getColumnName(iidx);
                final String value = columns[colIdx][rowIdx];
                if (value != null) {
                    rowData.put(key, value);
                }
            }

            try {
                switch (modelCategory) {
                    case Binomial: {
                        final BinomialModelPrediction p = easyPredictModelWrapper.predictBinomial(rowData);
                        setPrediction(frame, rowIdx, p.classProbabilities);
                    }
                    break;
                    case Multinomial: {
                        final MultinomialModelPrediction p = easyPredictModelWrapper.predictMultinomial(rowData);
                        setPrediction(frame, rowIdx, p.classProbabilities);
                    }
                    break;
                    case Regression: {
                        final RegressionModelPrediction p = easyPredictModelWrapper.predictRegression(rowData);
                        final MojoColumnFloat64 col = (MojoColumnFloat64) frame.getColumn(oindices[0]);
                        final double[] darr = (double[]) col.getData();
                        darr[rowIdx] = p.value;
                    }
                    break;
                    default:
                        throw new UnsupportedOperationException("Unsupported ModelCategory: " + modelCategory.toString());
                }
            } catch (UnsupportedOperationException e) {
                throw e;
            } catch (PredictException e) {
                if (ai.h2o.mojos.runtime.utils.Debug.getPrintH2O3Exceptions()) e.printStackTrace();
                throw new UnsupportedOperationException(String.format("%s failed: %s", modelCategory, e.getMessage()));
            } catch (Exception e) {
                if (ai.h2o.mojos.runtime.utils.Debug.getPrintH2O3Exceptions()) e.printStackTrace();
                throw new UnsupportedOperationException(String.format("%s failed with %s: %s", modelCategory, e.getClass().getName(), e.getMessage()));
            }
        }
    }

    private void setPrediction(MojoFrame frame, int rowIdx, double[] classProbabilities) {
        for (int outputColIdx = 0; outputColIdx < oindices.length; outputColIdx++) {
            final int oidx = oindices[outputColIdx];
            final MojoColumnFloat64 col = (MojoColumnFloat64) frame.getColumn(oidx);
            final double[] darr = (double[]) col.getData();
            darr[rowIdx] = classProbabilities[outputColIdx];
        }
    }
}
