package ai.h2o.mojos.runtime.h2o3;

import ai.h2o.mojos.runtime.MojoPipeline;
import ai.h2o.mojos.runtime.api.MojoColumnMeta;
import ai.h2o.mojos.runtime.frame.MojoColumn;
import ai.h2o.mojos.runtime.frame.MojoColumnFloat64;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import ai.h2o.mojos.runtime.frame.MojoFrameBuilder;
import ai.h2o.mojos.runtime.frame.MojoFrameMeta;
import ai.h2o.mojos.runtime.frame.StringConverter;
import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.klime.KLimeMojoModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import hex.genmodel.easy.prediction.KLimeModelPrediction;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.joda.time.DateTime;

/**
 * A MOJO2 pipeline implementation that uses a k-LIME MOJO (built on H2O-3 MOJO framework)
 * as the predictor inside.  The intent is to provide as identical an
 * experience to the MOJO2 API as possible.
 */
public class MojoPipelineKlimeImpl extends MojoPipeline {
    private final EasyPredictModelWrapper _model;
    private final GenModel genModel;

    private final MojoFrameMeta _inputMeta;
    private final MojoFrameMeta _outputMeta;

    /**
     * A MOJO2 pipeline implementation that uses a k-LIME MOJO (built on H2O-3 MOJO framework)
     * as the predictor inside.
     *
     * @param model k-LIME MOJO model.
     */
    MojoPipelineKlimeImpl(MojoModel model) {
        super(model.getUUID(), new DateTime(1970, 1, 1, 0, 0), "");

        _model = wrapModelForPrediction(model);
        genModel = _model.m;

        {
            final List<MojoColumnMeta> columns = new ArrayList<>();
            for (int i = 0; i < genModel.getNumCols(); i += 1) {
                final String columnName = genModel.getNames()[i];
                final MojoColumn.Type columnType = (genModel.getDomainValues(i) == null) ? MojoColumn.Type.Float64 : MojoColumn.Type.Str;
                columns.add(MojoColumnMeta.newInput(columnName, columnType));
            }
            _inputMeta = new MojoFrameMeta(columns);
        }

        {
            // TODO following is a bit strange exercise, let's check with MM and/or Navdeep why is that
            final List<String> mypredictorsList = new ArrayList<>(Arrays.asList(genModel.getNames()));
            mypredictorsList.remove(genModel.getResponseName());
            final String[] predictors = mypredictorsList.toArray(new String[0]);

            final List<MojoColumnMeta> columns = new ArrayList<>();
            columns.add(MojoColumnMeta.newOutput(genModel.getResponseName(), MojoColumn.Type.Float64));
            columns.add(MojoColumnMeta.newOutput("cluster", MojoColumn.Type.Float64));
            for (int i = 2; i < genModel.getPredsSize(); i += 1) {
                columns.add(MojoColumnMeta.newOutput(predictors[i - 2], MojoColumn.Type.Float64));
            }
            _outputMeta = new MojoFrameMeta(columns);
        }

    }

    @Override
    protected MojoFrameBuilder getFrameBuilder(MojoColumn.Kind kind) {
        return new MojoFrameBuilder(getMeta(kind), Collections.<String>emptyList(), Collections.<String, StringConverter>emptyMap());
    }

    @Override
    protected MojoFrameMeta getMeta(MojoColumn.Kind kind) {
        switch(kind) {
            case Feature:
                return _inputMeta;
            case Output:
                return _outputMeta;
            default:
                throw new UnsupportedOperationException("Cannot generate meta for interim frame");
        }
    }

    @Override
    public MojoFrame transform(MojoFrame inputFrame, MojoFrame outputFrame) {
        final ModelCategory modelCategory = genModel.getModelCategory();
        final int colCount = inputFrame.getNcols();
        final int rowCount = inputFrame.getNrows();
        final String[] columnNames = inputFrame.getColumnNames();
        final String[][] columns = new String[colCount][];
        for (int j = 0; j < colCount; j += 1) {
            columns[j] = inputFrame.getColumn(j).getDataAsStrings();
        }

        for (int rowIdx = 0; rowIdx < rowCount; rowIdx++) {
            final RowData rowData = new RowData();
            for (int colIdx = 0; colIdx < colCount; colIdx++) {
                final String key = columnNames[colIdx];
                final String value = columns[colIdx][rowIdx];
                if (value != null) {
                    rowData.put(key, value);
                }
            }
            try {
                final KLimeModelPrediction p = _model.predictKLime(rowData);
                for (int outputColIdx = 0; outputColIdx < genModel.getPredsSize(); outputColIdx++){
                    final MojoColumnFloat64 col = (MojoColumnFloat64) outputFrame.getColumn(outputColIdx);
                    final double[] darr = (double[]) col.getData();
                    switch (outputColIdx) {
                        case 0:
                            darr[rowIdx] = p.value;
                            break;
                        case 1:
                            darr[rowIdx] = p.cluster;
                            break;
                        default:
                            darr[rowIdx] = p.reasonCodes[outputColIdx - 2];
                            break;
                    }
                }
            } catch (PredictException e) {
                if (ai.h2o.mojos.runtime.utils.Debug.getPrintH2O3Exceptions()) e.printStackTrace();
                throw new UnsupportedOperationException(String.format("%s failed: %s", modelCategory, e.getMessage()));
            } catch (Exception e) {
                if (ai.h2o.mojos.runtime.utils.Debug.getPrintH2O3Exceptions()) e.printStackTrace();
                throw new UnsupportedOperationException(String.format("%s failed with %s: %s", modelCategory, e.getClass().getName(), e.getMessage()));
            }
        }

        return outputFrame;
    }

    /**
     * Wraps the specified {@link KLimeMojoModel} as an {@link EasyPredictModelWrapper} with
     * configuration to behave similar to Mojo2 behavior.
     *
     * This includes configuring the wrapper to tolerate and ignore (by forcing to NA) bad input
     * without throwing an exception.
     */
    private static EasyPredictModelWrapper wrapModelForPrediction(MojoModel model) {
        EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config()
                .setModel(model)
                .setConvertUnknownCategoricalLevelsToNa(true)
                .setConvertInvalidNumbersToNa(true);

        return new EasyPredictModelWrapper(config) {
            // This is hotfix for https://0xdata.atlassian.net/browse/PUBDEV-7627
            // The klime model implementation returns the model type 'Regression',
            // but the method KLimeMojoModel#predictKLime setup type to Klime.
            // That leads to a failed check in the method
            // EasyPredictModelWrapper#validateModelCategory.
            @Override
            protected double[] preamble(ModelCategory c, RowData data) throws PredictException {
                assert c == ModelCategory.KLime : "Unexpected model category, should be K-Lime! "
                                                  + "If you upgraded h2o-genmodel runtime,"
                                                  + "please check https://0xdata.atlassian.net/browse/PUBDEV-7627"
                                                  + " and remove this override of the `preamble` method";
                return preamble(ModelCategory.Regression, data, 0.0);
            }
        };
    }
}
