package com.github.keenon.loglinear.learning;

import com.github.keenon.loglinear.model.ConcatVector;
import com.github.keenon.loglinear.model.GraphicalModel;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.lang.management.ManagementFactory;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadPoolExecutor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/keenon/loglinear/learning/AbstractBatchOptimizer.class */
public abstract class AbstractBatchOptimizer {
    private static final Logger log = LoggerFactory.getLogger(AbstractBatchOptimizer.class);
    List<Constraint> constraints = new ArrayList();

    /* loaded from: input_file:com/github/keenon/loglinear/learning/AbstractBatchOptimizer$Constraint.class */
    private static class Constraint {
        int component;
        boolean isSparse = true;
        int index;
        double value;
        double[] arr;

        public Constraint(int i, int i2, double d) {
            this.component = i;
            this.index = i2;
            this.value = d;
        }

        public Constraint(int i, double[] dArr) {
            this.component = i;
            this.arr = dArr;
        }

        public void applyToWeights(ConcatVector concatVector) {
            if (this.isSparse) {
                concatVector.setSparseComponent(this.component, this.index, this.value);
            } else {
                concatVector.setDenseComponent(this.component, this.arr);
            }
        }

        public void applyToDerivative(ConcatVector concatVector) {
            if (this.isSparse) {
                concatVector.setSparseComponent(this.component, this.index, 0.0d);
            } else {
                concatVector.setDenseComponent(this.component, new double[]{0.0d});
            }
        }
    }

    /* loaded from: input_file:com/github/keenon/loglinear/learning/AbstractBatchOptimizer$GradientWorker.class */
    private static class GradientWorker<T> implements Runnable {
        ConcatVector localDerivative;
        TrainingWorker mainWorker;
        int threadIdx;
        int numThreads;
        List<T> queue;
        AbstractDifferentiableFunction<T> fn;
        ConcatVector weights;
        double localLogLikelihood = 0.0d;
        long finishedAtTime = 0;
        long cpuTimeRequired = 0;

        public GradientWorker(TrainingWorker<T> trainingWorker, int i, int i2, List<T> list, AbstractDifferentiableFunction<T> abstractDifferentiableFunction, ConcatVector concatVector) {
            this.mainWorker = trainingWorker;
            this.threadIdx = i;
            this.numThreads = i2;
            this.queue = list;
            this.fn = abstractDifferentiableFunction;
            this.weights = concatVector;
            this.localDerivative = concatVector.newEmptyClone();
        }

        @Override // java.lang.Runnable
        public void run() {
            long threadCpuTime = ManagementFactory.getThreadMXBean().getThreadCpuTime(Thread.currentThread().getId());
            Iterator<T> it = this.queue.iterator();
            while (it.hasNext()) {
                this.localLogLikelihood += this.fn.getSummaryForInstance(it.next(), this.weights, this.localDerivative);
                if (this.mainWorker.isFinished) {
                    return;
                }
            }
            this.finishedAtTime = System.currentTimeMillis();
            this.cpuTimeRequired = ManagementFactory.getThreadMXBean().getThreadCpuTime(Thread.currentThread().getId()) - threadCpuTime;
        }
    }

