package hex.api;

import hex.DataInfo;
import hex.Model;
import hex.glm.GLMModel;
import hex.gram.Gram;
import hex.schemas.DataInfoFrameV3;
import hex.schemas.GLMModelV3;
import hex.schemas.GLMRegularizationPathV3;
import hex.schemas.GramV3;
import hex.schemas.MakeGLMModelV3;
import java.util.Arrays;
import java.util.HashMap;
import water.DKV;
import water.Key;
import water.MRTask;
import water.api.Handler;
import water.api.schemas3.KeyV3;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.InteractionWrappedVec;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.ArrayUtils;

/* loaded from: input_file:hex/api/MakeGLMModelHandler.class */
public class MakeGLMModelHandler extends Handler {
    public GLMModelV3 make_model(int i, MakeGLMModelV3 makeGLMModelV3) {
        GLMModel get = DKV.getGet(makeGLMModelV3.model.key());
        if (get == null) {
            throw new IllegalArgumentException("missing source model " + makeGLMModelV3.model);
        }
        String[] coefficientNames = ((GLMModel.GLMOutput) get._output).coefficientNames();
        HashMap<String, Double> coefficients = get.coefficients();
        for (int i2 = 0; i2 < makeGLMModelV3.names.length; i2++) {
            coefficients.put(makeGLMModelV3.names[i2], Double.valueOf(makeGLMModelV3.beta[i2]));
        }
        double[] dArr = (double[]) get.beta().clone();
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr[i3] = coefficients.get(coefficientNames[i3]).doubleValue();
        }
        GLMModel gLMModel = new GLMModel(makeGLMModelV3.dest != null ? makeGLMModelV3.dest.key() : Key.make(), (GLMModel.GLMParameters) get._parms, null, get._ymu, Double.NaN, Double.NaN, -1L);
        get.dinfo().setPredictorTransform(DataInfo.TransformType.NONE);
        gLMModel._output = new GLMModel.GLMOutput(get.dinfo(), ((GLMModel.GLMOutput) get._output)._names, ((GLMModel.GLMOutput) get._output)._domains, ((GLMModel.GLMOutput) get._output).coefficientNames(), ((GLMModel.GLMOutput) get._output)._binomial, dArr);
        DKV.put(gLMModel._key, gLMModel);
        GLMModelV3 gLMModelV3 = new GLMModelV3();
        gLMModelV3.fillFromImpl(gLMModel);
        return gLMModelV3;
    }

    public GLMRegularizationPathV3 extractRegularizationPath(int i, GLMRegularizationPathV3 gLMRegularizationPathV3) {
        GLMModel get = DKV.getGet(gLMRegularizationPathV3.model.key());
        if (get == null) {
            throw new IllegalArgumentException("missing source model " + gLMRegularizationPathV3.model);
        }
        return new GLMRegularizationPathV3().fillFromImpl(get.getRegularizationPath());
    }

    public DataInfoFrameV3 getDataInfoFrame(int i, DataInfoFrameV3 dataInfoFrameV3) {
        Frame get = DKV.getGet(dataInfoFrameV3.frame.key());
        if (null == get) {
            throw new IllegalArgumentException("no frame found");
        }
        dataInfoFrameV3.result = new KeyV3.FrameKeyV3(oneHot(get, Model.InteractionSpec.allPairwise(dataInfoFrameV3.interactions), dataInfoFrameV3.use_all, dataInfoFrameV3.standardize, dataInfoFrameV3.interactions_only, true)._key);
        return dataInfoFrameV3;
    }

    /* JADX WARN: Type inference failed for: r0v37, types: [hex.api.MakeGLMModelHandler$1] */
    /* JADX WARN: Type inference failed for: r0v6, types: [hex.api.MakeGLMModelHandler$2] */
    public static Frame oneHot(Frame frame, Model.InteractionSpec interactionSpec, boolean z, boolean z2, boolean z3, final boolean z4) {
        Frame outputFrame;
        final DataInfo dataInfo = new DataInfo(frame, (Frame) null, 1, z, z2 ? DataInfo.TransformType.STANDARDIZE : DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, z4, false, false, false, false, false, interactionSpec);
        if (!z3) {
            byte[] bArr = new byte[dataInfo.fullN()];
            Arrays.fill(bArr, (byte) 3);
            outputFrame = new MRTask() { // from class: hex.api.MakeGLMModelHandler.2
                public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
                    DataInfo.Row newDenseRow = DataInfo.this.newDenseRow();
                    for (int i = 0; i < chunkArr[0]._len; i++) {
                        newDenseRow = DataInfo.this.extractDenseRow(chunkArr, i, newDenseRow);
                        if (!z4 || !newDenseRow.isBad()) {
                            for (int i2 = 0; i2 < newChunkArr.length; i2++) {
                                newChunkArr[i2].addNum(newDenseRow.get(i2));
                            }
                        }
                    }
                }
            }.doAll(bArr, dataInfo._adaptedFrame.vecs()).outputFrame(Key.make("OneHot" + Key.make().toString()), dataInfo.coefNames(), (String[][]) null);
        } else {
            if (null == dataInfo._interactionVecs) {
                throw new IllegalArgumentException("no interactions");
            }
            int i = 0;
            final int[] iArr = new int[dataInfo._interactionVecs.length];
            final int[] iArr2 = new int[dataInfo._interactionVecs.length];
            int i2 = 0;
            String[] coefNames = dataInfo.coefNames();
            for (int i3 : dataInfo._interactionVecs) {
                int i4 = i2;
                i2++;
                int expandedLength = dataInfo._adaptedFrame.vec(i3).expandedLength();
                iArr2[i4] = expandedLength;
                i += expandedLength;
            }
            String[] strArr = new String[i];
            int i5 = 0;
            int i6 = 0;
            int i7 = 0;
            for (int i8 = 0; i8 < dataInfo._adaptedFrame.numCols(); i8++) {
                Vec vec = dataInfo._adaptedFrame.vec(i8);
                if (vec instanceof InteractionWrappedVec) {
                    iArr[i5] = i6;
                    for (int i9 = 0; i9 < iArr2[i5]; i9++) {
                        int i10 = i7;
                        i7++;
                        int i11 = i6;
                        i6++;
                        strArr[i10] = coefNames[i11];
                    }
                    i5++;
                    if (i5 > dataInfo._interactionVecs.length) {
                        break;
                    }
                } else if (vec.isCategorical()) {
                    i6 += vec.domain().length - (z ? 0 : 1);
                } else {
                    i6++;
                }
            }
            outputFrame = new MRTask() { // from class: hex.api.MakeGLMModelHandler.1
                public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
                    DataInfo.Row newDenseRow = DataInfo.this.newDenseRow();
                    for (int i12 = 0; i12 < chunkArr[0]._len; i12++) {
                        newDenseRow = DataInfo.this.extractDenseRow(chunkArr, i12, newDenseRow);
                        if (!z4 || !newDenseRow.isBad()) {
                            int i13 = 0;
                            for (int i14 = 0; i14 < iArr.length; i14++) {
                                int i15 = iArr[i14];
                                for (int i16 = i15; i16 < i15 + iArr2[i14]; i16++) {
                                    int i17 = i13;
                                    i13++;
                                    newChunkArr[i17].addNum(newDenseRow.get(i16));
                                }
                            }
                        }
                    }
                }
            }.doAll(i, (byte) 3, dataInfo._adaptedFrame).outputFrame(Key.make(), strArr, (String[][]) null);
        }
        dataInfo.dropInteractions();
        dataInfo.remove();
        return outputFrame;
    }

    public GramV3 computeGram(int i, GramV3 gramV3) {
        if (DKV.get(gramV3.X.key()) == null) {
            throw new IllegalArgumentException("Frame " + gramV3.X.key() + " does not exist.");
        }
        Frame frame = gramV3.X.key().get();
        Frame frame2 = new Frame((String[]) frame._names.clone(), (Vec[]) frame.vecs().clone());
        String str = null;
        Vec vec = null;
        if (gramV3.W != null && !gramV3.W.column_name.isEmpty()) {
            str = gramV3.W.column_name;
            if (frame.find(str) == -1) {
                throw new IllegalArgumentException("Did not find weight vector " + str);
            }
            vec = frame2.remove(str);
        }
        DataInfo dataInfo = new DataInfo(frame2, (Frame) null, 0, gramV3.use_all_factor_levels, gramV3.standardize ? DataInfo.TransformType.STANDARDIZE : DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, gramV3.skip_missing, false, !gramV3.skip_missing, false, false, false, true);
        DKV.put(dataInfo);
        if (vec != null) {
            dataInfo.setWeights(str, vec);
        }
        double[][] xx = ((Gram.GramTask) new Gram.GramTask(null, dataInfo, false, true).doAll(dataInfo._adaptedFrame))._gram.getXX();
        dataInfo.remove();
        String[] strArr = (String[]) ArrayUtils.append(dataInfo.coefNames(), new String[]{"Intercept"});
        Vec[] vecArr = new Vec[xx.length];
        Key[] addVecs = new Vec.VectorGroup().addVecs(vecArr.length);
        for (int i2 = 0; i2 < vecArr.length; i2++) {
            vecArr[i2] = Vec.makeVec(xx[i2], addVecs[i2]);
        }
        gramV3.destination_frame = new KeyV3.FrameKeyV3();
        String key = gramV3.X.key().toString();
        if (key.endsWith(".hex")) {
            key = key.substring(0, key.lastIndexOf("."));
        }
        String str2 = key + "_gram";
        if (vec != null) {
            str2 = str2 + "_" + str;
        }
        Key make = Key.make(str2);
        if (DKV.get(make) != null) {
            int i3 = 0;
            while (i3 < 1000) {
                Key make2 = Key.make(str2 + "_" + i3);
                make = make2;
                if (DKV.get(make2) == null) {
                    break;
                }
                i3++;
            }
            if (i3 == 1000) {
                throw new IllegalArgumentException("unable to make unique key");
            }
        }
        gramV3.destination_frame.fillFromImpl(make);
        DKV.put(new Frame(make, strArr, vecArr));
        return gramV3;
    }
}
