package water.rapids.ast.prims.mungers;

import java.util.Arrays;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import water.DKV;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;
import water.util.ArrayUtils;
import water.util.VecUtils;

/* loaded from: input_file:water/rapids/ast/prims/mungers/AstRelevelByFreq.class */
public class AstRelevelByFreq extends AstPrimitive<AstRelevelByFreq> {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:water/rapids/ast/prims/mungers/AstRelevelByFreq$RemapDomain.class */
    public static class RemapDomain extends MRTask<RemapDomain> {
        private final int[] _mapping;

        public RemapDomain(int[] iArr) {
            this._mapping = iArr;
        }

        @Override // water.MRTask
        public void map(Chunk chunk) {
            for (int i = 0; i < chunk._len; i++) {
                if (!chunk.isNA(i)) {
                    chunk.set(i, this._mapping[(int) chunk.atd(i)]);
                }
            }
        }
    }

    @Override // water.rapids.ast.AstPrimitive
    public String[] args() {
        return new String[]{"frame", "weights", "topn"};
    }

    @Override // water.rapids.ast.AstPrimitive
    public int nargs() {
        return 4;
    }

    @Override // water.rapids.ast.AstRoot
    public String str() {
        return "relevel.by.freq";
    }

    @Override // water.rapids.ast.AstPrimitive
    public ValFrame apply(Env env, Env.StackHelp stackHelp, AstRoot[] astRootArr) {
        Frame frame = stackHelp.track(astRootArr[1].exec(env)).getFrame();
        String str = astRootArr[2].exec(env).getStr();
        Vec vec = frame.vec(str);
        if (str != null && vec == null) {
            throw new IllegalArgumentException("Frame doesn't contain weights column '" + str + "'.");
        }
        double num = astRootArr[3].exec(env).getNum();
        if ((num != -1.0d && num <= CMAESOptimizer.DEFAULT_STOPFITNESS) || ((int) num) != num) {
            throw new IllegalArgumentException("TopN argument needs to be a positive integer number, got: " + num);
        }
        Frame frame2 = new Frame(frame);
        for (int i = 0; i < frame2.numCols(); i++) {
            Vec vec2 = frame2.vec(i);
            if (vec2.isCategorical()) {
                Vec makeCopy = vec2.makeCopy();
                frame2.replace(i, makeCopy);
                relevelByFreq(makeCopy, vec, (int) num);
            }
        }
        return new ValFrame(frame2);
    }

    static void relevelByFreq(Vec vec, Vec vec2, int i) {
        double[] collectDomainWeights = VecUtils.collectDomainWeights(vec, vec2);
        int[] seq = ArrayUtils.seq(0, collectDomainWeights.length);
        ArrayUtils.sort(seq, collectDomainWeights, 0, -1);
        if (i != -1 && i < seq.length - 1) {
            seq = takeTopNMostFrequentDomain(seq, i);
        }
        String[] domain = vec.domain();
        String[] strArr = (String[]) vec.domain().clone();
        for (int i2 = 0; i2 < seq.length; i2++) {
            strArr[i2] = domain[seq[i2]];
        }
        new RemapDomain(getMapping(seq)).doAll(vec);
        vec.setDomain(strArr);
        DKV.put(vec);
    }

    static int[] getMapping(int[] iArr) {
        int[] iArr2 = new int[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            iArr2[iArr[i]] = i;
        }
        return iArr2;
    }

    static int[] takeTopNMostFrequentDomain(int[] iArr, int i) {
        int length = iArr.length;
        int[] iArr2 = new int[length];
        int[] iArr3 = new int[i];
        for (int i2 = 0; i2 < i; i2++) {
            int i3 = iArr[i2];
            iArr3[i2] = i3;
            iArr2[i2] = i3;
        }
        Arrays.sort(iArr3);
        int i4 = i;
        for (int i5 = 0; i5 < length; i5++) {
            if (Arrays.binarySearch(iArr3, i5) < 0) {
                int i6 = i4;
                i4++;
                iArr2[i6] = i5;
            }
        }
        if ($assertionsDisabled || i4 == length) {
            return iArr2;
        }
        throw new AssertionError();
    }

    static {
        $assertionsDisabled = !AstRelevelByFreq.class.desiredAssertionStatus();
    }
}
