package com.github.keenon.loglinear.simple;

import com.github.keenon.loglinear.learning.BacktrackingAdaGradOptimizer;
import com.github.keenon.loglinear.learning.LogLikelihoodDifferentiableFunction;
import com.github.keenon.loglinear.model.ConcatVector;
import com.github.keenon.loglinear.model.ConcatVectorNamespace;
import com.github.keenon.loglinear.model.GraphicalModel;
import com.github.keenon.loglinear.storage.ModelLog;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.IdentityHashMap;
import java.util.Map;

/* loaded from: input_file:com/github/keenon/loglinear/simple/SimpleDurablePredictor.class */
public abstract class SimpleDurablePredictor<T extends Serializable> {
    public ConcatVector weights;
    public ConcatVectorNamespace namespace;
    public ModelLog log;
    private String weightsPath;
    private String namespacePath;
    protected Map<GraphicalModel, T> context = new IdentityHashMap();
    private boolean trainingRunning = false;

    public SimpleDurablePredictor(String str) throws IOException {
        File file = new File(str);
        if (!file.exists()) {
            file.mkdirs();
        }
        this.log = new ModelLog(str + "/model-log.ser");
        this.weightsPath = str + "/weights.ser";
        try {
            if (new File(this.weightsPath).exists()) {
                this.weights = ConcatVector.readFromStream(new FileInputStream(this.weightsPath));
            } else {
                this.weights = new ConcatVector(1);
            }
        } catch (Exception e) {
            System.err.println("weights.ser is corrupted, unable to read it");
            this.weights = new ConcatVector(1);
        }
        this.namespacePath = str + "/namespace.ser";
        try {
            if (new File(this.namespacePath).exists()) {
                this.namespace = (ConcatVectorNamespace) new ObjectInputStream(new FileInputStream(this.namespacePath)).readObject();
            } else {
                this.namespace = new ConcatVectorNamespace();
            }
        } catch (Exception e2) {
            System.err.println("namespace.ser is corrupted, unable to read it");
            this.namespace = new ConcatVectorNamespace();
        }
    }

    public GraphicalModel createModel(T t) {
        GraphicalModel createModelInternal = createModelInternal(t);
        this.context.put(createModelInternal, t);
        return createModelInternal;
    }

    public void blockForRetraining() {
        synchronized (this) {
            if (this.trainingRunning) {
                try {
                    wait();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    protected abstract GraphicalModel createModelInternal(T t);

    /* JADX INFO: Access modifiers changed from: protected */
    public void addLabeledTrainingExample(GraphicalModel graphicalModel) {
        this.log.add(graphicalModel);
        launchTrainingRunIfNotRunning();
    }

    protected abstract T restoreContextObjectFromModelTags(GraphicalModel graphicalModel);

    protected abstract void featurizeModel(GraphicalModel graphicalModel, T t);

    protected void launchTrainingRunIfNotRunning() {
        synchronized (this) {
            if (this.trainingRunning) {
                return;
            }
            this.trainingRunning = true;
            new Thread(() -> {
                GraphicalModel[] graphicalModelArr = new GraphicalModel[this.log.size()];
                for (int i = 0; i < graphicalModelArr.length; i++) {
                    graphicalModelArr[i] = this.log.get(i);
                }
                for (GraphicalModel graphicalModel : graphicalModelArr) {
                    if (!this.context.containsKey(graphicalModel)) {
                        this.context.put(graphicalModel, restoreContextObjectFromModelTags(graphicalModel));
                    }
                    graphicalModel.factors.clear();
                    featurizeModel(graphicalModel, this.context.get(graphicalModel));
                }
                this.weights = new BacktrackingAdaGradOptimizer().optimize(graphicalModelArr, new LogLikelihoodDifferentiableFunction(), this.weights, 0.01d, 0.001d, false);
                try {
                    FileOutputStream fileOutputStream = new FileOutputStream(this.weightsPath);
                    this.weights.writeToStream(fileOutputStream);
                    fileOutputStream.close();
                    ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(this.namespacePath));
                    objectOutputStream.writeObject(this.namespace);
                    objectOutputStream.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
                synchronized (this) {
                    this.trainingRunning = false;
                    this.notifyAll();
                    if (this.log.size() > graphicalModelArr.length) {
                        launchTrainingRunIfNotRunning();
                    }
                }
            }).start();
        }
    }
}
