/*
 * Decompiled with CFR 0.152.
 */
package ciir.umass.edu.learning.tree;

import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.tree.RegressionTree;
import ciir.umass.edu.learning.tree.Split;
import ciir.umass.edu.utilities.RankLibError;
import java.io.ByteArrayInputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import org.w3c.dom.Document;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

public class Ensemble {
    protected List<RegressionTree> trees = null;
    protected List<Float> weights = null;
    protected int[] features = null;

    public Ensemble() {
        this.trees = new ArrayList<RegressionTree>();
        this.weights = new ArrayList<Float>();
    }

    public Ensemble(Ensemble e) {
        this.trees = new ArrayList<RegressionTree>();
        this.weights = new ArrayList<Float>();
        this.trees.addAll(e.trees);
        this.weights.addAll(e.weights);
    }

    public Ensemble(String xmlRep) {
        try {
            int i;
            this.trees = new ArrayList<RegressionTree>();
            this.weights = new ArrayList<Float>();
            DocumentBuilderFactory dbFactory = DocumentBuilderFactory.newInstance();
            DocumentBuilder dBuilder = dbFactory.newDocumentBuilder();
            byte[] xmlDATA = xmlRep.getBytes();
            ByteArrayInputStream in = new ByteArrayInputStream(xmlDATA);
            Document doc = dBuilder.parse(in);
            NodeList nl = doc.getElementsByTagName("tree");
            HashMap<Integer, Integer> fids = new HashMap<Integer, Integer>();
            for (i = 0; i < nl.getLength(); ++i) {
                Node n = nl.item(i);
                Split root = this.create(n.getFirstChild(), fids);
                float weight = Float.parseFloat(n.getAttributes().getNamedItem("weight").getNodeValue());
                this.trees.add(new RegressionTree(root));
                this.weights.add(Float.valueOf(weight));
            }
            this.features = new int[fids.keySet().size()];
            i = 0;
            for (Integer fid : fids.keySet()) {
                this.features[i++] = fid;
            }
        }
        catch (Exception ex) {
            throw RankLibError.create("Error in Emsemble(xmlRepresentation): ", ex);
        }
    }

    public void add(RegressionTree tree, float weight) {
        this.trees.add(tree);
        this.weights.add(Float.valueOf(weight));
    }

    public RegressionTree getTree(int k) {
        return this.trees.get(k);
    }

    public float getWeight(int k) {
        return this.weights.get(k).floatValue();
    }

    public double variance() {
        double var = 0.0;
        for (RegressionTree tree : this.trees) {
            var += tree.variance();
        }
        return var;
    }

    public void remove(int k) {
        this.trees.remove(k);
        this.weights.remove(k);
    }

    public int treeCount() {
        return this.trees.size();
    }

    public int leafCount() {
        int count = 0;
        for (RegressionTree tree : this.trees) {
            count += tree.leaves().size();
        }
        return count;
    }

    public float eval(DataPoint dp) {
        float s2 = 0.0f;
        for (int i = 0; i < this.trees.size(); ++i) {
            s2 = (float)((double)s2 + this.trees.get(i).eval(dp) * (double)this.weights.get(i).floatValue());
        }
        return s2;
    }

    public String toString() {
        String strRep = "<ensemble>\n";
        for (int i = 0; i < this.trees.size(); ++i) {
            strRep = strRep + "\t<tree id=\"" + (i + 1) + "\" weight=\"" + this.weights.get(i) + "\">\n";
            strRep = strRep + this.trees.get(i).toString("\t\t");
            strRep = strRep + "\t</tree>\n";
        }
        strRep = strRep + "</ensemble>\n";
        return strRep;
    }

    public int[] getFeatures() {
        return this.features;
    }

    private Split create(Node n, HashMap<Integer, Integer> fids) {
        Split s2 = null;
        if (n.getFirstChild().getNodeName().compareToIgnoreCase("feature") == 0) {
            NodeList nl = n.getChildNodes();
            int fid = Integer.parseInt(nl.item(0).getFirstChild().getNodeValue().trim());
            fids.put(fid, 0);
            float threshold = Float.parseFloat(nl.item(1).getFirstChild().getNodeValue().trim());
            s2 = new Split(fid, threshold, 0.0);
            s2.setLeft(this.create(nl.item(2), fids));
            s2.setRight(this.create(nl.item(3), fids));
        } else {
            float output = Float.parseFloat(n.getFirstChild().getFirstChild().getNodeValue().trim());
            s2 = new Split();
            s2.setOutput(output);
        }
        return s2;
    }
}

