package com.github.fairsearch.deltr;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.github.fairsearch.deltr.models.DeltrDoc;
import com.github.fairsearch.deltr.models.DeltrTopDocs;
import com.github.fairsearch.deltr.models.TrainStep;
import com.github.fairsearch.deltr.parsers.DeltrDeserializer;
import com.google.common.primitives.Doubles;
import java.io.IOException;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

@JsonDeserialize(using = DeltrDeserializer.class)
/* loaded from: input_file:com/github/fairsearch/deltr/Deltr.class */
public class Deltr {
    protected static final Logger LOGGER = Logger.getLogger(Deltr.class.getName());

    @JsonProperty
    private double gamma;

    @JsonProperty("number_of_iterations")
    private int numberOfIterations;

    @JsonProperty("learning_rate")
    private double learningRate;

    @JsonProperty
    private double lambda;

    @JsonProperty("init_var")
    private double initVar;

    @JsonProperty("standardize")
    protected boolean shouldStandardize;

    @JsonProperty
    protected double mu;

    @JsonProperty
    protected double sigma;

    @JsonProperty
    protected double[] omega;

    @JsonIgnore
    protected List<TrainStep> log;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/github/fairsearch/deltr/Deltr$TrainerData.class */
    public static class TrainerData {
        private int[] queryIds;
        private int[] protectedElementFeature;
        private INDArray featureMatrix;
        private INDArray trainingScores;
        private int protectedElementFeatureIndex;

        private TrainerData() {
        }

        /* JADX INFO: Access modifiers changed from: private */
        public TrainerData append(TrainerData trainerData) {
            int[] iArr = this.queryIds;
            this.queryIds = new int[iArr.length + trainerData.queryIds.length];
            System.arraycopy(iArr, 0, this.queryIds, 0, iArr.length);
            System.arraycopy(trainerData.queryIds, 0, this.queryIds, iArr.length, trainerData.queryIds.length);
            int[] iArr2 = this.protectedElementFeature;
            this.protectedElementFeature = new int[iArr2.length + trainerData.protectedElementFeature.length];
            System.arraycopy(iArr2, 0, this.protectedElementFeature, 0, iArr2.length);
            System.arraycopy(trainerData.protectedElementFeature, 0, this.protectedElementFeature, iArr2.length, trainerData.protectedElementFeature.length);
            INDArray create = Nd4j.create(this.featureMatrix.rows() + trainerData.featureMatrix.rows(), this.featureMatrix.columns());
            INDArray create2 = Nd4j.create(this.trainingScores.rows() + trainerData.trainingScores.rows(), this.trainingScores.columns());
            for (int i = 0; i < this.featureMatrix.rows(); i++) {
                create.putRow(i, this.featureMatrix.getRow(i));
                create2.putRow(i, this.trainingScores.getRow(i));
            }
            for (int i2 = 0; i2 < trainerData.featureMatrix.rows(); i2++) {
                create.putRow(i2 + this.featureMatrix.rows(), trainerData.featureMatrix.getRow(i2));
                create2.putRow(i2 + this.featureMatrix.rows(), trainerData.trainingScores.getRow(i2));
            }
            this.featureMatrix = create;
            this.trainingScores = create2;
            return this;
        }
    }

    public Deltr(double d) {
        this(d, false);
    }

    public Deltr(double d, boolean z) {
        this(d, 3000, z);
    }

    public Deltr(double d, int i, boolean z) {
        this(d, i, 0.0010000000474974513d, 0.0010000000474974513d, 0.009999999776482582d, z);
    }

    public Deltr(double d, int i, double d2, double d3, double d4, boolean z) {
        this.mu = 0.0d;
        this.sigma = 0.0d;
        this.omega = null;
        this.log = null;
        this.gamma = d;
        this.numberOfIterations = i;
        this.learningRate = d2;
        this.lambda = d3;
        this.initVar = d4;
        this.shouldStandardize = z;
    }

    public Deltr(double d, int i, double d2, double d3, double d4, boolean z, double d5, double d6, double[] dArr) {
        this(d, i, d2, d3, d4, z);
        this.mu = d5;
        this.sigma = d6;
        this.omega = dArr;
    }

