package ai.catboost.spark.impl;

import ai.catboost.spark.DataHelpers$;
import ai.catboost.spark.DatasetForTraining;
import ai.catboost.spark.DatasetForTrainingWithPairs;
import ai.catboost.spark.PoolFilesPaths;
import ai.catboost.spark.SparkHelpers$;
import ai.catboost.spark.UsualDatasetForTraining;
import java.net.InetSocketAddress;
import java.time.Duration;
import org.apache.spark.TaskContext$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructType;
import org.json4s.JsonAST;
import org.json4s.JsonDSL$;
import org.json4s.jackson.JsonMethods$;
import ru.yandex.catboost.spark.catboost4j_spark.core.src.native_impl.QuantizedFeaturesInfoPtr;
import ru.yandex.catboost.spark.catboost4j_spark.core.src.native_impl.TIntermediateDataMetaInfo;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.Iterable;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.mutable.HashMap;
import scala.concurrent.Await$;
import scala.concurrent.Future;
import scala.concurrent.duration.Duration$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ObjectRef;
import scala.runtime.java8.JFunction1;

/* compiled from: Workers.scala */
/* loaded from: input_file:ai/catboost/spark/impl/CatBoostWorkers$.class */
public final class CatBoostWorkers$ {
    public static CatBoostWorkers$ MODULE$;

    static {
        new CatBoostWorkers$();
    }

    public int $lessinit$greater$default$3() {
        return 0;
    }

