package com.github.keenon.loglinear.learning;

import com.github.keenon.loglinear.learning.AbstractFunction;
import com.github.keenon.loglinear.model.ConcatVector;
import com.github.keenon.loglinear.model.GraphicalModel;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.lang.management.ManagementFactory;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:com/github/keenon/loglinear/learning/AbstractBatchOptimizer.class */
public abstract class AbstractBatchOptimizer {

    /* loaded from: input_file:com/github/keenon/loglinear/learning/AbstractBatchOptimizer$GradientWorker.class */
    private static class GradientWorker<T> implements Runnable {
        ConcatVector localDerivative;
        TrainingWorker mainWorker;
        int threadIdx;
        int numThreads;
        List<T> queue;
        AbstractFunction<T> fn;
        ConcatVector weights;
        double localLogLikelihood = 0.0d;
        long jvmThreadId = 0;
        long finishedAtTime = 0;
        long cpuTimeRequired = 0;

        public GradientWorker(TrainingWorker<T> trainingWorker, int i, int i2, List<T> list, AbstractFunction<T> abstractFunction, ConcatVector concatVector) {
            this.mainWorker = trainingWorker;
            this.threadIdx = i;
            this.numThreads = i2;
            this.queue = list;
            this.fn = abstractFunction;
            this.weights = concatVector;
            this.localDerivative = new ConcatVector(concatVector.getNumberOfComponents());
        }

        @Override // java.lang.Runnable
        public void run() {
            long threadCpuTime = ManagementFactory.getThreadMXBean().getThreadCpuTime(this.jvmThreadId);
            Iterator<T> it = this.queue.iterator();
            while (it.hasNext()) {
                AbstractFunction.FunctionSummaryAtPoint summaryForInstance = this.fn.getSummaryForInstance(it.next(), this.weights);
                if (Double.isFinite(summaryForInstance.value)) {
                    this.localDerivative.addVectorInPlace(summaryForInstance.gradient, 1.0d);
                    this.localLogLikelihood += summaryForInstance.value;
                }
                if (this.mainWorker.isFinished) {
                    return;
                }
            }
            this.finishedAtTime = System.currentTimeMillis();
            this.cpuTimeRequired = ManagementFactory.getThreadMXBean().getThreadCpuTime(this.jvmThreadId) - threadCpuTime;
        }
    }

    /* loaded from: input_file:com/github/keenon/loglinear/learning/AbstractBatchOptimizer$OptimizationState.class */
    protected abstract class OptimizationState {
        /* JADX INFO: Access modifiers changed from: protected */
        public OptimizationState() {
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/github/keenon/loglinear/learning/AbstractBatchOptimizer$TrainingWorker.class */
    public class TrainingWorker<T> implements Runnable {
        ConcatVector weights;
        OptimizationState optimizationState;
        boolean isFinished = false;
        boolean useThreads;
        T[] dataset;
        AbstractFunction<T> fn;
        double l2regularization;
        static final /* synthetic */ boolean $assertionsDisabled;

        public TrainingWorker(T[] tArr, AbstractFunction<T> abstractFunction, ConcatVector concatVector, double d) {
            this.useThreads = Runtime.getRuntime().availableProcessors() > 1;
            this.optimizationState = AbstractBatchOptimizer.this.getFreshOptimizationState(concatVector);
            this.weights = concatVector.deepClone();
            this.dataset = tArr;
            this.fn = abstractFunction;
            this.l2regularization = d;
        }

        /* JADX WARN: Multi-variable type inference failed */
        private int estimateRelativeRuntime(T t) {
            if (!(t instanceof GraphicalModel)) {
                return 1;
            }
            int i = 0;
            Iterator<GraphicalModel.Factor> it = ((GraphicalModel) t).factors.iterator();
            while (it.hasNext()) {
                i += it.next().featuresTable.combinatorialNeighborStatesCount();
            }
            return i;
        }

        /* JADX WARN: Code restructure failed: missing block: B:69:0x0395, code lost:
        
            r11.isFinished = true;
         */
        /* JADX WARN: Code restructure failed: missing block: B:70:0x039a, code lost:
        
            return;
         */
        @Override // java.lang.Runnable
        /*
            Code decompiled incorrectly, please refer to instructions dump.
            To view partially-correct add '--show-bad-code' argument
        */
        public void run() {
            /*
                Method dump skipped, instructions count: 923
                To view this dump add '--comments-level debug' option
            */
            throw new UnsupportedOperationException("Method not decompiled: com.github.keenon.loglinear.learning.AbstractBatchOptimizer.TrainingWorker.run():void");
        }

        static {
            $assertionsDisabled = !AbstractBatchOptimizer.class.desiredAssertionStatus();
        }
    }

    public <T> ConcatVector optimize(T[] tArr, AbstractFunction<T> abstractFunction) {
        return optimize(tArr, abstractFunction, new ConcatVector(0), 0.0d);
    }

    public <T> ConcatVector optimize(T[] tArr, AbstractFunction<T> abstractFunction, ConcatVector concatVector, double d) {
        System.err.println("\n**************\nBeginning training\n");
        TrainingWorker trainingWorker = new TrainingWorker(tArr, abstractFunction, concatVector, d);
        new Thread(trainingWorker).start();
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(System.in));
        System.err.println("NOTE: you can press any key (and maybe ENTER afterwards to jog stdin) to terminate learning early.");
        System.err.println("The convergence criteria are quite aggressive if left uninterrupted, and will run for a while");
        System.err.println("if left to their own devices.\n");
        while (!trainingWorker.isFinished) {
            try {
            } catch (IOException e) {
                e.printStackTrace();
            }
            if (bufferedReader.ready()) {
                System.err.println("received quit command: quitting");
                System.err.println("training completed by interruption");
                trainingWorker.isFinished = true;
                return trainingWorker.weights;
            }
            continue;
        }
        System.err.println("training completed without interruption");
        return trainingWorker.weights;
    }

    public abstract boolean updateWeights(ConcatVector concatVector, ConcatVector concatVector2, double d, OptimizationState optimizationState);

    protected abstract OptimizationState getFreshOptimizationState(ConcatVector concatVector);
}
