package com.github.fairsearch.deltr;

import com.github.fairsearch.deltr.models.TrainStep;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:com/github/fairsearch/deltr/Trainer.class */
public class Trainer {
    private static final Logger LOGGER = Logger.getLogger(Trainer.class.getName());
    private double gamma;
    private boolean noExposure;
    private int numberOfIterations;
    private double learningRate;
    private double lambda;
    private double initVar;
    private Map<String, INDArray> dataPerQuery;
    private Map<String, INDArray> itemsPerQueryCache;
    private Map<String, ItemGroup> itemsPerGroupPerQueryCache;
    private Map<String, INDArray> normalizedToppProtDerivPerGroupDiffCache;
    private Map<String, Double> exposureDiffCache;
    private static Map<String, INDArray> toppCache;
    private List<TrainStep> log;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/github/fairsearch/deltr/Trainer$ItemGroup.class */
    public static class ItemGroup {
        private INDArray judgementsPerQuery;
        private INDArray protectedItemsPerQuery;
        private INDArray nonprotectedItemsPerQuery;
        private int code;

        public ItemGroup(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
            this.judgementsPerQuery = iNDArray;
            this.protectedItemsPerQuery = iNDArray2;
            this.nonprotectedItemsPerQuery = iNDArray3;
        }

        public INDArray getJudgementsPerQuery() {
            return this.judgementsPerQuery;
        }

        public INDArray getProtectedItemsPerQuery() {
            return this.protectedItemsPerQuery;
        }

        public INDArray getNonprotectedItemsPerQuery() {
            return this.nonprotectedItemsPerQuery;
        }

        public int hashCode() {
            return (23 * ((23 * ((23 * 1) + (this.judgementsPerQuery == null ? 0 : this.judgementsPerQuery.hashCode()))) + (this.protectedItemsPerQuery == null ? 0 : this.protectedItemsPerQuery.hashCode()))) + (this.nonprotectedItemsPerQuery == null ? 0 : this.nonprotectedItemsPerQuery.hashCode());
        }
    }

    public Trainer(double d, int i, double d2, double d3, double d4) {
        this.gamma = d;
        this.numberOfIterations = i;
        this.learningRate = d2;
        this.lambda = d3;
        this.initVar = d4;
        cleanCache();
        this.noExposure = false;
        if (this.gamma == 0.0d) {
            this.noExposure = true;
        }
        cleanLog();
    }

    public double[] train(int[] iArr, int[] iArr2, INDArray iNDArray, INDArray iNDArray2) {
        int i = iNDArray.shape()[0];
        int i2 = iNDArray.shape()[1];
        Arrays.stream(iArr).forEach(i3 -> {
            this.dataPerQuery.put(keyGen(i3, iNDArray2), findItemsPerGroupPerQuery(iNDArray2, iArr, i3, iArr2).getJudgementsPerQuery());
            this.dataPerQuery.put(keyGen(i3, iNDArray), findItemsPerGroupPerQuery(iNDArray, iArr, i3, iArr2).getJudgementsPerQuery());
        });
        INDArray mul = Nd4j.rand(i2, 1).mul(Double.valueOf(this.initVar));
        INDArray zeros = Nd4j.zeros(this.numberOfIterations, 1);
        INDArray create = Nd4j.create(this.numberOfIterations, i2);
        cleanLog();
        for (int i4 = 0; i4 < this.numberOfIterations; i4++) {
            long currentTimeMillis = System.currentTimeMillis();
            long currentTimeMillis2 = System.currentTimeMillis();
            INDArray reshape = iNDArray.mmul(mul).reshape(i, 1);
            HashMap hashMap = new HashMap();
            Arrays.stream(iArr).forEach(i5 -> {
                hashMap.put(keyGen(i5, reshape), findItemsPerGroupPerQuery(reshape, iArr, i5, iArr2).getJudgementsPerQuery());
            });
            LOGGER.info(String.format("Prediction computed in %d ms", Long.valueOf(System.currentTimeMillis() - currentTimeMillis2)));
            long currentTimeMillis3 = System.currentTimeMillis();
            TrainStep calculateCost = calculateCost(iNDArray2, reshape, iArr, iArr2, hashMap);
            LOGGER.info("Cost: " + (System.currentTimeMillis() - currentTimeMillis3));
            long currentTimeMillis4 = System.currentTimeMillis();
            INDArray add = calculateCost.getCost().add(reshape.mul(reshape).mul(Double.valueOf(this.lambda)));
            INDArray calculateGradient = calculateGradient(iNDArray, iNDArray2, reshape, iArr, iArr2, hashMap);
            LOGGER.info(String.format("Gradient computed in %d ms", Long.valueOf(System.currentTimeMillis() - currentTimeMillis4)));
            calculateCost.setOmega(mul);
            calculateCost.setGrad(calculateGradient);
            calculateCost.setTotalCost(calculateCost.getCost().sumNumber().doubleValue());
            mul = mul.sub(calculateGradient.sum(new int[]{0}).mul(Double.valueOf(this.learningRate))).reshape(i2, 1);
            create.putRow(i4, mul.transpose());
            zeros.putScalar(i4, add.sum(new int[]{0}).getDouble(0));
            this.log.add(calculateCost);
            LOGGER.info(String.format("Iteration %d done in %d ms", Integer.valueOf(i4), Long.valueOf(System.currentTimeMillis() - currentTimeMillis)));
        }
        cleanCache();
        return mul.data().asDouble();
    }

