package org.apache.ctakes.temporal.ae.feature.selection;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Serializable;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.tcas.Annotation;
import org.cleartk.ml.Feature;
import org.cleartk.ml.Instance;
import org.cleartk.ml.feature.extractor.CleartkExtractorException;
import org.cleartk.ml.feature.transform.OneToOneTrainableExtractor_ImplBase;

/* loaded from: input_file:org/apache/ctakes/temporal/ae/feature/selection/ZscoreNormalizationExtractor.class */
public class ZscoreNormalizationExtractor<OUTCOME_T, FOCUS_T extends Annotation> extends OneToOneTrainableExtractor_ImplBase<OUTCOME_T> {
    private boolean isTrained;
    private Map<String, MeanStdPair> meanStdMap;

    /* loaded from: input_file:org/apache/ctakes/temporal/ae/feature/selection/ZscoreNormalizationExtractor$MeanStdPair.class */
    private static class MeanStdPair {
        public double mean;
        public double std;

        public MeanStdPair(double d, double d2) {
            this.mean = d;
            this.std = d2;
        }
    }

    /* loaded from: input_file:org/apache/ctakes/temporal/ae/feature/selection/ZscoreNormalizationExtractor$ZscoreRunningStat.class */
    public static class ZscoreRunningStat implements Serializable {
        private static final long serialVersionUID = 1;
        private List<Double> data;
        private double sum;
        private double mean;
        private int n;

        public ZscoreRunningStat() {
            clear();
        }

        public void add(double d) {
            this.n++;
            this.sum += d;
            this.mean = this.sum / this.n;
        }

        public void clear() {
            this.data = new ArrayList();
            this.sum = 0.0d;
            this.n = 0;
            this.mean = 0.0d;
        }

        public int getNumSamples() {
            return this.n;
        }

        private double getVariance() {
            double d = 0.0d;
            Iterator<Double> it = this.data.iterator();
            while (it.hasNext()) {
                double doubleValue = it.next().doubleValue();
                d += (this.mean - doubleValue) * (this.mean - doubleValue);
            }
            return d / this.n;
        }

        public double getStdDev() {
            return Math.sqrt(getVariance());
        }

        public double getMean() {
            return this.mean;
        }
    }

    public ZscoreNormalizationExtractor(String str) {
        super(str);
        this.isTrained = false;
    }

    public Feature transform(Feature feature) {
        String name = feature.getName();
        if (!(feature.getValue() instanceof Number)) {
            return feature;
        }
        MeanStdPair meanStdPair = this.meanStdMap.get(name);
        double d = 0.5d;
        double doubleValue = ((Number) feature.getValue()).doubleValue();
        if (meanStdPair != null) {
            d = (doubleValue - meanStdPair.mean) / meanStdPair.std;
        }
        return new Feature("Zscore_NORMED_" + name, Double.valueOf(d));
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void train(Iterable<Instance<OUTCOME_T>> iterable) {
        ZscoreRunningStat zscoreRunningStat;
        HashMap hashMap = new HashMap();
        Iterator<Instance<OUTCOME_T>> it = iterable.iterator();
        while (it.hasNext()) {
            for (Feature feature : it.next().getFeatures()) {
                String name = feature.getName();
                Object value = feature.getValue();
                if (value instanceof Number) {
                    if (hashMap.containsKey(name)) {
                        zscoreRunningStat = (ZscoreRunningStat) hashMap.get(name);
                    } else {
                        zscoreRunningStat = new ZscoreRunningStat();
                        hashMap.put(name, zscoreRunningStat);
                    }
                    zscoreRunningStat.add(((Number) value).doubleValue());
                } else {
                    System.err.println("Ignore non-numeric feature from normalization: " + name + " with Value: " + value);
                }
            }
        }
        this.meanStdMap = new HashMap();
        for (Map.Entry entry : hashMap.entrySet()) {
            ZscoreRunningStat zscoreRunningStat2 = (ZscoreRunningStat) entry.getValue();
            this.meanStdMap.put(entry.getKey(), new MeanStdPair(zscoreRunningStat2.getMean(), zscoreRunningStat2.getStdDev()));
        }
        this.isTrained = true;
    }

    public void save(URI uri) throws IOException {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(uri)));
        for (Map.Entry<String, MeanStdPair> entry : this.meanStdMap.entrySet()) {
            MeanStdPair value = entry.getValue();
            bufferedWriter.append((CharSequence) String.format(Locale.ROOT, "%s\t%f\t%f\n", entry.getKey(), Double.valueOf(value.mean), Double.valueOf(value.std)));
        }
        bufferedWriter.close();
    }

    public void load(URI uri) throws IOException {
        File file = new File(uri);
        this.meanStdMap = new HashMap();
        BufferedReader bufferedReader = new BufferedReader(new FileReader(file));
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                bufferedReader.close();
                this.isTrained = true;
                return;
            } else {
                String[] split = readLine.split("\\t");
                this.meanStdMap.put(split[0], new MeanStdPair(Double.parseDouble(split[1]), Double.parseDouble(split[2])));
            }
        }
    }

    public List<Feature> extract(JCas jCas, FOCUS_T focus_t) throws CleartkExtractorException {
        return null;
    }
}
