package dragon.ir.classification;

import dragon.ir.classification.featureselection.FeatureSelector;
import dragon.ir.classification.multiclass.CodeMatrix;
import dragon.ir.classification.multiclass.HingeLoss;
import dragon.ir.classification.multiclass.LossMultiClassDecoder;
import dragon.ir.classification.multiclass.MultiClassDecoder;
import dragon.ir.classification.multiclass.OVACodeMatrix;
import dragon.ir.index.IndexReader;
import dragon.matrix.Row;
import dragon.matrix.SparseMatrix;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import jnisvmlight.KernelParam;
import jnisvmlight.LabeledFeatureVector;
import jnisvmlight.LearnParam;
import jnisvmlight.SVMLightInterface;
import jnisvmlight.SVMLightModel;
import jnisvmlight.TrainingParameters;

/* loaded from: input_file:dragon/ir/classification/SVMLightClassifier.class */
public class SVMLightClassifier extends AbstractClassifier {
    private SVMLightModel[] arrModel;
    private LearnParam learnParam;
    private KernelParam kernelParam;
    private CodeMatrix codeMatrix;
    private MultiClassDecoder classDecoder;
    private double[] arrConfidence;
    private boolean scale;

    public SVMLightClassifier(String str) {
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(str));
            this.arrModel = new SVMLightModel[objectInputStream.readInt()];
            for (int i = 0; i < this.arrModel.length; i++) {
                this.arrModel[i] = (SVMLightModel) objectInputStream.readObject();
            }
            this.codeMatrix = (CodeMatrix) objectInputStream.readObject();
            this.classDecoder = (MultiClassDecoder) objectInputStream.readObject();
            this.classNum = objectInputStream.readInt();
            this.scale = objectInputStream.readBoolean();
            this.featureSelector = (FeatureSelector) objectInputStream.readObject();
            this.arrLabel = new String[this.classNum];
            for (int i2 = 0; i2 < this.arrLabel.length; i2++) {
                this.arrLabel[i2] = (String) objectInputStream.readObject();
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public SVMLightClassifier(IndexReader indexReader) {
        super(indexReader);
        this.learnParam = new LearnParam();
        this.kernelParam = new KernelParam();
        this.classDecoder = new LossMultiClassDecoder(new HingeLoss());
        this.codeMatrix = new OVACodeMatrix(1);
        this.classNum = 0;
        this.scale = false;
    }

    public SVMLightClassifier(SparseMatrix sparseMatrix) {
        super(sparseMatrix);
        this.learnParam = new LearnParam();
        this.kernelParam = new KernelParam();
        this.classDecoder = new LossMultiClassDecoder(new HingeLoss());
        this.codeMatrix = new OVACodeMatrix(1);
        this.classNum = 0;
        this.scale = false;
    }

    public void setUseLinearKernel() {
        this.kernelParam.kernel_type = 0L;
    }

    public void setUseRBFKernel() {
        this.kernelParam.kernel_type = 2L;
    }

    public void setUsePolynomialKernel() {
        this.kernelParam.kernel_type = 1L;
    }

    public void setUserSigmoidKernel() {
        this.kernelParam.kernel_type = 3L;
    }

    public void setScalingOption(boolean z) {
        this.scale = z;
    }

    public void setCodeMatrix(CodeMatrix codeMatrix) {
        this.codeMatrix = codeMatrix;
    }

    public void setMultiClassDecoder(MultiClassDecoder multiClassDecoder) {
        this.classDecoder = multiClassDecoder;
    }

    @Override // dragon.ir.classification.Classifier
    public int[] rank() {
        return this.classDecoder.rank();
    }

    @Override // dragon.ir.classification.Classifier
    public void train(DocClassSet docClassSet) {
        if (this.indexReader == null && this.doctermMatrix == null) {
            return;
        }
        try {
            trainFeatureSelector(docClassSet);
            this.arrLabel = new String[docClassSet.getClassNum()];
            for (int i = 0; i < docClassSet.getClassNum(); i++) {
                this.arrLabel[i] = docClassSet.getDocClass(i).getClassName();
            }
            this.classNum = docClassSet.getClassNum();
            this.codeMatrix.setClassNum(this.classNum);
            ArrayList[] arrayListArr = new ArrayList[this.classNum];
            TrainingParameters trainingParameters = new TrainingParameters(this.learnParam, this.kernelParam);
            SVMLightInterface sVMLightInterface = new SVMLightInterface();
            this.arrModel = new SVMLightModel[this.codeMatrix.getClassifierNum()];
            for (int i2 = 0; i2 < this.classNum; i2++) {
                arrayListArr[i2] = loadData(docClassSet.getDocClass(i2));
            }
            for (int i3 = 0; i3 < this.codeMatrix.getClassifierNum(); i3++) {
                LabeledFeatureVector[] loadData = loadData(arrayListArr, this.codeMatrix, i3);
                int i4 = 0;
                int i5 = 0;
                for (LabeledFeatureVector labeledFeatureVector : loadData) {
                    if (labeledFeatureVector.getLabel() > 0.0d) {
                        i4++;
                    } else {
                        i5++;
                    }
                }
                trainingParameters.getLearningParameters().svm_costratio = 1.0d;
                this.arrModel[i3] = sVMLightInterface.trainModel(loadData, trainingParameters);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override // dragon.ir.classification.Classifier
    public int classify(Row row) {
        LabeledFeatureVector loadData;
        if (this.arrModel == null || (loadData = loadData(row)) == null) {
            return -1;
        }
        if (this.arrConfidence == null || this.arrConfidence.length != this.codeMatrix.getClassifierNum()) {
            this.arrConfidence = new double[this.codeMatrix.getClassifierNum()];
        }
        for (int i = 0; i < this.codeMatrix.getClassifierNum(); i++) {
            this.arrConfidence[i] = this.arrModel[i].classify(loadData);
        }
        return this.classDecoder.decode(this.codeMatrix, this.arrConfidence);
    }

    public double[] getBinaryClassifierConfidence() {
        return this.arrConfidence;
    }

    @Override // dragon.ir.classification.Classifier
    public void saveModel(String str) {
        try {
            if (this.arrModel == null) {
                return;
            }
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(str));
            objectOutputStream.writeInt(this.arrModel.length);
            for (int i = 0; i < this.arrModel.length; i++) {
                this.arrModel[i].removeTrainingData();
                objectOutputStream.writeObject(this.arrModel[i]);
            }
            objectOutputStream.writeObject(this.codeMatrix);
            objectOutputStream.writeObject(this.classDecoder);
            objectOutputStream.writeInt(this.classNum);
            objectOutputStream.writeBoolean(this.scale);
            objectOutputStream.writeObject(this.featureSelector);
            for (int i2 = 0; i2 < this.classNum; i2++) {
                objectOutputStream.writeObject(getClassLabel(i2));
            }
            objectOutputStream.flush();
            objectOutputStream.close();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private LabeledFeatureVector[] loadData(ArrayList[] arrayListArr, CodeMatrix codeMatrix, int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < this.classNum; i2++) {
            int code = this.codeMatrix.getCode(i2, i);
            if (code != 0) {
                for (int i3 = 0; i3 < arrayListArr[i2].size(); i3++) {
                    LabeledFeatureVector labeledFeatureVector = (LabeledFeatureVector) arrayListArr[i2].get(i3);
                    labeledFeatureVector.setLabel(code);
                    arrayList.add(labeledFeatureVector);
                }
            }
        }
        LabeledFeatureVector[] labeledFeatureVectorArr = new LabeledFeatureVector[arrayList.size()];
        for (int i4 = 0; i4 < arrayList.size(); i4++) {
            labeledFeatureVectorArr[i4] = (LabeledFeatureVector) arrayList.get(i4);
        }
        arrayList.clear();
        return labeledFeatureVectorArr;
    }

    private ArrayList loadData(DocClass docClass) {
        ArrayList arrayList = new ArrayList(docClass.getDocNum());
        for (int i = 0; i < docClass.getDocNum(); i++) {
            LabeledFeatureVector loadData = loadData(getRow(docClass.getDoc(i).getIndex()));
            if (loadData != null) {
                arrayList.add(loadData);
            }
        }
        return arrayList;
    }

    protected LabeledFeatureVector loadData(Row row) {
        int i;
        if (row == null) {
            return null;
        }
        int i2 = 0;
        for (int i3 = 0; i3 < row.getNonZeroNum(); i3++) {
            if (this.featureSelector.map(row.getNonZeroColumn(i3)) >= 0) {
                i2++;
            }
        }
        if (i2 == 0) {
            return null;
        }
        int[] iArr = new int[i2];
        double[] dArr = new double[i2];
        int i4 = 0;
        for (int i5 = 0; i5 < row.getNonZeroNum(); i5++) {
            int map = this.featureSelector.map(row.getNonZeroColumn(i5));
            if (map >= 0) {
                iArr[i4] = map + 1;
                dArr[i4] = row.getNonZeroDoubleScore(i5);
                i4++;
            }
        }
        if (this.scale) {
            double d = 0.0d;
            for (int i6 = 0; i6 < i4; i6 = i + 1) {
                d = Math.sqrt(d + (dArr[i6] * dArr[i6]));
                i = 0;
                while (i < i4) {
                    dArr[i] = dArr[i] / d;
                    i++;
                }
            }
        }
        return new LabeledFeatureVector(1.0d, iArr, dArr);
    }
}
