package net.loomchild.maligna.filter.aligner.align.hmm.fb;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import net.loomchild.maligna.calculator.Calculator;
import net.loomchild.maligna.coretypes.Alignment;
import net.loomchild.maligna.coretypes.Category;
import net.loomchild.maligna.filter.aligner.align.AlignAlgorithm;
import net.loomchild.maligna.filter.aligner.align.hmm.Util;
import net.loomchild.maligna.matrix.Matrix;
import net.loomchild.maligna.matrix.MatrixFactory;
import net.loomchild.maligna.matrix.MatrixIterator;
import net.loomchild.maligna.progress.ProgressManager;
import net.loomchild.maligna.progress.ProgressMeter;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/* loaded from: input_file:net/loomchild/maligna/filter/aligner/align/hmm/fb/ForwardBackwardAlgorithm.class */
public class ForwardBackwardAlgorithm implements AlignAlgorithm {
    private Log log = LogFactory.getLog(ForwardBackwardAlgorithm.class);
    private Map<Category, Float> categoryMap;
    private Calculator calculator;
    private MatrixFactory matrixFactory;

    public ForwardBackwardAlgorithm(Calculator calculator, Map<Category, Float> map, MatrixFactory matrixFactory) {
        this.matrixFactory = matrixFactory;
        this.calculator = calculator;
        this.categoryMap = map;
    }

    @Override // net.loomchild.maligna.filter.aligner.align.AlignAlgorithm
    public List<Alignment> align(List<String> list, List<String> list2) {
        Matrix<Float> createMatrix = this.matrixFactory.createMatrix(list.size() + 1, list2.size() + 1);
        ProgressMeter progressMeter = new ProgressMeter("Forward-Backward Align", createMatrix.getSize() * 2);
        ProgressManager.getInstance().registerProgressMeter(progressMeter);
        MatrixIterator<Float> iterator = createMatrix.getIterator();
        while (iterator.hasNext()) {
            iterator.next();
            int x = iterator.getX();
            int y = iterator.getY();
            createMatrix.set(x, y, Float.valueOf(createForwardData(x, y, list, list2, createMatrix)));
            progressMeter.completeTask();
        }
        Matrix<Float> createMatrix2 = this.matrixFactory.createMatrix(list.size() + 1, list2.size() + 1);
        MatrixIterator<Float> iterator2 = createMatrix2.getIterator();
        iterator2.afterLast();
        while (iterator2.hasPrevious()) {
            iterator2.previous();
            int x2 = iterator2.getX();
            int y2 = iterator2.getY();
            createMatrix2.set(x2, y2, Float.valueOf(createBackwardData(x2, y2, list, list2, createMatrix2)));
            progressMeter.completeTask();
        }
        ArrayList arrayList = new ArrayList();
        float floatValue = createMatrix.get(list.size(), list2.size()).floatValue();
        int i = 0;
        int i2 = 0;
        while (true) {
            if (i >= list.size() && i2 >= list2.size()) {
                ProgressManager.getInstance().unregisterProgressMeter(progressMeter);
                return arrayList;
            }
            float f = Float.POSITIVE_INFINITY;
            Category category = null;
            for (Category category2 : this.categoryMap.keySet()) {
                int sourceSegmentCount = i + category2.getSourceSegmentCount();
                int targetSegmentCount = i2 + category2.getTargetSegmentCount();
                if (sourceSegmentCount <= list.size() && targetSegmentCount <= list2.size() && createMatrix.get(sourceSegmentCount, targetSegmentCount) != null && createMatrix2.get(sourceSegmentCount, targetSegmentCount) != null) {
                    float floatValue2 = (createMatrix.get(sourceSegmentCount, targetSegmentCount).floatValue() + createMatrix2.get(sourceSegmentCount, targetSegmentCount).floatValue()) - floatValue;
                    if (floatValue2 < f) {
                        f = floatValue2;
                        category = category2;
                    }
                }
            }
            arrayList.add(new Alignment(createSubList(list, i, i + category.getSourceSegmentCount()), createSubList(list2, i2, i2 + category.getTargetSegmentCount()), f));
            i += category.getSourceSegmentCount();
            i2 += category.getTargetSegmentCount();
            this.log.trace("(" + i + ", " + i2 + ") - s: " + f + " (" + Math.exp(-f) + ")");
        }
    }

    private float createForwardData(int i, int i2, List<String> list, List<String> list2, Matrix<Float> matrix) {
        ArrayList arrayList = new ArrayList(this.categoryMap.size());
        for (Map.Entry<Category, Float> entry : this.categoryMap.entrySet()) {
            Category key = entry.getKey();
            float floatValue = entry.getValue().floatValue();
            int sourceSegmentCount = i - key.getSourceSegmentCount();
            int targetSegmentCount = i2 - key.getTargetSegmentCount();
            if (Util.elementExists(matrix, sourceSegmentCount, targetSegmentCount)) {
                arrayList.add(Float.valueOf(floatValue + this.calculator.calculateScore(list.subList(sourceSegmentCount, i), list2.subList(targetSegmentCount, i2)) + matrix.get(sourceSegmentCount, targetSegmentCount).floatValue()));
            }
        }
        return net.loomchild.maligna.util.Util.scoreSum(arrayList);
    }

    private float createBackwardData(int i, int i2, List<String> list, List<String> list2, Matrix<Float> matrix) {
        ArrayList arrayList = new ArrayList(this.categoryMap.size());
        for (Map.Entry<Category, Float> entry : this.categoryMap.entrySet()) {
            Category key = entry.getKey();
            float floatValue = entry.getValue().floatValue();
            int sourceSegmentCount = i + key.getSourceSegmentCount();
            int targetSegmentCount = i2 + key.getTargetSegmentCount();
            if (Util.elementExists(matrix, sourceSegmentCount, targetSegmentCount)) {
                arrayList.add(Float.valueOf(floatValue + this.calculator.calculateScore(list.subList(i, sourceSegmentCount), list2.subList(i2, targetSegmentCount)) + matrix.get(sourceSegmentCount, targetSegmentCount).floatValue()));
            }
        }
        return net.loomchild.maligna.util.Util.scoreSum(arrayList);
    }

    private List<String> createSubList(List<String> list, int i, int i2) {
        return new ArrayList(list.subList(i, i2));
    }
}
