package org.apache.spark.ml.evaluation;

import java.io.IOException;
import org.apache.spark.annotation.Experimental;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.ml.param.ParamValidators$;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.apache.spark.ml.util.DefaultParamsWritable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.ml.util.SchemaUtils$;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.StructType;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;

/* compiled from: MulticlassClassificationEvaluator.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005ed\u0001B\u0001\u0003\u00015\u0011\u0011%T;mi&\u001cG.Y:t\u00072\f7o]5gS\u000e\fG/[8o\u000bZ\fG.^1u_JT!a\u0001\u0003\u0002\u0015\u00154\u0018\r\\;bi&|gN\u0003\u0002\u0006\r\u0005\u0011Q\u000e\u001c\u0006\u0003\u000f!\tQa\u001d9be.T!!\u0003\u0006\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u0005Y\u0011aA8sO\u000e\u00011#\u0002\u0001\u000f%ii\u0002CA\b\u0011\u001b\u0005\u0011\u0011BA\t\u0003\u0005%)e/\u00197vCR|'\u000f\u0005\u0002\u001415\tAC\u0003\u0002\u0016-\u000511\u000f[1sK\u0012T!a\u0006\u0003\u0002\u000bA\f'/Y7\n\u0005e!\"\u0001\u0005%bgB\u0013X\rZ5di&|gnQ8m!\t\u00192$\u0003\u0002\u001d)\tY\u0001*Y:MC\n,GnQ8m!\tq\u0012%D\u0001 \u0015\t\u0001C!\u0001\u0003vi&d\u0017B\u0001\u0012 \u0005U!UMZ1vYR\u0004\u0016M]1ng^\u0013\u0018\u000e^1cY\u0016D\u0001\u0002\n\u0001\u0003\u0006\u0004%\t%J\u0001\u0004k&$W#\u0001\u0014\u0011\u0005\u001djcB\u0001\u0015,\u001b\u0005I#\"\u0001\u0016\u0002\u000bM\u001c\u0017\r\\1\n\u00051J\u0013A\u0002)sK\u0012,g-\u0003\u0002/_\t11\u000b\u001e:j]\u001eT!\u0001L\u0015)\u0007\r\nt\u0007\u0005\u00023k5\t1G\u0003\u00025\r\u0005Q\u0011M\u001c8pi\u0006$\u0018n\u001c8\n\u0005Y\u001a$!B*j]\u000e,\u0017%\u0001\u001d\u0002\u000bErSG\f\u0019\t\u0011i\u0002!\u0011!Q\u0001\n\u0019\nA!^5eA!\u001a\u0011(M\u001c\t\u000bu\u0002A\u0011\u0001 \u0002\rqJg.\u001b;?)\ty\u0004\t\u0005\u0002\u0010\u0001!)A\u0005\u0010a\u0001M!\u001a\u0001)M\u001c)\u0007q\nt\u0007C\u0003>\u0001\u0011\u0005A\tF\u0001@Q\r\u0019\u0015g\u000e\u0005\b\u000f\u0002\u0011\r\u0011\"\u0001I\u0003)iW\r\u001e:jG:\u000bW.Z\u000b\u0002\u0013B\u0019!j\u0013\u0014\u000e\u0003YI!\u0001\u0014\f\u0003\u000bA\u000b'/Y7)\u0007\u0019\u000bt\u0007\u0003\u0004P\u0001\u0001\u0006I!S\u0001\f[\u0016$(/[2OC6,\u0007\u0005K\u0002Oc]BQA\u0015\u0001\u0005\u0002\u0015\nQbZ3u\u001b\u0016$(/[2OC6,\u0007fA)2o!)Q\u000b\u0001C\u0001-\u0006i1/\u001a;NKR\u0014\u0018n\u0019(b[\u0016$\"a\u0016-\u000e\u0003\u0001AQ!\u0017+A\u0002\u0019\nQA^1mk\u0016D3\u0001V\u00198\u0011\u0015a\u0006\u0001\"\u0001^\u0003A\u0019X\r\u001e)sK\u0012L7\r^5p]\u000e{G\u000e\u0006\u0002X=\")\u0011l\u0017a\u0001M!\u001a1,M\u001c\t\u000b\u0005\u0004A\u0011\u00012\u0002\u0017M,G\u000fT1cK2\u001cu\u000e\u001c\u000b\u0003/\u000eDQ!\u00171A\u0002\u0019B3\u0001Y\u00198\u0011\u00151\u0007\u0001\"\u0011h\u0003!)g/\u00197vCR,GC\u00015l!\tA\u0013.\u0003\u0002kS\t1Ai\\;cY\u0016DQ\u0001\\3A\u00025\fq\u0001Z1uCN,G\u000f\r\u0002omB\u0019qN\u001d;\u000e\u0003AT!!\u001d\u0004\u0002\u0007M\fH.\u0003\u0002ta\n9A)\u0019;bg\u0016$\bCA;w\u0019\u0001!\u0011b^6\u0002\u0002\u0003\u0005)\u0011\u0001=\u0003\u0007}#\u0013'\u0005\u0002zyB\u0011\u0001F_\u0005\u0003w&\u0012qAT8uQ&tw\r\u0005\u0002){&\u0011a0\u000b\u0002\u0004\u0003:L\b\u0006B32\u0003\u0003\t#!a\u0001\u0002\u000bIr\u0003G\f\u0019\t\u000f\u0005\u001d\u0001\u0001\"\u0011\u0002\n\u0005q\u0011n\u001d'be\u001e,'OQ3ui\u0016\u0014XCAA\u0006!\rA\u0013QB\u0005\u0004\u0003\u001fI#a\u0002\"p_2,\u0017M\u001c\u0015\u0005\u0003\u000b\tt\u0007C\u0004\u0002\u0016\u0001!\t%a\u0006\u0002\t\r|\u0007/\u001f\u000b\u0004\u007f\u0005e\u0001\u0002CA\u000e\u0003'\u0001\r!!\b\u0002\u000b\u0015DHO]1\u0011\u0007)\u000by\"C\u0002\u0002\"Y\u0011\u0001\u0002U1sC6l\u0015\r\u001d\u0015\u0005\u0003'\tt\u0007K\u0002\u0001\u0003O\u00012AMA\u0015\u0013\r\tYc\r\u0002\r\u000bb\u0004XM]5nK:$\u0018\r\u001c\u0015\u0004\u0001E:taBA\u0019\u0005!\u0005\u00111G\u0001\"\u001bVdG/[2mCN\u001c8\t\\1tg&4\u0017nY1uS>tWI^1mk\u0006$xN\u001d\t\u0004\u001f\u0005UbAB\u0001\u0003\u0011\u0003\t9d\u0005\u0005\u00026\u0005e\u0012qHA#!\rA\u00131H\u0005\u0004\u0003{I#AB!osJ+g\r\u0005\u0003\u001f\u0003\u0003z\u0014bAA\"?\t)B)\u001a4bk2$\b+\u0019:b[N\u0014V-\u00193bE2,\u0007c\u0001\u0015\u0002H%\u0019\u0011\u0011J\u0015\u0003\u0019M+'/[1mSj\f'\r\\3\t\u000fu\n)\u0004\"\u0001\u0002NQ\u0011\u00111\u0007\u0005\t\u0003#\n)\u0004\"\u0011\u0002T\u0005!An\\1e)\ry\u0014Q\u000b\u0005\b\u0003/\ny\u00051\u0001'\u0003\u0011\u0001\u0018\r\u001e5)\u000b\u0005=\u0013'a\u0017\"\u0005\u0005u\u0013!B\u0019/m9\u0002\u0004BCA1\u0003k\t\t\u0011\"\u0003\u0002d\u0005Y!/Z1e%\u0016\u001cx\u000e\u001c<f)\t\t)\u0007\u0005\u0003\u0002h\u0005ETBAA5\u0015\u0011\tY'!\u001c\u0002\t1\fgn\u001a\u0006\u0003\u0003_\nAA[1wC&!\u00111OA5\u0005\u0019y%M[3di\"*\u0011QG\u0019\u0002\\!*\u0011qF\u0019\u0002\\\u0001")
@Experimental
/* loaded from: input_file:lib/spark-mllib_2.11-2.1.3.jar:org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.class */
public class MulticlassClassificationEvaluator extends Evaluator implements HasPredictionCol, HasLabelCol, DefaultParamsWritable {
    private final String uid;
    private final Param<String> metricName;
    private final Param<String> labelCol;
    private final Param<String> predictionCol;

