package org.apache.solr.ltr.model;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.util.SolrPluginUtils;

/* loaded from: input_file:org/apache/solr/ltr/model/NeuralNetworkModel.class */
public class NeuralNetworkModel extends LTRScoringModel {
    private List<Layer> layers;

    /* loaded from: input_file:org/apache/solr/ltr/model/NeuralNetworkModel$Activation.class */
    protected interface Activation {
        float apply(float f);
    }

    /* loaded from: input_file:org/apache/solr/ltr/model/NeuralNetworkModel$DefaultLayer.class */
    public class DefaultLayer implements Layer {
        private int layerID;
        private float[][] weightMatrix;
        private int matrixRows;
        private int matrixCols;
        private float[] biasVector;
        private int numUnits;
        protected String activationStr;
        protected Activation activation;

        public DefaultLayer() {
            this.layerID = NeuralNetworkModel.this.layers.size();
        }

        public void setMatrix(Object obj) {
            List list = (List) obj;
            this.matrixRows = list.size();
            this.matrixCols = ((List) list.get(0)).size();
            this.weightMatrix = new float[this.matrixRows][this.matrixCols];
            for (int i = 0; i < this.matrixRows; i++) {
                for (int i2 = 0; i2 < this.matrixCols; i2++) {
                    this.weightMatrix[i][i2] = ((Double) ((List) list.get(i)).get(i2)).floatValue();
                }
            }
        }

        public void setBias(Object obj) {
            List list = (List) obj;
            this.numUnits = list.size();
            this.biasVector = new float[this.numUnits];
            for (int i = 0; i < this.numUnits; i++) {
                this.biasVector[i] = ((Double) list.get(i)).floatValue();
            }
        }

        public void setActivation(Object obj) {
            this.activationStr = (String) obj;
            String str = this.activationStr;
            boolean z = -1;
            switch (str.hashCode()) {
                case -135761730:
                    if (str.equals("identity")) {
                        z = 2;
                        break;
                    }
                    break;
                case 3496700:
                    if (str.equals("relu")) {
                        z = false;
                        break;
                    }
                    break;
                case 2088248974:
                    if (str.equals("sigmoid")) {
                        z = true;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    this.activation = new Activation() { // from class: org.apache.solr.ltr.model.NeuralNetworkModel.DefaultLayer.1
                        @Override // org.apache.solr.ltr.model.NeuralNetworkModel.Activation
                        public float apply(float f) {
                            if (f < 0.0f) {
                                return 0.0f;
                            }
                            return f;
                        }
                    };
                    return;
                case true:
                    this.activation = new Activation() { // from class: org.apache.solr.ltr.model.NeuralNetworkModel.DefaultLayer.2
                        @Override // org.apache.solr.ltr.model.NeuralNetworkModel.Activation
                        public float apply(float f) {
                            return (float) (1.0d / (1.0d + Math.exp(-f)));
                        }
                    };
                    return;
                case true:
                    this.activation = new Activation() { // from class: org.apache.solr.ltr.model.NeuralNetworkModel.DefaultLayer.3
                        @Override // org.apache.solr.ltr.model.NeuralNetworkModel.Activation
                        public float apply(float f) {
                            return f;
                        }
                    };
                    return;
                default:
                    this.activation = null;
                    return;
            }
        }

        @Override // org.apache.solr.ltr.model.NeuralNetworkModel.Layer
        public float[] calculateOutput(float[] fArr) {
            float[] fArr2 = new float[this.matrixRows];
            for (int i = 0; i < this.matrixRows; i++) {
                float f = this.biasVector[i];
                for (int i2 = 0; i2 < this.matrixCols; i2++) {
                    f += this.weightMatrix[i][i2] * fArr[i2];
                }
                fArr2[i] = this.activation.apply(f);
            }
            return fArr2;
        }

        @Override // org.apache.solr.ltr.model.NeuralNetworkModel.Layer
        public int validate(int i) throws ModelException {
            if (this.numUnits != this.matrixRows) {
                throw new ModelException("Dimension mismatch in model \"" + NeuralNetworkModel.this.name + "\". Layer " + Integer.toString(this.layerID) + " has " + Integer.toString(this.numUnits) + " bias weights but " + Integer.toString(this.matrixRows) + " weight matrix rows.");
            }
            if (this.activation == null) {
                throw new ModelException("Invalid activation function (\"" + this.activationStr + "\") in layer " + Integer.toString(this.layerID) + " of model \"" + NeuralNetworkModel.this.name + "\".");
            }
            if (i == this.matrixCols) {
                return this.matrixRows;
            }
            if (this.layerID == 0) {
                throw new ModelException("Dimension mismatch in model \"" + NeuralNetworkModel.this.name + "\". The input has " + Integer.toString(i) + " features, but the weight matrix for layer 0 has " + Integer.toString(this.matrixCols) + " columns.");
            }
            throw new ModelException("Dimension mismatch in model \"" + NeuralNetworkModel.this.name + "\". The weight matrix for layer " + Integer.toString(this.layerID - 1) + " has " + Integer.toString(i) + " rows, but the weight matrix for layer " + Integer.toString(this.layerID) + " has " + Integer.toString(this.matrixCols) + " columns.");
        }

        @Override // org.apache.solr.ltr.model.NeuralNetworkModel.Layer
        public String describe() {
            StringBuilder sb = new StringBuilder();
            sb.append("(matrix=").append(Integer.toString(this.matrixRows)).append('x').append(Integer.toString(this.matrixCols)).append(",activation=").append(this.activationStr).append(")");
            return sb.toString();
        }
    }

