/*
 * Decompiled with CFR 0.152.
 */
package edu.upenn.seas.mstparser;

import edu.upenn.seas.mstparser.DependencyInstance;
import edu.upenn.seas.mstparser.FeatureVector;

public class Parameters {
    private double SCORE = 0.0;
    public double[] parameters;
    public double[] total;
    public String lossType = "punc";

    public Parameters(int size) {
        this.parameters = new double[size];
        this.total = new double[size];
        for (int i = 0; i < this.parameters.length; ++i) {
            this.parameters[i] = 0.0;
            this.total[i] = 0.0;
        }
        this.lossType = "punc";
    }

    public void setLoss(String lt) {
        this.lossType = lt;
    }

    public void averageParams(double avVal) {
        int j = 0;
        while (j < this.total.length) {
            int n = j++;
            this.total[n] = this.total[n] * (1.0 / avVal);
        }
        this.parameters = this.total;
    }

    public void updateParamsMIRA(DependencyInstance inst, Object[][] d, double upd) {
        String actParseTree = inst.actParseTree;
        FeatureVector actFV = inst.fv;
        int K = 0;
        for (int i = 0; i < d.length && d[i][0] != null; ++i) {
            K = i + 1;
        }
        double[] b = new double[K];
        double[] lam_dist = new double[K];
        FeatureVector[] dist = new FeatureVector[K];
        for (int k = 0; k < K; ++k) {
            lam_dist[k] = this.getScore(actFV) - this.getScore((FeatureVector)d[k][0]);
            b[k] = this.numErrors(inst, (String)d[k][1], actParseTree);
            int n = k;
            b[n] = b[n] - lam_dist[k];
            dist[k] = actFV.getDistVector((FeatureVector)d[k][0]);
        }
        double[] alpha = this.hildreth(dist, b);
        FeatureVector fv = null;
        boolean res = false;
        for (int k = 0; k < K; ++k) {
            fv = dist[k];
            fv.update(this.parameters, this.total, alpha[k], upd);
        }
    }

    public double getScore(FeatureVector fv) {
        return fv.getScore(this.parameters);
    }

    private double[] hildreth(FeatureVector[] a, double[] b) {
        int i;
        int max_iter = 10000;
        double eps = 1.0E-8;
        double zero = 1.0E-12;
        double[] alpha = new double[b.length];
        double[] F = new double[b.length];
        double[] kkt = new double[b.length];
        double max_kkt = Double.NEGATIVE_INFINITY;
        int K = a.length;
        double[][] A = new double[K][K];
        boolean[] is_computed = new boolean[K];
        for (i = 0; i < K; ++i) {
            A[i][i] = a[i].dotProduct(a[i]);
            is_computed[i] = false;
        }
        int max_kkt_i = -1;
        for (i = 0; i < F.length; ++i) {
            F[i] = b[i];
            kkt[i] = F[i];
            if (!(kkt[i] > max_kkt)) continue;
            max_kkt = kkt[i];
            max_kkt_i = i;
        }
        for (int iter = 0; max_kkt >= eps && iter < max_iter; ++iter) {
            double diff_alpha = A[max_kkt_i][max_kkt_i] <= zero ? 0.0 : F[max_kkt_i] / A[max_kkt_i][max_kkt_i];
            double try_alpha = alpha[max_kkt_i] + diff_alpha;
            double add_alpha = 0.0;
            add_alpha = try_alpha < 0.0 ? -1.0 * alpha[max_kkt_i] : diff_alpha;
            alpha[max_kkt_i] = alpha[max_kkt_i] + add_alpha;
            if (!is_computed[max_kkt_i]) {
                for (i = 0; i < K; ++i) {
                    A[i][max_kkt_i] = a[i].dotProduct(a[max_kkt_i]);
                    is_computed[max_kkt_i] = true;
                }
            }
            for (i = 0; i < F.length; ++i) {
                int n = i;
                F[n] = F[n] - add_alpha * A[i][max_kkt_i];
                kkt[i] = F[i];
                if (!(alpha[i] > zero)) continue;
                kkt[i] = Math.abs(F[i]);
            }
            max_kkt = Double.NEGATIVE_INFINITY;
            max_kkt_i = -1;
            for (i = 0; i < F.length; ++i) {
                if (!(kkt[i] > max_kkt)) continue;
                max_kkt = kkt[i];
                max_kkt_i = i;
            }
        }
        return alpha;
    }

    public double numErrors(DependencyInstance inst, String pred, String act) {
        if (this.lossType.equals("nopunc")) {
            return this.numErrorsDepNoPunc(inst, pred, act) + this.numErrorsLabelNoPunc(inst, pred, act);
        }
        return this.numErrorsDep(inst, pred, act) + this.numErrorsLabel(inst, pred, act);
    }

    public double numErrorsDep(DependencyInstance inst, String pred, String act) {
        String[] act_spans = act.split(" ");
        String[] pred_spans = pred.split(" ");
        int correct = 0;
        for (int i = 0; i < pred_spans.length; ++i) {
            String a;
            String p = pred_spans[i].split(":")[0];
            if (!p.equals(a = act_spans[i].split(":")[0])) continue;
            ++correct;
        }
        return (double)act_spans.length - (double)correct;
    }

    public double numErrorsLabel(DependencyInstance inst, String pred, String act) {
        String[] act_spans = act.split(" ");
        String[] pred_spans = pred.split(" ");
        int correct = 0;
        for (int i = 0; i < pred_spans.length; ++i) {
            String a;
            String p = pred_spans[i].split(":")[1];
            if (!p.equals(a = act_spans[i].split(":")[1])) continue;
            ++correct;
        }
        return (double)act_spans.length - (double)correct;
    }

    public double numErrorsDepNoPunc(DependencyInstance inst, String pred, String act) {
        String[] act_spans = act.split(" ");
        String[] pred_spans = pred.split(" ");
        String[] pos = inst.postags;
        int correct = 0;
        int numPunc = 0;
        for (int i = 0; i < pred_spans.length; ++i) {
            String p = pred_spans[i].split(":")[0];
            String a = act_spans[i].split(":")[0];
            if (pos[i + 1].matches("[,:.'`]+")) {
                ++numPunc;
                continue;
            }
            if (!p.equals(a)) continue;
            ++correct;
        }
        return (double)act_spans.length - (double)numPunc - (double)correct;
    }

    public double numErrorsLabelNoPunc(DependencyInstance inst, String pred, String act) {
        String[] act_spans = act.split(" ");
        String[] pred_spans = pred.split(" ");
        String[] pos = inst.postags;
        int correct = 0;
        int numPunc = 0;
        for (int i = 0; i < pred_spans.length; ++i) {
            String p = pred_spans[i].split(":")[1];
            String a = act_spans[i].split(":")[1];
            if (pos[i + 1].matches("[,:.'`]+")) {
                ++numPunc;
                continue;
            }
            if (!p.equals(a)) continue;
            ++correct;
        }
        return (double)act_spans.length - (double)numPunc - (double)correct;
    }
}

