package com.linkedin.feathr.common.featurizeddataset;

import com.linkedin.feathr.common.tensor.Representable;
import com.linkedin.feathr.common.tensor.TensorCategory;
import com.linkedin.feathr.common.tensor.TensorData;
import com.linkedin.feathr.common.tensor.TensorType;
import com.linkedin.feathr.common.tensor.scalar.ScalarTensor;
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
import scala.collection.Seq;

/* loaded from: input_file:com/linkedin/feathr/common/featurizeddataset/SparkDeserializerFactory.class */
public final class SparkDeserializerFactory {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/linkedin/feathr/common/featurizeddataset/SparkDeserializerFactory$DenseDeserializer.class */
    public static class DenseDeserializer implements FeatureDeserializer {
        private final Representable[] _columnTypes;
        private final boolean _regular;

        DenseDeserializer(Representable[] representableArr, boolean z) {
            this._columnTypes = representableArr;
            this._regular = z;
        }

        @Override // com.linkedin.feathr.common.featurizeddataset.FeatureDeserializer
        public TensorData deserialize(Object obj) {
            if (obj == null) {
                return null;
            }
            return new FDSDenseTensorWrapper(this._columnTypes, this._regular, (Seq) obj);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/linkedin/feathr/common/featurizeddataset/SparkDeserializerFactory$ScalarDeserializer.class */
    public static class ScalarDeserializer implements FeatureDeserializer {
        private ScalarDeserializer() {
        }

        @Override // com.linkedin.feathr.common.featurizeddataset.FeatureDeserializer
        public TensorData deserialize(Object obj) {
            if (obj == null) {
                return null;
            }
            return ScalarTensor.wrap(obj);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/linkedin/feathr/common/featurizeddataset/SparkDeserializerFactory$SparseDeserializer.class */
    public static class SparseDeserializer implements FeatureDeserializer {
        private final Representable[] _columnTypes;

        SparseDeserializer(Representable[] representableArr) {
            this._columnTypes = representableArr;
        }

        @Override // com.linkedin.feathr.common.featurizeddataset.FeatureDeserializer
        public TensorData deserialize(Object obj) {
            if (obj == null) {
                return null;
            }
            return new FDSSparseTensorWrapper(this._columnTypes, (GenericRowWithSchema) obj);
        }
    }

    private SparkDeserializerFactory() {
    }

    public static FeatureDeserializer getFeatureDeserializer(TensorType tensorType) {
        FeatureDeserializer sparseDeserializer;
        TensorCategory tensorCategory = tensorType.getTensorCategory();
        if (tensorCategory == TensorCategory.DENSE) {
            if (tensorType.getDimensionTypes().isEmpty()) {
                sparseDeserializer = new ScalarDeserializer();
            } else {
                sparseDeserializer = new DenseDeserializer(tensorType.getColumnTypes(), tensorCategory == TensorCategory.DENSE);
            }
        } else {
            if (tensorCategory != TensorCategory.SPARSE) {
                throw new IllegalArgumentException("Unsupported tensor category " + tensorCategory);
            }
            sparseDeserializer = new SparseDeserializer(tensorType.getColumnTypes());
        }
        return sparseDeserializer;
    }
}