    private void cleanCache() {
        this.dataPerQuery = new HashMap();
        this.itemsPerQueryCache = new HashMap();
        this.itemsPerGroupPerQueryCache = new HashMap();
        this.normalizedToppProtDerivPerGroupDiffCache = new HashMap();
        this.exposureDiffCache = new HashMap();
        toppCache = new HashMap();
    }

    private void cleanLog() {
        this.log = new ArrayList();
    }

    private INDArray calculateGradient(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int[] iArr, int[] iArr2, Map<String, INDArray> map) {
        INDArray create = Nd4j.create(iNDArray.shape());
        IntStream.range(0, iArr.length).forEach(i -> {
            int i = iArr[i];
            long currentTimeMillis = System.currentTimeMillis();
            if (i % 100 == 0) {
                LOGGER.info(String.format("Gradient Step 1 computed in %d ms", Long.valueOf(System.currentTimeMillis() - currentTimeMillis)));
                currentTimeMillis = System.currentTimeMillis();
            }
            double doubleValue = 1.0d / Transforms.exp((INDArray) map.get(keyGen(i, iNDArray3))).sumNumber().doubleValue();
            if (i % 100 == 0) {
                LOGGER.info(String.format("Gradient Step 2 computed in %d ms", Long.valueOf(System.currentTimeMillis() - currentTimeMillis)));
                currentTimeMillis = System.currentTimeMillis();
            }
            INDArray mul = this.dataPerQuery.get(keyGen(i, iNDArray)).transpose().mmul(Transforms.exp((INDArray) map.get(keyGen(i, iNDArray3)))).mul(Double.valueOf(doubleValue));
            if (i % 100 == 0) {
                LOGGER.info(String.format("Gradient Step 3 computed in %d ms", Long.valueOf(System.currentTimeMillis() - currentTimeMillis)));
                currentTimeMillis = System.currentTimeMillis();
            }
            INDArray div = mul.sub(this.dataPerQuery.get(keyGen(i, iNDArray)).transpose().mmul(topp(this.dataPerQuery.get(keyGen(i, iNDArray2))))).div(Double.valueOf(Math.log(iNDArray3.length())));
            if (i % 100 == 0) {
                LOGGER.info(String.format("Gradient Step 4 computed in %d ms", Long.valueOf(System.currentTimeMillis() - currentTimeMillis)));
                currentTimeMillis = System.currentTimeMillis();
            }
            if (!this.noExposure) {
                div = div.add(normalizedToppProtDerivPerGroupDiff(iNDArray, iNDArray3, iArr, i, iArr2).mul(Double.valueOf(this.gamma)).mul(2).mul(Double.valueOf(exposureDiff(iNDArray3, iArr, i, iArr2))).transpose());
            }
            if (i % 100 == 0) {
                LOGGER.info(String.format("Gradient Step 5 computed in %d ms", Long.valueOf(System.currentTimeMillis() - currentTimeMillis)));
            }
            create.putRow(i, div);
        });
        return create;
    }

