package org.apache.mahout.classifier.sgd;

import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Locale;
import java.util.Map;
import org.apache.mahout.math.stats.GlobalOnlineAuc;
import org.apache.mahout.math.stats.GroupedOnlineAuc;
import org.apache.mahout.math.stats.OnlineAuc;

/* loaded from: input_file:org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.class */
public class AdaptiveLogisticModelParameters extends LogisticModelParameters {
    private AdaptiveLogisticRegression alr;
    private int interval = 800;
    private int averageWindow = 500;
    private int threads = 4;
    private String prior = "L1";
    private double priorOption = Double.NaN;
    private String auc = null;

    public AdaptiveLogisticRegression createAdaptiveLogisticRegression() {
        if (this.alr == null) {
            this.alr = new AdaptiveLogisticRegression(getMaxTargetCategories(), getNumFeatures(), createPrior(this.prior, this.priorOption));
            this.alr.setInterval(this.interval);
            this.alr.setAveragingWindow(this.averageWindow);
            this.alr.setThreadCount(this.threads);
            this.alr.setAucEvaluator(createAUC(this.auc));
        }
        return this.alr;
    }

    public void checkParameters() {
        if (this.prior != null) {
            String trim = this.prior.toUpperCase(Locale.ENGLISH).trim();
            if (("TP".equals(trim) || "EBP".equals(trim)) && Double.isNaN(this.priorOption)) {
                throw new IllegalArgumentException("You must specify a double value for TPrior and ElasticBandPrior.");
            }
        }
    }

    private static PriorFunction createPrior(String str, double d) {
        if (str == null) {
            return null;
        }
        if ("L1".equals(str.toUpperCase(Locale.ENGLISH).trim())) {
            return new L1();
        }
        if ("L2".equals(str.toUpperCase(Locale.ENGLISH).trim())) {
            return new L2();
        }
        if ("UP".equals(str.toUpperCase(Locale.ENGLISH).trim())) {
            return new UniformPrior();
        }
        if ("TP".equals(str.toUpperCase(Locale.ENGLISH).trim())) {
            return new TPrior(d);
        }
        if ("EBP".equals(str.toUpperCase(Locale.ENGLISH).trim())) {
            return new ElasticBandPrior(d);
        }
        return null;
    }

    private static OnlineAuc createAUC(String str) {
        if (str == null) {
            return null;
        }
        if ("GLOBAL".equals(str.toUpperCase(Locale.ENGLISH).trim())) {
            return new GlobalOnlineAuc();
        }
        if ("GROUPED".equals(str.toUpperCase(Locale.ENGLISH).trim())) {
            return new GroupedOnlineAuc();
        }
        return null;
    }

    @Override // org.apache.mahout.classifier.sgd.LogisticModelParameters
    public void saveTo(OutputStream outputStream) throws IOException {
        if (this.alr != null) {
            this.alr.close();
        }
        setTargetCategories(getCsvRecordFactory().getTargetCategories());
        write(new DataOutputStream(outputStream));
    }

    @Override // org.apache.mahout.classifier.sgd.LogisticModelParameters, org.apache.hadoop.io.Writable
    public void write(DataOutput dataOutput) throws IOException {
        dataOutput.writeUTF(getTargetVariable());
        dataOutput.writeInt(getTypeMap().size());
        for (Map.Entry<String, String> entry : getTypeMap().entrySet()) {
            dataOutput.writeUTF(entry.getKey());
            dataOutput.writeUTF(entry.getValue());
        }
        dataOutput.writeInt(getNumFeatures());
        dataOutput.writeInt(getMaxTargetCategories());
        dataOutput.writeInt(getTargetCategories().size());
        Iterator<String> it2 = getTargetCategories().iterator();
        while (it2.hasNext()) {
            dataOutput.writeUTF(it2.next());
        }
        dataOutput.writeInt(this.interval);
        dataOutput.writeInt(this.averageWindow);
        dataOutput.writeInt(this.threads);
        dataOutput.writeUTF(this.prior);
        dataOutput.writeDouble(this.priorOption);
        dataOutput.writeUTF(this.auc);
        this.alr.write(dataOutput);
    }

    @Override // org.apache.mahout.classifier.sgd.LogisticModelParameters, org.apache.hadoop.io.Writable
    public void readFields(DataInput dataInput) throws IOException {
        setTargetVariable(dataInput.readUTF());
        int readInt = dataInput.readInt();
        HashMap hashMap = new HashMap(readInt);
        for (int i = 0; i < readInt; i++) {
            hashMap.put(dataInput.readUTF(), dataInput.readUTF());
        }
        setTypeMap(hashMap);
        setNumFeatures(dataInput.readInt());
        setMaxTargetCategories(dataInput.readInt());
        int readInt2 = dataInput.readInt();
        ArrayList arrayList = new ArrayList(readInt2);
        for (int i2 = 0; i2 < readInt2; i2++) {
            arrayList.add(dataInput.readUTF());
        }
        setTargetCategories(arrayList);
        this.interval = dataInput.readInt();
        this.averageWindow = dataInput.readInt();
        this.threads = dataInput.readInt();
        this.prior = dataInput.readUTF();
        this.priorOption = dataInput.readDouble();
        this.auc = dataInput.readUTF();
        this.alr = new AdaptiveLogisticRegression();
        this.alr.readFields(dataInput);
    }

    private static AdaptiveLogisticModelParameters loadFromStream(InputStream inputStream) throws IOException {
        AdaptiveLogisticModelParameters adaptiveLogisticModelParameters = new AdaptiveLogisticModelParameters();
        adaptiveLogisticModelParameters.readFields(new DataInputStream(inputStream));
        return adaptiveLogisticModelParameters;
    }

    public static AdaptiveLogisticModelParameters loadFromFile(File file) throws IOException {
        FileInputStream fileInputStream = new FileInputStream(file);
        Throwable th = null;
        try {
            AdaptiveLogisticModelParameters loadFromStream = loadFromStream(fileInputStream);
            if (fileInputStream != null) {
                if (0 != 0) {
                    try {
                        fileInputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    fileInputStream.close();
                }
            }
            return loadFromStream;
        } catch (Throwable th3) {
            if (fileInputStream != null) {
                if (0 != 0) {
                    try {
                        fileInputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    fileInputStream.close();
                }
            }
            throw th3;
        }
    }

    public int getInterval() {
        return this.interval;
    }

    public void setInterval(int i) {
        this.interval = i;
    }

    public int getAverageWindow() {
        return this.averageWindow;
    }

    public void setAverageWindow(int i) {
        this.averageWindow = i;
    }

    public int getThreads() {
        return this.threads;
    }

    public void setThreads(int i) {
        this.threads = i;
    }

    public String getPrior() {
        return this.prior;
    }

    public void setPrior(String str) {
        this.prior = str;
    }

    public String getAuc() {
        return this.auc;
    }

    public void setAuc(String str) {
        this.auc = str;
    }

    public double getPriorOption() {
        return this.priorOption;
    }

    public void setPriorOption(double d) {
        this.priorOption = d;
    }
}