    /* loaded from: input_file:com/github/keenon/loglinear/learning/AbstractBatchOptimizer$OptimizationState.class */
    protected abstract class OptimizationState {
        /* JADX INFO: Access modifiers changed from: protected */
        public OptimizationState() {
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/github/keenon/loglinear/learning/AbstractBatchOptimizer$TrainingWorker.class */
    public class TrainingWorker<T> implements Runnable {
        ConcatVector weights;
        OptimizationState optimizationState;
        boolean isFinished = false;
        boolean useThreads;
        T[] dataset;
        AbstractDifferentiableFunction<T> fn;
        double l2regularization;
        double convergenceDerivativeNorm;
        boolean quiet;
        ThreadPoolExecutor executor;
        final Object naturalTerminationBarrier;
        static final /* synthetic */ boolean $assertionsDisabled;

        public TrainingWorker(T[] tArr, AbstractDifferentiableFunction<T> abstractDifferentiableFunction, ConcatVector concatVector, double d, double d2, boolean z, ThreadPoolExecutor threadPoolExecutor) {
            this.useThreads = Runtime.getRuntime().availableProcessors() > 1;
            this.naturalTerminationBarrier = new Object();
            this.optimizationState = AbstractBatchOptimizer.this.getFreshOptimizationState(concatVector);
            this.weights = concatVector.deepClone();
            this.dataset = tArr;
            this.fn = abstractDifferentiableFunction;
            this.l2regularization = d;
            this.convergenceDerivativeNorm = d2;
            this.quiet = z;
            this.executor = threadPoolExecutor;
        }

        /* JADX WARN: Multi-variable type inference failed */
        private int estimateRelativeRuntime(T t) {
            if (!(t instanceof GraphicalModel)) {
                return 1;
            }
            int i = 0;
            Iterator<GraphicalModel.Factor> it = ((GraphicalModel) t).factors.iterator();
            while (it.hasNext()) {
                i += it.next().featuresTable.combinatorialNeighborStatesCount();
            }
            return i;
        }

        /* JADX WARN: Code restructure failed: missing block: B:82:0x03e2, code lost:
        
            r0 = r11.naturalTerminationBarrier;
         */
        /* JADX WARN: Code restructure failed: missing block: B:83:0x03e9, code lost:
        
            monitor-enter(r0);
         */
        /* JADX WARN: Code restructure failed: missing block: B:85:0x03ea, code lost:
        
            r11.naturalTerminationBarrier.notifyAll();
         */
        /* JADX WARN: Code restructure failed: missing block: B:86:0x03f3, code lost:
        
            monitor-exit(r0);
         */
        /* JADX WARN: Code restructure failed: missing block: B:88:0x03ff, code lost:
        
            r11.isFinished = true;
         */
        /* JADX WARN: Code restructure failed: missing block: B:89:0x0404, code lost:
        
            return;
         */
        @Override // java.lang.Runnable
        /*
            Code decompiled incorrectly, please refer to instructions dump.
            To view partially-correct add '--show-bad-code' argument
        */
        public void run() {
            /*
                Method dump skipped, instructions count: 1029
                To view this dump add '--comments-level debug' option
            */
            throw new UnsupportedOperationException("Method not decompiled: com.github.keenon.loglinear.learning.AbstractBatchOptimizer.TrainingWorker.run():void");
        }

        static {
            $assertionsDisabled = !AbstractBatchOptimizer.class.desiredAssertionStatus();
        }
    }

    public <T> ConcatVector optimize(T[] tArr, AbstractDifferentiableFunction<T> abstractDifferentiableFunction) {
        ThreadPoolExecutor threadPoolExecutor = (ThreadPoolExecutor) Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
        ConcatVector optimize = optimize(tArr, abstractDifferentiableFunction, new ConcatVector(0), 0.0d, 0.001d, false, threadPoolExecutor);
        threadPoolExecutor.shutdown();
        return optimize;
    }

    public <T> ConcatVector optimize(T[] tArr, AbstractDifferentiableFunction<T> abstractDifferentiableFunction, ConcatVector concatVector, double d, double d2, boolean z, ThreadPoolExecutor threadPoolExecutor) {
        if (z) {
            log.info("[Beginning quiet training]");
        } else {
            log.info("\n**************\nBeginning training\n");
        }
        TrainingWorker trainingWorker = new TrainingWorker(tArr, abstractDifferentiableFunction, concatVector, d, d2, z, threadPoolExecutor);
        new Thread(trainingWorker).start();
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(System.in));
        if (z) {
            while (!trainingWorker.isFinished) {
                synchronized (trainingWorker.naturalTerminationBarrier) {
                    try {
                        trainingWorker.naturalTerminationBarrier.wait();
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            }
            log.info("[Quiet training complete]");
            return trainingWorker.weights;
        }
        log.info("NOTE: you can press any key (and maybe ENTER afterwards to jog stdin) to terminate learning early.");
        log.info("The convergence criteria are quite aggressive if left uninterrupted, and will run for a while");
        log.info("if left to their own devices.\n");
        while (!trainingWorker.isFinished) {
            try {
            } catch (IOException e2) {
                e2.printStackTrace();
            }
            if (bufferedReader.ready()) {
                log.info("received quit command: quitting");
                log.info("training completed by interruption");
                trainingWorker.isFinished = true;
                return trainingWorker.weights;
            }
            continue;
        }
        log.info("training completed without interruption");
        return trainingWorker.weights;
    }

    public void addSparseConstraint(int i, int i2, double d) {
        this.constraints.add(new Constraint(i, i2, d));
    }

    public void addDenseConstraint(int i, double[] dArr) {
        this.constraints.add(new Constraint(i, dArr));
    }

    public abstract boolean updateWeights(ConcatVector concatVector, ConcatVector concatVector2, double d, OptimizationState optimizationState, boolean z);

    protected abstract OptimizationState getFreshOptimizationState(ConcatVector concatVector);
}