    private INDArray normalizedToppProtDerivPerGroupDiff(INDArray iNDArray, INDArray iNDArray2, int[] iArr, int i, int[] iArr2) {
        String keyGen = keyGen(i, iArr, iArr2, iNDArray, iNDArray2);
        if (!this.normalizedToppProtDerivPerGroupDiffCache.containsKey(keyGen)) {
            ItemGroup findItemsPerGroupPerQuery = findItemsPerGroupPerQuery(iNDArray, iArr, i, iArr2);
            ItemGroup findItemsPerGroupPerQuery2 = findItemsPerGroupPerQuery(iNDArray2, iArr, i, iArr2);
            this.normalizedToppProtDerivPerGroupDiffCache.put(keyGen, normalizedToppProtDerivPerGroup(findItemsPerGroupPerQuery.getNonprotectedItemsPerQuery(), findItemsPerGroupPerQuery.getJudgementsPerQuery(), findItemsPerGroupPerQuery2.getNonprotectedItemsPerQuery(), findItemsPerGroupPerQuery2.getJudgementsPerQuery()).sub(normalizedToppProtDerivPerGroup(findItemsPerGroupPerQuery.getProtectedItemsPerQuery(), findItemsPerGroupPerQuery.getJudgementsPerQuery(), findItemsPerGroupPerQuery2.getProtectedItemsPerQuery(), findItemsPerGroupPerQuery2.getJudgementsPerQuery())));
        }
        return this.normalizedToppProtDerivPerGroupDiffCache.get(keyGen);
    }

    private INDArray normalizedToppProtDerivPerGroup(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4) {
        return toppProtFirstDerivative(iNDArray, iNDArray2, iNDArray3, iNDArray4).div(Double.valueOf(Math.log(2.0d))).sum(new int[]{0}).div(Integer.valueOf(iNDArray3.length()));
    }

    private INDArray toppProtFirstDerivative(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4) {
        INDArray mmul = Transforms.exp(iNDArray3).transpose().mmul(iNDArray);
        double doubleValue = Transforms.exp(iNDArray4).sumNumber().doubleValue();
        double doubleValue2 = Transforms.exp(iNDArray4).transpose().mmul(iNDArray2).sumNumber().doubleValue();
        double pow = Math.pow(Transforms.exp(iNDArray4).sumNumber().doubleValue(), 2.0d);
        INDArray mul = mmul.mul(Double.valueOf(doubleValue));
        INDArray mul2 = Transforms.exp(iNDArray3).mul(Double.valueOf(doubleValue2));
        INDArray create = Nd4j.create(iNDArray.shape());
        for (int i = 0; i < create.rows(); i++) {
            create.putRow(i, mul.sub(Float.valueOf(mul2.getFloat(i, 0))).div(Double.valueOf(pow)));
        }
        return create;
    }

    private TrainStep calculateCost(INDArray iNDArray, INDArray iNDArray2, int[] iArr, int[] iArr2, Map<String, INDArray> map) {
        INDArray create = Nd4j.create(iNDArray2.shape());
        IntStream.range(0, iArr.length).forEach(i -> {
            create.putRow(i, calculateLoss(iArr[i], iNDArray, iNDArray2, iArr, iArr2, map));
        });
        return new TrainStep(System.currentTimeMillis(), create, create.sumNumber().doubleValue(), Arrays.stream(iArr).parallel().mapToDouble(i2 -> {
            return exposureDiff(iNDArray2, iArr, i2, iArr2);
        }).sum());
    }

    private INDArray calculateLoss(int i, INDArray iNDArray, INDArray iNDArray2, int[] iArr, int[] iArr2, Map<String, INDArray> map) {
        INDArray mul = topp(this.dataPerQuery.get(keyGen(i, iNDArray))).transpose().mmul(Transforms.log(topp(map.get(keyGen(i, iNDArray2))))).div(Double.valueOf(Math.log(iNDArray2.length()))).mul(-1);
        if (!this.noExposure) {
            mul = mul.add(Double.valueOf(Math.pow(exposureDiff(iNDArray2, iArr, i, iArr2), 2.0d) * this.gamma));
        }
        return mul;
    }