    public CatBoostWorkers apply(SparkSession sparkSession, int i, Duration duration, Duration duration2, DatasetForTraining datasetForTraining, Seq<DatasetForTraining> seq, JsonAST.JObject jObject, String str, Future<Tuple2<PoolFilesPaths, PoolFilesPaths[]>> future) {
        JFunction1.mcVI.sp spVar;
        QuantizedFeaturesInfoPtr quantizedFeaturesInfo = datasetForTraining.srcPool().quantizedFeaturesInfo();
        Tuple3<DatasetForTraining, HashMap<String, Object>, Option<Object>> selectColumnsForTrainingAndReturnIndex = DataHelpers$.MODULE$.selectColumnsForTrainingAndReturnIndex(datasetForTraining, true, datasetForTraining instanceof DatasetForTrainingWithPairs, true, true);
        if (selectColumnsForTrainingAndReturnIndex == null) {
            throw new MatchError(selectColumnsForTrainingAndReturnIndex);
        }
        Tuple3 tuple3 = new Tuple3((DatasetForTraining) selectColumnsForTrainingAndReturnIndex._1(), (HashMap) selectColumnsForTrainingAndReturnIndex._2(), (Option) selectColumnsForTrainingAndReturnIndex._3());
        DatasetForTraining datasetForTraining2 = (DatasetForTraining) tuple3._1();
        HashMap hashMap = (HashMap) tuple3._2();
        Option option = (Option) tuple3._3();
        Seq seq2 = (Seq) seq.map(datasetForTraining3 -> {
            Tuple3<DatasetForTraining, HashMap<String, Object>, Option<Object>> selectColumnsForTrainingAndReturnIndex2 = DataHelpers$.MODULE$.selectColumnsForTrainingAndReturnIndex(datasetForTraining3, true, datasetForTraining3 instanceof DatasetForTrainingWithPairs, true, true);
            if (selectColumnsForTrainingAndReturnIndex2 != null) {
                return (DatasetForTraining) selectColumnsForTrainingAndReturnIndex2._1();
            }
            throw new MatchError(selectColumnsForTrainingAndReturnIndex2);
        }, Seq$.MODULE$.canBuildFrom());
        int threadCountForTask = SparkHelpers$.MODULE$.getThreadCountForTask(sparkSession);
        JsonAST.JObject $tilde = JsonDSL$.MODULE$.jobject2assoc(jObject).$tilde(JsonDSL$.MODULE$.pair2jvalue(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("thread_count"), BoxesRunTime.boxToInteger(threadCountForTask)), obj -> {
            return $anonfun$apply$2(BoxesRunTime.unboxToInt(obj));
        }));
        Option<Object> executorNativeMemoryLimit = SparkHelpers$.MODULE$.getExecutorNativeMemoryLimit(sparkSession);
        if (executorNativeMemoryLimit.isDefined()) {
            $tilde = JsonDSL$.MODULE$.jobject2assoc($tilde).$tilde(JsonDSL$.MODULE$.pair2jvalue(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("used_ram_limit"), new StringBuilder(2).append(BoxesRunTime.unboxToLong(executorNativeMemoryLimit.get()) / 1024).append("KB").toString()), str2 -> {
                return JsonDSL$.MODULE$.string2jvalue(str2);
            }));
        }
        String compact = JsonMethods$.MODULE$.compact($tilde);
        TIntermediateDataMetaInfo createDataMetaInfo = datasetForTraining.srcPool().createDataMetaInfo(datasetForTraining.srcPool().createDataMetaInfo$default$1());
        StructType mainDataSchema = datasetForTraining2.mainDataSchema();
        int size = 1 + seq2.size();
        if (datasetForTraining2 instanceof DatasetForTrainingWithPairs) {
            RDD cache = getCogroupedMainAndPairsRDDForAllDatasets((DatasetForTrainingWithPairs) datasetForTraining2, (Seq) seq2.map(datasetForTraining4 -> {
                return (DatasetForTrainingWithPairs) datasetForTraining4;
            }, Seq$.MODULE$.canBuildFrom())).cache();
            cache.count();
            Await$.MODULE$.result(future, Duration$.MODULE$.Inf());
            StructType schema = datasetForTraining.srcPool().pairsData().schema();
            spVar = i2 -> {
                InetSocketAddress inetSocketAddress = new InetSocketAddress(SparkHelpers$.MODULE$.getDriverHost(sparkSession), i2);
                cache.foreachPartition(iterator -> {
                    $anonfun$apply$6(inetSocketAddress, compact, quantizedFeaturesInfo, str, threadCountForTask, duration, duration2, size, hashMap, createDataMetaInfo, mainDataSchema, schema, option, iterator);
                    return BoxedUnit.UNIT;
                });
            };
        } else {
            RDD cache2 = getMergedDataFrameForAllDatasets((UsualDatasetForTraining) datasetForTraining2, (Seq) seq2.map(datasetForTraining5 -> {
                return (UsualDatasetForTraining) datasetForTraining5;
            }, Seq$.MODULE$.canBuildFrom())).cache();
            cache2.count();
            Await$.MODULE$.result(future, Duration$.MODULE$.Inf());
            spVar = i3 -> {
                InetSocketAddress inetSocketAddress = new InetSocketAddress(SparkHelpers$.MODULE$.getDriverHost(sparkSession), i3);
                cache2.foreachPartition(iterator -> {
                    $anonfun$apply$10(inetSocketAddress, compact, quantizedFeaturesInfo, str, threadCountForTask, duration, duration2, size, hashMap, createDataMetaInfo, mainDataSchema, option, iterator);
                    return BoxedUnit.UNIT;
                });
            };
        }
        return new CatBoostWorkers(sparkSession.sparkContext(), spVar, $lessinit$greater$default$3());
    }

    private RDD<Row> getMergedDataFrameForAllDatasets(UsualDatasetForTraining usualDatasetForTraining, Seq<UsualDatasetForTraining> seq) {
        ObjectRef create = ObjectRef.create(usualDatasetForTraining.data().rdd());
        seq.foreach(usualDatasetForTraining2 -> {
            $anonfun$getMergedDataFrameForAllDatasets$1(create, usualDatasetForTraining2);
            return BoxedUnit.UNIT;
        });
        return (RDD) create.elem;
    }

    private RDD<Tuple2<Tuple2<Object, Object>, Tuple2<Iterable<Iterable<Row>>, Iterable<Iterable<Row>>>>> getCogroupedMainAndPairsRDDForAllDatasets(DatasetForTrainingWithPairs datasetForTrainingWithPairs, Seq<DatasetForTrainingWithPairs> seq) {
        ObjectRef create = ObjectRef.create(datasetForTrainingWithPairs.data());
        seq.foreach(datasetForTrainingWithPairs2 -> {
            $anonfun$getCogroupedMainAndPairsRDDForAllDatasets$1(create, datasetForTrainingWithPairs2);
            return BoxedUnit.UNIT;
        });
        return (RDD) create.elem;
    }

    public static final /* synthetic */ JsonAST.JValue $anonfun$apply$2(int i) {
        return JsonDSL$.MODULE$.int2jvalue(i);
    }

    public static final /* synthetic */ void $anonfun$apply$6(InetSocketAddress inetSocketAddress, String str, QuantizedFeaturesInfoPtr quantizedFeaturesInfoPtr, String str2, int i, Duration duration, Duration duration2, int i2, HashMap hashMap, TIntermediateDataMetaInfo tIntermediateDataMetaInfo, StructType structType, StructType structType2, Option option, Iterator iterator) {
        new CatBoostWorker(TaskContext$.MODULE$.getPartitionId()).processPartition(inetSocketAddress, str, quantizedFeaturesInfoPtr, str2, i, duration, duration2, tLocalExecutor -> {
            return iterator.hasNext() ? DataHelpers$.MODULE$.loadQuantizedDatasetsWithPairs(0, i2, quantizedFeaturesInfoPtr, hashMap, tIntermediateDataMetaInfo, structType, structType2, option, tLocalExecutor, iterator, DataHelpers$.MODULE$.loadQuantizedDatasetsWithPairs$default$11(), DataHelpers$.MODULE$.loadQuantizedDatasetsWithPairs$default$12()) : new Tuple3((Object) null, (Object) null, (Object) null);
        });
    }

    public static final /* synthetic */ void $anonfun$apply$10(InetSocketAddress inetSocketAddress, String str, QuantizedFeaturesInfoPtr quantizedFeaturesInfoPtr, String str2, int i, Duration duration, Duration duration2, int i2, HashMap hashMap, TIntermediateDataMetaInfo tIntermediateDataMetaInfo, StructType structType, Option option, Iterator iterator) {
        new CatBoostWorker(TaskContext$.MODULE$.getPartitionId()).processPartition(inetSocketAddress, str, quantizedFeaturesInfoPtr, str2, i, duration, duration2, tLocalExecutor -> {
            return iterator.hasNext() ? DataHelpers$.MODULE$.loadQuantizedDatasets(i2, quantizedFeaturesInfoPtr, hashMap, tIntermediateDataMetaInfo, structType, option, tLocalExecutor, iterator, DataHelpers$.MODULE$.loadQuantizedDatasets$default$9(), DataHelpers$.MODULE$.loadQuantizedDatasets$default$10()) : new Tuple3((Object) null, (Object) null, (Object) null);
        });
    }

    public static final /* synthetic */ void $anonfun$getMergedDataFrameForAllDatasets$1(ObjectRef objectRef, UsualDatasetForTraining usualDatasetForTraining) {
        objectRef.elem = ((RDD) objectRef.elem).zipPartitions(usualDatasetForTraining.data().rdd(), true, (iterator, iterator2) -> {
            return iterator.$plus$plus(() -> {
                return iterator2;
            });
        }, ClassTag$.MODULE$.apply(Row.class), ClassTag$.MODULE$.apply(Row.class));
    }

    public static final /* synthetic */ void $anonfun$getCogroupedMainAndPairsRDDForAllDatasets$1(ObjectRef objectRef, DatasetForTrainingWithPairs datasetForTrainingWithPairs) {
        objectRef.elem = ((RDD) objectRef.elem).zipPartitions(datasetForTrainingWithPairs.data(), true, (iterator, iterator2) -> {
            return iterator.$plus$plus(() -> {
                return iterator2;
            });
        }, ClassTag$.MODULE$.apply(Tuple2.class), ClassTag$.MODULE$.apply(Tuple2.class));
    }

    private CatBoostWorkers$() {
        MODULE$ = this;
    }
}
