package com.github.fairsearch.deltr;

import com.github.fairsearch.deltr.models.DeltrDocImpl;
import com.github.fairsearch.deltr.models.DeltrTopDocs;
import com.github.fairsearch.deltr.models.DeltrTopDocsImpl;
import com.github.fairsearch.deltr.models.TrainStep;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.TreeMap;
import java.util.stream.IntStream;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

@RunWith(JUnitParamsRunner.class)
/* loaded from: input_file:com/github/fairsearch/deltr/DeltrTests.class */
public class DeltrTests {
    private static final double OFFSET = 0.001d;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:com/github/fairsearch/deltr/DeltrTests$DeltrMock.class */
    private static class DeltrMock extends Deltr {
        public DeltrMock(double d) {
            super(d);
        }

        public void setOmega(double[] dArr) {
            this.omega = dArr;
        }

        public void setMu(double d) {
            this.mu = d;
        }

        public void setSigma(double d) {
            this.sigma = d;
        }
    }

    @Test
    public void testNd4j() {
        INDArray rand = Nd4j.rand(3, 5);
        Nd4j.zeros(3, 5);
        Nd4j.ones(5, 1);
        Nd4j.rand(5, 1);
        rand.put(0, 0, 1);
        rand.put(1, 0, 0);
        rand.put(2, 0, 1);
        TreeMap treeMap = new TreeMap();
        treeMap.put("0", Double.valueOf(1.0d));
        treeMap.put("1", Double.valueOf(1.0d));
        System.out.println(Nd4j.create(new double[]{152.0d}));
        System.out.println(Transforms.exp(Nd4j.create(new double[]{152.0d})));
    }

    @Test
    @Parameters({"1, test_data_1.csv, true", "1, test_data_1.csv, false"})
    public void testTrainFromFixtures(double d, String str, boolean z) {
        List<DeltrTopDocs> prepareData = prepareData(getClass().getResource(String.format("/fixtures/%s", str)).getFile());
        Deltr deltr = new Deltr(d, 10, z);
        deltr.train(prepareData);
        evaluateTrainer(deltr);
    }

    @Test
    @Parameters({"1, 20, 5, 1, 100, false", "1, 50, 10, 0.8, 500, false", "1, 1000, 3, 1, 1000, false", "2, 200, 4, 0.9, 300, false", "3, 100, 5, 1, 200, false", "4, 50, 6, 1, 100, false", "1, 20, 5, 1, 100, true", "1, 50, 10, 0.8, 500, true", "1, 1000, 3, 1, 1000, true", "2, 200, 4, 0.9, 300, true", "3, 100, 5, 1, 200, true", "4, 50, 6, 1, 100, true"})
    public void testTrainSyntheticData(int i, int i2, int i3, double d, int i4, boolean z) {
        List<DeltrTopDocs> generateDataset = new SyntheticDatasetCreator(i, i2, 2, i3).generateDataset();
        Deltr deltr = new Deltr(d, i4, z);
        deltr.train(generateDataset);
        evaluateTrainer(deltr);
    }

    private void evaluateTrainer(Deltr deltr) {
        if (!$assertionsDisabled && deltr.getOmega() == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && deltr.getLog() == null) {
            throw new AssertionError();
        }
        if (deltr.getLog().size() <= 1) {
            if (!$assertionsDisabled && deltr.getLog().get(0) == null) {
                throw new AssertionError();
            }
            return;
        }
        TrainStep trainStep = (TrainStep) deltr.getLog().get(0);
        for (int i = 1; i < deltr.getLog().size(); i++) {
            if (((TrainStep) deltr.getLog().get(i)).getTotalCost() > trainStep.getTotalCost()) {
                System.out.println(((TrainStep) deltr.getLog().get(i)).getTotalCost() + " <= " + trainStep.getTotalCost());
            }
            if (!$assertionsDisabled && (trainStep.getTotalCost() - ((TrainStep) deltr.getLog().get(i)).getTotalCost()) + OFFSET < 0.0d) {
                throw new AssertionError();
            }
            trainStep = (TrainStep) deltr.getLog().get(i);
        }
    }