    private double exposureDiff(INDArray iNDArray, int[] iArr, int i, int[] iArr2) {
        String keyGen = keyGen(i, iArr, iArr2, iNDArray);
        if (!this.exposureDiffCache.containsKey(keyGen)) {
            ItemGroup findItemsPerGroupPerQuery = findItemsPerGroupPerQuery(iNDArray, iArr, i, iArr2);
            double normalizedExposure = normalizedExposure(findItemsPerGroupPerQuery.getProtectedItemsPerQuery(), findItemsPerGroupPerQuery.getJudgementsPerQuery());
            this.exposureDiffCache.put(keyGen, Double.valueOf(Math.max(0.0d, normalizedExposure(findItemsPerGroupPerQuery.getNonprotectedItemsPerQuery(), findItemsPerGroupPerQuery.getJudgementsPerQuery()) - normalizedExposure)));
        }
        return this.exposureDiffCache.get(keyGen).doubleValue();
    }

    private double normalizedExposure(INDArray iNDArray, INDArray iNDArray2) {
        return toppProt(iNDArray, iNDArray2).div(Double.valueOf(Math.log(2.0d))).sumNumber().doubleValue() / iNDArray.length();
    }

    private INDArray toppProt(INDArray iNDArray, INDArray iNDArray2) {
        return Transforms.exp(iNDArray).div(Transforms.exp(iNDArray2).sumNumber());
    }

    private ItemGroup findItemsPerGroupPerQuery(INDArray iNDArray, int[] iArr, int i, int[] iArr2) {
        String keyGen = keyGen(i, iArr, iArr2, iNDArray);
        if (!this.itemsPerGroupPerQueryCache.containsKey(keyGen)) {
            INDArray findItemsPerQuery = findItemsPerQuery(iNDArray, iArr, i);
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            int i2 = 0;
            for (int i3 = 0; i3 < iArr2.length; i3++) {
                if (iArr[i3] == i) {
                    if (iArr2[i3] == 0) {
                        arrayList2.add(Integer.valueOf(i2));
                    } else {
                        arrayList.add(Integer.valueOf(i2));
                    }
                    i2++;
                }
            }
            this.itemsPerGroupPerQueryCache.put(keyGen, new ItemGroup(findItemsPerQuery, findItemsPerQuery.getRows(ArrayUtil.toArray(arrayList)), findItemsPerQuery.getRows(ArrayUtil.toArray(arrayList2))));
        }
        return this.itemsPerGroupPerQueryCache.get(keyGen);
    }

    private INDArray findItemsPerQuery(INDArray iNDArray, int[] iArr, int i) {
        String keyGen = keyGen(i, iArr, iNDArray);
        if (!this.itemsPerQueryCache.containsKey(keyGen)) {
            this.itemsPerQueryCache.put(keyGen, iNDArray.getRows(IntStream.range(0, iArr.length).filter(i2 -> {
                return iArr[i2] == i;
            }).toArray()));
        }
        return this.itemsPerQueryCache.get(keyGen);
    }

    private static String keyGen(INDArray iNDArray) {
        return String.format("%d", Integer.valueOf(iNDArray.hashCode()));
    }

    private static String keyGen(int i, INDArray iNDArray) {
        return String.format("%d-%d", Integer.valueOf(i), Integer.valueOf(iNDArray.hashCode()));
    }

    private static String keyGen(int i, int[] iArr, INDArray iNDArray) {
        return String.format("%d-%d-%d", Integer.valueOf(i), Integer.valueOf(iArr.hashCode()), Integer.valueOf(iNDArray.hashCode()));
    }

    private static String keyGen(int i, int[] iArr, int[] iArr2, INDArray iNDArray) {
        return String.format("%d-%d-%d-%d", Integer.valueOf(i), Integer.valueOf(iArr.hashCode()), Integer.valueOf(iArr2.hashCode()), Integer.valueOf(iNDArray.hashCode()));
    }

    private static String keyGen(int i, int[] iArr, int[] iArr2, INDArray iNDArray, INDArray iNDArray2) {
        return String.format("%d-%d-%d-%d-%d", Integer.valueOf(i), Integer.valueOf(iArr.hashCode()), Integer.valueOf(iArr2.hashCode()), Integer.valueOf(iNDArray.hashCode()), Integer.valueOf(iNDArray2.hashCode()));
    }

    public List<TrainStep> getLog() {
        return this.log;
    }

    private static INDArray topp(INDArray iNDArray) {
        String keyGen = keyGen(iNDArray);
        if (!toppCache.containsKey(keyGen)) {
            INDArray exp = Transforms.exp(iNDArray);
            toppCache.put(keyGen, exp.div(exp.sumNumber()));
        }
        return toppCache.get(keyGen);
    }
}