    /* loaded from: input_file:org/apache/solr/ltr/model/NeuralNetworkModel$Layer.class */
    public interface Layer {
        float[] calculateOutput(float[] fArr);

        int validate(int i) throws ModelException;

        String describe();
    }

    protected Layer createLayer(Object obj) {
        DefaultLayer defaultLayer = new DefaultLayer();
        if (obj != null) {
            SolrPluginUtils.invokeSetters(defaultLayer, ((Map) obj).entrySet());
        }
        return defaultLayer;
    }

    public void setLayers(Object obj) {
        this.layers = new ArrayList();
        Iterator it = ((List) obj).iterator();
        while (it.hasNext()) {
            this.layers.add(createLayer(it.next()));
        }
    }

    public NeuralNetworkModel(String str, List<Feature> list, List<Normalizer> list2, String str2, List<Feature> list3, Map<String, Object> map) {
        super(str, list, list2, str2, list3, map);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.solr.ltr.model.LTRScoringModel
    public void validate() throws ModelException {
        super.validate();
        int size = this.features.size();
        Iterator<Layer> it = this.layers.iterator();
        while (it.hasNext()) {
            size = it.next().validate(size);
        }
        if (size != 1) {
            throw new ModelException("The output matrix for model \"" + this.name + "\" has " + Integer.toString(size) + " rows, but should only have one.");
        }
    }

    @Override // org.apache.solr.ltr.model.LTRScoringModel
    public float score(float[] fArr) {
        float[] fArr2 = fArr;
        Iterator<Layer> it = this.layers.iterator();
        while (it.hasNext()) {
            fArr2 = it.next().calculateOutput(fArr2);
        }
        return fArr2[0];
    }

    @Override // org.apache.solr.ltr.model.LTRScoringModel
    public Explanation explain(LeafReaderContext leafReaderContext, int i, float f, List<Explanation> list) {
        StringBuilder sb = new StringBuilder();
        sb.append("(name=").append(getName());
        sb.append(",featureValues=[");
        for (int i2 = 0; i2 < list.size(); i2++) {
            Explanation explanation = list.get(i2);
            if (i2 > 0) {
                sb.append(',');
            }
            sb.append(this.features.get(i2).getName()).append('=').append(explanation.getValue());
        }
        sb.append("],layers=[");
        for (int i3 = 0; i3 < this.layers.size(); i3++) {
            if (i3 > 0) {
                sb.append(',');
            }
            sb.append(this.layers.get(i3).describe());
        }
        sb.append("])");
        return Explanation.match(f, sb.toString(), new Explanation[0]);
    }
}
