/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.math.minimize;

import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import de.jungblut.datastructure.ArrayUtils;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.minimize.CostFunction;
import de.jungblut.math.minimize.CostGradientTuple;
import de.jungblut.math.sparse.SparseDoubleRowMatrix;
import de.jungblut.math.tuple.Tuple;
import de.jungblut.partition.Boundaries;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;

public abstract class AbstractMiniBatchCostFunction
implements CostFunction {
    private final Executor pool;
    private final List<Tuple<DoubleMatrix, DoubleMatrix>> batches;
    private final boolean stochastic;
    private int batchOffset = 0;

    public AbstractMiniBatchCostFunction(DoubleVector[] inputMatrix, DoubleVector[] outcomeMatrix, int batchSize, int numThreads) {
        this(inputMatrix, outcomeMatrix, batchSize, numThreads, false);
    }

    public AbstractMiniBatchCostFunction(DoubleVector[] inputMatrix, DoubleVector[] outcomeMatrix, int batchSize, int numThreads, boolean stochastic) {
        Preconditions.checkArgument((batchSize >= 0 && batchSize <= inputMatrix.length ? 1 : 0) != 0, (Object)("Batchsize wasn't in range of 0-" + inputMatrix.length));
        Preconditions.checkArgument((numThreads >= 1 ? 1 : 0) != 0, (Object)"#Threads need to be at least > 0");
        this.stochastic = stochastic;
        ThreadFactory factory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat("MiniBatch Worker %d").build();
        HashSet<Boundaries.Range> partitions = new HashSet<Boundaries.Range>();
        if (batchSize == 0) {
            numThreads = 1;
            batchSize = inputMatrix.length;
            partitions.add(new Boundaries.Range(0, batchSize - 1));
        } else {
            for (int offset = 0; offset < inputMatrix.length; offset += batchSize) {
                partitions.add(new Boundaries.Range(offset, Math.min(inputMatrix.length - 1, offset + (batchSize - 1))));
            }
        }
        this.pool = Executors.newFixedThreadPool(stochastic ? 1 : numThreads, factory);
        this.batches = new ArrayList<Tuple<DoubleMatrix, DoubleMatrix>>();
        for (Boundaries.Range r : partitions) {
            int start = r.getStart();
            int end = r.getEnd();
            DoubleVector[] featureSubArray = ArrayUtils.subArray(inputMatrix, start, end);
            boolean sparse = featureSubArray[0].isSparse();
            DenseDoubleMatrix outcomeMat = null;
            if (outcomeMatrix != null) {
                DoubleVector[] outcomeSubArray = ArrayUtils.subArray(outcomeMatrix, start, end);
                outcomeMat = new DenseDoubleMatrix(outcomeSubArray);
            }
            DenseDoubleVector bias = DenseDoubleVector.ones((int)featureSubArray.length);
            SparseDoubleRowMatrix featureMatrix = sparse ? new SparseDoubleRowMatrix(featureSubArray) : new DenseDoubleMatrix(featureSubArray);
            SparseDoubleRowMatrix featuresWithBias = sparse ? new SparseDoubleRowMatrix(bias, (DoubleMatrix)featureMatrix) : new DenseDoubleMatrix(bias, (DoubleMatrix)featureMatrix);
            this.batches.add((Tuple<DoubleMatrix, DoubleMatrix>)new Tuple((Object)featuresWithBias, (Object)outcomeMat));
        }
    }

    @Override
    public final CostGradientTuple evaluateCost(DoubleVector input) {
        if (this.batches.size() == 1) {
            CallableMiniBatch batch = new CallableMiniBatch(this.batches.get(0), input);
            try {
                return batch.call();
            }
            catch (Exception e) {
                e.printStackTrace();
                return null;
            }
        }
        ExecutorCompletionService<CostGradientTuple> completionService = new ExecutorCompletionService<CostGradientTuple>(this.pool);
        int submittedBatches = 0;
        for (int i = this.batchOffset; i < this.batches.size(); ++i) {
            completionService.submit(new CallableMiniBatch(this.batches.get(i), input));
            ++submittedBatches;
            if (!this.stochastic) continue;
            ++this.batchOffset;
            if (this.batchOffset < this.batches.size()) break;
            this.batchOffset = 0;
            break;
        }
        double costSum = 0.0;
        DenseDoubleVector gradientSum = new DenseDoubleVector(input.getLength());
        try {
            for (int i = 0; i < submittedBatches; ++i) {
                CostGradientTuple result = (CostGradientTuple)completionService.take().get();
                costSum += result.getCost();
                gradientSum = gradientSum.add(result.getGradient());
            }
        }
        catch (InterruptedException | ExecutionException e) {
            e.printStackTrace();
            return null;
        }
        return new CostGradientTuple(costSum / (double)submittedBatches, gradientSum.divide((double)submittedBatches));
    }

    protected abstract CostGradientTuple evaluateBatch(DoubleVector var1, DoubleMatrix var2, DoubleMatrix var3);

    class CallableMiniBatch
    implements Callable<CostGradientTuple> {
        private final DoubleVector parameters;
        private final Tuple<DoubleMatrix, DoubleMatrix> featureOutcome;

        public CallableMiniBatch(Tuple<DoubleMatrix, DoubleMatrix> featureOutcome, DoubleVector parameters) {
            this.featureOutcome = featureOutcome;
            this.parameters = parameters;
        }

        @Override
        public CostGradientTuple call() throws Exception {
            return AbstractMiniBatchCostFunction.this.evaluateBatch(this.parameters, (DoubleMatrix)this.featureOutcome.getFirst(), (DoubleMatrix)this.featureOutcome.getSecond());
        }
    }
}