    @Test
    public void testRankEmptyDeltr() {
        try {
            new Deltr(1.0d).rank((DeltrTopDocs) null);
            if ($assertionsDisabled) {
            } else {
                throw new AssertionError();
            }
        } catch (NullPointerException e) {
        } catch (Exception e2) {
            if (!$assertionsDisabled) {
                throw new AssertionError();
            }
        }
    }

    @Test
    @Parameters({"20, 5, false", "50, 10, false", "1000, 3, false", "20, 5, true", "50, 10, true", "1000, 3, true"})
    public void testRankDeltr(int i, int i2, boolean z) {
        DeltrTopDocs deltrTopDocs = new SyntheticDatasetCreator(1, i, 2, i2).generateDataset().get(0);
        double[] array = IntStream.range(0, i2).mapToDouble(i3 -> {
            return 10 * i3;
        }).toArray();
        DeltrMock deltrMock = new DeltrMock(1.0d);
        deltrMock.setOmega(array);
        if (z) {
            deltrMock.setMu(1.0d);
            deltrMock.setSigma(1.0d);
        }
        DeltrTopDocs rank = deltrMock.rank(deltrTopDocs);
        IntStream.range(0, i).parallel().forEach(i4 -> {
            deltrTopDocs.doc(i4).rejudge(IntStream.range(0, i2).parallel().mapToDouble(i4 -> {
                return array[i4] * deltrTopDocs.doc(i4).feature(i4).doubleValue();
            }).sum());
        });
        deltrTopDocs.reorder();
        IntStream.range(0, i).parallel().forEach(i5 -> {
            if (!$assertionsDisabled && deltrTopDocs.doc(i5).id() != rank.doc(i5).id()) {
                throw new AssertionError();
            }
        });
    }

    private List<DeltrTopDocs> prepareData(String str) {
        ArrayList arrayList = new ArrayList();
        int i = -1;
        int i2 = 0;
        TopDocs topDocs = null;
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            Throwable th = null;
            try {
                try {
                    bufferedReader.readLine().split(",");
                    while (true) {
                        String readLine = bufferedReader.readLine();
                        if (readLine == null) {
                            break;
                        }
                        String[] split = readLine.split(",");
                        int parseInt = Integer.parseInt(split[0]);
                        int parseInt2 = Integer.parseInt(split[1]);
                        double parseDouble = Double.parseDouble(split[2]);
                        float parseFloat = Float.parseFloat(split[3]);
                        if (parseInt != i) {
                            if (topDocs != null) {
                                topDocs.totalHits = i2;
                                arrayList.add(topDocs);
                            }
                            i = parseInt;
                            i2 = 0;
                            topDocs = new DeltrTopDocsImpl(parseInt);
                            topDocs.setMaxScore(parseFloat);
                        }
                        ScoreDoc deltrDocImpl = new DeltrDocImpl(i2, parseFloat, parseInt2 == 1);
                        deltrDocImpl.put("0", Boolean.valueOf(parseInt2 == 1));
                        deltrDocImpl.put("1", Double.valueOf(parseDouble));
                        i2++;
                        ScoreDoc[] scoreDocArr = ((DeltrTopDocsImpl) topDocs).scoreDocs;
                        ((DeltrTopDocsImpl) topDocs).scoreDocs = new ScoreDoc[scoreDocArr.length + 1];
                        System.arraycopy(scoreDocArr, 0, ((DeltrTopDocsImpl) topDocs).scoreDocs, 0, scoreDocArr.length);
                        ((DeltrTopDocsImpl) topDocs).scoreDocs[scoreDocArr.length] = deltrDocImpl;
                    }
                    arrayList.add(topDocs);
                    if (bufferedReader != null) {
                        if (0 != 0) {
                            try {
                                bufferedReader.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            bufferedReader.close();
                        }
                    }
                } finally {
                }
            } finally {
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        return arrayList;
    }

    static {
        $assertionsDisabled = !DeltrTests.class.desiredAssertionStatus();
    }
}