    public static MLReader<MulticlassClassificationEvaluator> read() {
        return MulticlassClassificationEvaluator$.MODULE$.read();
    }

    public static MulticlassClassificationEvaluator load(String str) {
        return MulticlassClassificationEvaluator$.MODULE$.load(str);
    }

    @Override // org.apache.spark.ml.util.DefaultParamsWritable, org.apache.spark.ml.util.MLWritable
    public MLWriter write() {
        return DefaultParamsWritable.Cclass.write(this);
    }

    @Override // org.apache.spark.ml.util.MLWritable
    public void save(String str) throws IOException {
        MLWritable.Cclass.save(this, str);
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final Param<String> labelCol() {
        return this.labelCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final void org$apache$spark$ml$param$shared$HasLabelCol$_setter_$labelCol_$eq(Param param) {
        this.labelCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final String getLabelCol() {
        return HasLabelCol.Cclass.getLabelCol(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasPredictionCol
    public final Param<String> predictionCol() {
        return this.predictionCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasPredictionCol
    public final void org$apache$spark$ml$param$shared$HasPredictionCol$_setter_$predictionCol_$eq(Param param) {
        this.predictionCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasPredictionCol
    public final String getPredictionCol() {
        return HasPredictionCol.Cclass.getPredictionCol(this);
    }

    @Override // org.apache.spark.ml.util.Identifiable
    public String uid() {
        return this.uid;
    }

    public Param<String> metricName() {
        return this.metricName;
    }

    public String getMetricName() {
        return (String) $(metricName());
    }

    public MulticlassClassificationEvaluator setMetricName(String str) {
        return (MulticlassClassificationEvaluator) set((Param<Param<String>>) metricName(), (Param<String>) str);
    }

    public MulticlassClassificationEvaluator setPredictionCol(String str) {
        return (MulticlassClassificationEvaluator) set((Param<Param<String>>) predictionCol(), (Param<String>) str);
    }

    public MulticlassClassificationEvaluator setLabelCol(String str) {
        return (MulticlassClassificationEvaluator) set((Param<Param<String>>) labelCol(), (Param<String>) str);
    }

    @Override // org.apache.spark.ml.evaluation.Evaluator
    public double evaluate(Dataset<?> dataset) {
        double accuracy;
        StructType schema = dataset.schema();
        SchemaUtils$.MODULE$.checkColumnType(schema, (String) $(predictionCol()), DoubleType$.MODULE$, SchemaUtils$.MODULE$.checkColumnType$default$4());
        SchemaUtils$.MODULE$.checkNumericType(schema, (String) $(labelCol()), SchemaUtils$.MODULE$.checkNumericType$default$3());
        MulticlassMetrics multiclassMetrics = new MulticlassMetrics((RDD<Tuple2<Object, Object>>) dataset.select((Seq<Column>) Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col((String) $(predictionCol())), functions$.MODULE$.col((String) $(labelCol())).cast(DoubleType$.MODULE$)})).rdd().map(new MulticlassClassificationEvaluator$$anonfun$1(this), ClassTag$.MODULE$.apply(Tuple2.class)));
        String str = (String) $(metricName());
        if ("f1".equals(str)) {
            accuracy = multiclassMetrics.weightedFMeasure();
        } else if ("weightedPrecision".equals(str)) {
            accuracy = multiclassMetrics.weightedPrecision();
        } else if ("weightedRecall".equals(str)) {
            accuracy = multiclassMetrics.weightedRecall();
        } else {
            if (!"accuracy".equals(str)) {
                throw new MatchError(str);
            }
            accuracy = multiclassMetrics.accuracy();
        }
        return accuracy;
    }

    @Override // org.apache.spark.ml.evaluation.Evaluator
    public boolean isLargerBetter() {
        return true;
    }

    @Override // org.apache.spark.ml.evaluation.Evaluator, org.apache.spark.ml.param.Params
    public MulticlassClassificationEvaluator copy(ParamMap paramMap) {
        return (MulticlassClassificationEvaluator) defaultCopy(paramMap);
    }

    public MulticlassClassificationEvaluator(String str) {
        this.uid = str;
        HasPredictionCol.Cclass.$init$(this);
        HasLabelCol.Cclass.$init$(this);
        MLWritable.Cclass.$init$(this);
        DefaultParamsWritable.Cclass.$init$(this);
        this.metricName = new Param<>(this, "metricName", "metric name in evaluation (f1|weightedPrecision|weightedRecall|accuracy)", ParamValidators$.MODULE$.inArray(new String[]{"f1", "weightedPrecision", "weightedRecall", "accuracy"}));
        setDefault(Predef$.MODULE$.wrapRefArray(new ParamPair[]{metricName().$minus$greater("f1")}));
    }

    public MulticlassClassificationEvaluator() {
        this(Identifiable$.MODULE$.randomUID("mcEval"));
    }
}