    public void train(List<DeltrTopDocs> list) {
        Trainer trainer = new Trainer(this.gamma, this.numberOfIterations, this.learningRate, this.lambda, this.initVar);
        TrainerData trainerData = null;
        Iterator<DeltrTopDocs> it = list.iterator();
        while (it.hasNext()) {
            TrainerData prepareData = prepareData(it.next());
            if (trainerData == null) {
                trainerData = prepareData;
            } else {
                trainerData.append(prepareData);
            }
        }
        if (this.shouldStandardize) {
            this.mu = trainerData.featureMatrix.meanNumber().doubleValue();
            this.sigma = trainerData.featureMatrix.stdNumber().doubleValue();
            trainerData.featureMatrix = trainerData.featureMatrix.sub(Double.valueOf(this.mu)).div(Double.valueOf(this.sigma));
            trainerData.featureMatrix.putColumn(trainerData.protectedElementFeatureIndex, Nd4j.create(IntStream.of(trainerData.protectedElementFeature).mapToDouble(i -> {
                return i;
            }).toArray()));
        }
        this.omega = trainer.train(trainerData.queryIds, trainerData.protectedElementFeature, trainerData.featureMatrix, trainerData.trainingScores);
        this.log = trainer.getLog();
    }

    public DeltrTopDocs rank(DeltrTopDocs deltrTopDocs) {
        if (this.omega == null) {
            throw new NullPointerException("You need to train a model first!");
        }
        if (this.shouldStandardize) {
            for (int i = 0; i < deltrTopDocs.size(); i++) {
                DeltrDoc doc = deltrTopDocs.doc(i);
                for (String str : doc.keys()) {
                    if (!str.equals(doc.protectedFeatureName())) {
                        doc.put(str, Double.valueOf((doc.feature(str).doubleValue() - this.mu) / this.sigma));
                    }
                }
            }
        }
        for (int i2 = 0; i2 < deltrTopDocs.size(); i2++) {
            DeltrDoc doc2 = deltrTopDocs.doc(i2);
            double d = 0.0d;
            for (int i3 = 0; i3 < doc2.size(); i3++) {
                d += doc2.feature(i3).doubleValue() * this.omega[i3];
            }
            doc2.rejudge(d);
        }
        deltrTopDocs.reorder();
        return deltrTopDocs;
    }

    private TrainerData prepareData(DeltrTopDocs deltrTopDocs) {
        TrainerData trainerData = new TrainerData();
        trainerData.queryIds = new int[deltrTopDocs.size()];
        Arrays.fill(trainerData.queryIds, deltrTopDocs.id());
        trainerData.protectedElementFeature = new int[deltrTopDocs.size()];
        trainerData.featureMatrix = Nd4j.create(deltrTopDocs.size(), deltrTopDocs.doc(0).size());
        trainerData.trainingScores = Nd4j.create(deltrTopDocs.size(), 1);
        trainerData.protectedElementFeatureIndex = deltrTopDocs.doc(0).protectedFeatureIndex();
        for (int i = 0; i < deltrTopDocs.size(); i++) {
            DeltrDoc doc = deltrTopDocs.doc(i);
            trainerData.protectedElementFeature[i] = doc.isProtected() ? 1 : 0;
            trainerData.featureMatrix.putRow(i, Nd4j.create(Doubles.toArray(doc.features())));
            trainerData.trainingScores.putScalar(i, doc.judgement());
        }
        return trainerData;
    }

    public double[] getOmega() {
        return this.omega;
    }

    public List<TrainStep> getLog() {
        return this.log;
    }

    public String toJson() {
        try {
            return new ObjectMapper().writeValueAsString(this);
        } catch (JsonProcessingException e) {
            LOGGER.severe(String.format("Exception in parsing: '%s'", e.getMessage()));
            return null;
        }
    }

    public static Deltr createFromJson(String str) {
        try {
            return (Deltr) new ObjectMapper().readValue(str, Deltr.class);
        } catch (IOException e) {
            LOGGER.severe(String.format("IOException in parsing: '%s'", e.getMessage()));
            return null;
        }
    }

    public String toString() {
        return "Deltr{gamma=" + this.gamma + ", numberOfIterations=" + this.numberOfIterations + ", learningRate=" + this.learningRate + ", lambda=" + this.lambda + ", initVar=" + this.initVar + ", shouldStandardize=" + this.shouldStandardize + ", mu=" + this.mu + ", sigma=" + this.sigma + ", omega=" + Arrays.toString(this.omega) + '}';
    }
}
