package ai.h2o.mojos.runtime.h2o3;

import ai.h2o.mojos.runtime.frame.MojoColumnFloat64;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import ai.h2o.mojos.runtime.frame.MojoFrameMeta;
import ai.h2o.mojos.runtime.transforms.MojoTransform;
import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import hex.genmodel.easy.prediction.KLimeModelPrediction;

/**
 * A MOJO2 pipeline implementation that uses a k-LIME MOJO (built on H2O-3 MOJO framework)
 * as the predictor inside.  The intent is to provide as identical an
 * experience to the MOJO2 API as possible.
 */
public class KlimeTransform extends MojoTransform {
    private final EasyPredictModelWrapper easyPredictModelWrapper;
    private final GenModel genModel;

    /**
     * A MOJO2 transformer implementation that uses a k-LIME MOJO (built on H2O-3 MOJO framework) as the predictor inside.
     *
     * @param easyPredictModelWrapper H2O-3 MOJO model.
     */
    KlimeTransform(MojoFrameMeta meta, int[] iindices, int[] oindices, EasyPredictModelWrapper easyPredictModelWrapper) {
        super(iindices, oindices);
        this.easyPredictModelWrapper = easyPredictModelWrapper;
        this.genModel = easyPredictModelWrapper.m;
    }

    @Override
    public void transform(final MojoFrame frame) {
        final ModelCategory modelCategory = genModel.getModelCategory();
        final int colCount = iindices.length;
        final int rowCount = frame.getNrows();
        final String[][] columns = new String[colCount][];
        for (int j = 0; j < colCount; j += 1) {
            final int iidx = iindices[j];
            columns[j] = frame.getColumn(iidx).getDataAsStrings();
        }

        for (int rowIdx = 0; rowIdx < rowCount; rowIdx++) {
            final RowData rowData = new RowData();
            for (int colIdx = 0; colIdx < colCount; colIdx++) {
                final int iidx = iindices[colIdx];
                final String key = frame.getColumnName(iidx);
                final String value = columns[colIdx][rowIdx];
                if (value != null) {
                    rowData.put(key, value);
                }
            }
            try {
                final KLimeModelPrediction p = easyPredictModelWrapper.predictKLime(rowData);
                for (int outputColIdx = 0; outputColIdx < genModel.getPredsSize(); outputColIdx++) {
                    final MojoColumnFloat64 col = (MojoColumnFloat64) frame.getColumn(oindices[outputColIdx]);
                    final double[] darr = (double[]) col.getData();
                    switch (outputColIdx) {
                        case 0:
                            darr[rowIdx] = p.value;
                            break;
                        case 1:
                            darr[rowIdx] = p.cluster;
                            break;
                        default:
                            darr[rowIdx] = p.reasonCodes[outputColIdx - 2];
                            break;
                    }
                }
            } catch (PredictException e) {
                if (ai.h2o.mojos.runtime.utils.Debug.getPrintH2O3Exceptions()) e.printStackTrace();
                throw new UnsupportedOperationException(String.format("%s failed: %s", modelCategory, e.getMessage()));
            } catch (Exception e) {
                if (ai.h2o.mojos.runtime.utils.Debug.getPrintH2O3Exceptions()) e.printStackTrace();
                throw new UnsupportedOperationException(String.format("%s failed with %s: %s", modelCategory, e.getClass().getName(), e.getMessage()));
            }
        }
    }
}
