package org.apache.mahout.classifier.df;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.DataUtils;
import org.apache.mahout.classifier.df.data.Dataset;
import org.apache.mahout.classifier.df.data.Instance;
import org.apache.mahout.classifier.df.node.Node;

/* loaded from: input_file:org/apache/mahout/classifier/df/DecisionForest.class */
public class DecisionForest implements Writable {
    private final List<Node> trees;

    private DecisionForest() {
        this.trees = Lists.newArrayList();
    }

    public DecisionForest(List<Node> list) {
        Preconditions.checkArgument((list == null || list.isEmpty()) ? false : true, "trees argument must not be null or empty");
        this.trees = list;
    }

    List<Node> getTrees() {
        return this.trees;
    }

    public void classify(Data data, double[] dArr) {
        Preconditions.checkArgument(data.size() == dArr.length, "predictions.length must be equal to data.size()");
        if (data.isEmpty()) {
            return;
        }
        for (Node node : this.trees) {
            for (int i = 0; i < data.size(); i++) {
                dArr[i] = node.classify(data.get(i));
            }
        }
    }

    public double classify(Dataset dataset, Random random, Instance instance) {
        if (dataset.isNumerical(dataset.getLabelId())) {
            double d = 0.0d;
            int i = 0;
            Iterator<Node> it = this.trees.iterator();
            while (it.hasNext()) {
                double classify = it.next().classify(instance);
                if (classify != -1.0d) {
                    d += classify;
                    i++;
                }
            }
            return d / i;
        }
        int[] iArr = new int[dataset.nblabels()];
        Iterator<Node> it2 = this.trees.iterator();
        while (it2.hasNext()) {
            double classify2 = it2.next().classify(instance);
            if (classify2 != -1.0d) {
                int i2 = (int) classify2;
                iArr[i2] = iArr[i2] + 1;
            }
        }
        if (DataUtils.sum(iArr) == 0) {
            return -1.0d;
        }
        return DataUtils.maxindex(random, iArr);
    }

    public long meanNbNodes() {
        long j = 0;
        Iterator<Node> it = this.trees.iterator();
        while (it.hasNext()) {
            j += it.next().nbNodes();
        }
        return j / this.trees.size();
    }

    public long nbNodes() {
        long j = 0;
        Iterator<Node> it = this.trees.iterator();
        while (it.hasNext()) {
            j += it.next().nbNodes();
        }
        return j;
    }

    public long meanMaxDepth() {
        long j = 0;
        Iterator<Node> it = this.trees.iterator();
        while (it.hasNext()) {
            j += it.next().maxDepth();
        }
        return j / this.trees.size();
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!(obj instanceof DecisionForest)) {
            return false;
        }
        DecisionForest decisionForest = (DecisionForest) obj;
        return this.trees.size() == decisionForest.getTrees().size() && this.trees.containsAll(decisionForest.getTrees());
    }

    public int hashCode() {
        return this.trees.hashCode();
    }

    public void write(DataOutput dataOutput) throws IOException {
        dataOutput.writeInt(this.trees.size());
        Iterator<Node> it = this.trees.iterator();
        while (it.hasNext()) {
            it.next().write(dataOutput);
        }
    }

    public void readFields(DataInput dataInput) throws IOException {
        int readInt = dataInput.readInt();
        for (int i = 0; i < readInt; i++) {
            this.trees.add(Node.read(dataInput));
        }
    }

    private static DecisionForest read(DataInput dataInput) throws IOException {
        DecisionForest decisionForest = new DecisionForest();
        decisionForest.readFields(dataInput);
        return decisionForest;
    }

    public static DecisionForest load(Configuration configuration, Path path) throws IOException {
        FileSystem fileSystem = path.getFileSystem(configuration);
        DecisionForest decisionForest = null;
        for (Path path2 : fileSystem.getFileStatus(path).isDir() ? DFUtils.listOutputFiles(fileSystem, path) : new Path[]{path}) {
            FSDataInputStream fSDataInputStream = new FSDataInputStream(fileSystem.open(path2));
            if (decisionForest == null) {
                try {
                    decisionForest = read(fSDataInputStream);
                } finally {
                    Closeables.closeQuietly(fSDataInputStream);
                }
            } else {
                decisionForest.readFields(fSDataInputStream);
            }
        }
        return decisionForest;
    }
}
