package org.apache.mahout.classifier.bayes.datastore;

import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.classifier.bayes.exceptions.InvalidDatastoreException;
import org.apache.mahout.classifier.bayes.interfaces.Datastore;
import org.apache.mahout.classifier.bayes.io.SequenceFileModelReader;
import org.apache.mahout.common.Parameters;

/* loaded from: input_file:WEB-INF/lib/mahout-core-0.2.jar:org/apache/mahout/classifier/bayes/datastore/InMemoryBayesDatastore.class */
public class InMemoryBayesDatastore implements Datastore {
    private Parameters params;
    private double alpha_i;
    private final Map<String, Map<String, Map<String, Double>>> matrices = new HashMap();
    private final Map<String, Map<String, Double>> vectors = new HashMap();
    private double thetaNormalizer = 1.0d;

    public InMemoryBayesDatastore(Parameters parameters) {
        this.params = null;
        this.alpha_i = 1.0d;
        this.matrices.put("weight", new HashMap());
        this.vectors.put("sumWeight", new HashMap());
        this.matrices.put("weight", new HashMap());
        this.vectors.put("labelWeight", new HashMap());
        this.vectors.put("thetaNormalizer", new HashMap());
        String str = parameters.get("basePath");
        this.params = parameters;
        parameters.set("sigma_j", str + "/trainer-weights/Sigma_j/part-*");
        parameters.set("sigma_k", str + "/trainer-weights/Sigma_k/part-*");
        parameters.set("sigma_kSigma_j", str + "/trainer-weights/Sigma_kSigma_j/part-*");
        parameters.set("thetaNormalizer", str + "/trainer-thetaNormalizer/part-*");
        parameters.set("weight", str + "/trainer-tfIdf/trainer-tfIdf/part-*");
        this.alpha_i = Double.valueOf(parameters.get("alpha_i", "1.0")).doubleValue();
    }

    @Override // org.apache.mahout.classifier.bayes.interfaces.Datastore
    public void initialize() throws InvalidDatastoreException {
        Configuration configuration = new Configuration();
        try {
            SequenceFileModelReader.loadModel(this, FileSystem.get(new Path(this.params.get("basePath")).toUri(), configuration), this.params, configuration);
            updateVocabCount();
            Collection<String> keys = getKeys("thetaNormalizer");
            Iterator<String> it = keys.iterator();
            while (it.hasNext()) {
                this.thetaNormalizer = Math.max(this.thetaNormalizer, Math.abs(vectorGetCell("thetaNormalizer", it.next())));
            }
            for (String str : keys) {
                System.out.println(str + ' ' + vectorGetCell("thetaNormalizer", str) + ' ' + this.thetaNormalizer + ' ' + (vectorGetCell("thetaNormalizer", str) / this.thetaNormalizer));
            }
        } catch (IOException e) {
            throw new InvalidDatastoreException(e.getMessage());
        }
    }

    @Override // org.apache.mahout.classifier.bayes.interfaces.Datastore
    public Collection<String> getKeys(String str) throws InvalidDatastoreException {
        return this.vectors.get("labelWeight").keySet();
    }

    @Override // org.apache.mahout.classifier.bayes.interfaces.Datastore
    public double getWeight(String str, String str2, String str3) throws InvalidDatastoreException {
        return matrixGetCell(str, str2, str3);
    }

    @Override // org.apache.mahout.classifier.bayes.interfaces.Datastore
    public double getWeight(String str, String str2) throws InvalidDatastoreException {
        if (str.equals("thetaNormalizer")) {
            return vectorGetCell(str, str2) / this.thetaNormalizer;
        }
        if (!str.equals("params")) {
            return vectorGetCell(str, str2);
        }
        if (str2.equals("alpha_i")) {
            return this.alpha_i;
        }
        throw new InvalidDatastoreException();
    }

    private double matrixGetCell(String str, String str2, String str3) throws InvalidDatastoreException {
        Map<String, Map<String, Double>> map = this.matrices.get(str);
        if (map == null) {
            throw new InvalidDatastoreException();
        }
        Map<String, Double> map2 = map.get(str2);
        if (map2 == null) {
            return 0.0d;
        }
        return nullToZero(map2.get(str3));
    }

    private double vectorGetCell(String str, String str2) throws InvalidDatastoreException {
        Map<String, Double> map = this.vectors.get(str);
        if (map == null) {
            throw new InvalidDatastoreException();
        }
        return nullToZero(map.get(str2));
    }

    private void matrixPutCell(String str, String str2, String str3, double d) {
        Map<String, Map<String, Double>> map = this.matrices.get(str);
        if (map == null) {
            map = new HashMap();
            this.matrices.put(str, map);
        }
        Map<String, Double> map2 = map.get(str2);
        if (map2 == null) {
            map2 = new HashMap();
            map.put(str2, map2);
        }
        map2.put(str3, Double.valueOf(d));
    }

    private void vectorPutCell(String str, String str2, double d) {
        Map<String, Double> map = this.vectors.get(str);
        if (map == null) {
            map = new HashMap();
            this.vectors.put(str, map);
        }
        map.put(str2, Double.valueOf(d));
    }

    private long sizeOfMatrix(String str) {
        if (this.matrices.get(str) == null) {
            return 0L;
        }
        return r0.size();
    }

    public void loadFeatureWeight(String str, String str2, double d) {
        matrixPutCell("weight", str, str2, d);
    }

    public void setSumFeatureWeight(String str, double d) {
        matrixPutCell("weight", str, "sigma_j", d);
    }

    public void setSumLabelWeight(String str, double d) {
        vectorPutCell("labelWeight", str, d);
    }

    public void setThetaNormalizer(String str, double d) {
        vectorPutCell("thetaNormalizer", str, d);
    }

    public void setSigma_jSigma_k(double d) {
        vectorPutCell("sumWeight", "sigma_jSigma_k", d);
    }

    public void updateVocabCount() {
        vectorPutCell("sumWeight", "vocabCount", sizeOfMatrix("weight"));
    }

    private static double nullToZero(Double d) {
        if (d == null) {
            return 0.0d;
        }
        return d.doubleValue();
    }
}
