package org.apache.joshua.decoder.ff.fragmentlm;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Stack;
import org.apache.joshua.decoder.JoshuaConfiguration;
import org.apache.joshua.decoder.chart_parser.SourcePath;
import org.apache.joshua.decoder.ff.FeatureFunction;
import org.apache.joshua.decoder.ff.FeatureVector;
import org.apache.joshua.decoder.ff.StatefulFF;
import org.apache.joshua.decoder.ff.state_maintenance.DPState;
import org.apache.joshua.decoder.ff.tm.OwnerId;
import org.apache.joshua.decoder.ff.tm.OwnerMap;
import org.apache.joshua.decoder.ff.tm.Rule;
import org.apache.joshua.decoder.ff.tm.format.HieroFormatReader;
import org.apache.joshua.decoder.hypergraph.HGNode;
import org.apache.joshua.decoder.hypergraph.HyperEdge;
import org.apache.joshua.decoder.segment_file.Sentence;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:joshua-incubating-6.1.jar:org/apache/joshua/decoder/ff/fragmentlm/FragmentLMFF.class */
public class FragmentLMFF extends StatefulFF {
    private static final Logger LOG = LoggerFactory.getLogger(FragmentLMFF.class);
    private int BUILD_DEPTH;
    private int MAX_DEPTH;
    private int MIN_LEX_DEPTH;
    private HashMap<String, ArrayList<Tree>> lmFragments;
    private int numFragments;
    private String fragmentLMFile;

    /* loaded from: input_file:joshua-incubating-6.1.jar:org/apache/joshua/decoder/ff/fragmentlm/FragmentLMFF$FragmentState.class */
    public class FragmentState extends DPState {
        private Tree tree;

        public FragmentState(Tree tree) {
            this.tree = null;
            this.tree = tree;
        }

        @Override // org.apache.joshua.decoder.ff.state_maintenance.DPState
        public int hashCode() {
            return this.tree.hashCode();
        }

        @Override // org.apache.joshua.decoder.ff.state_maintenance.DPState
        public boolean equals(Object obj) {
            return (obj instanceof FragmentState) && this == obj;
        }

        @Override // org.apache.joshua.decoder.ff.state_maintenance.DPState
        public String toString() {
            return String.format("[FragmentState %s]", this.tree);
        }
    }

    public FragmentLMFF(FeatureVector featureVector, String[] strArr, JoshuaConfiguration joshuaConfiguration) {
        super(featureVector, "FragmentLMFF", strArr, joshuaConfiguration);
        this.BUILD_DEPTH = 1;
        this.MAX_DEPTH = 0;
        this.MIN_LEX_DEPTH = 1;
        this.lmFragments = null;
        this.numFragments = 0;
        this.fragmentLMFile = "";
        this.lmFragments = new HashMap<>();
        this.fragmentLMFile = this.parsedArgs.get("lm");
        this.BUILD_DEPTH = Integer.parseInt(this.parsedArgs.get("build-depth"));
        this.MAX_DEPTH = Integer.parseInt(this.parsedArgs.get("max-depth"));
        this.MIN_LEX_DEPTH = Integer.parseInt(this.parsedArgs.get("min-lex-depth"));
        try {
            PennTreebankReader.readTrees(this.fragmentLMFile).forEach(this::addLMFragment);
            LOG.info("FragmentLMFF: Read {} LM fragments from '{}'", Integer.valueOf(this.numFragments), this.fragmentLMFile);
        } catch (IOException e) {
            throw new RuntimeException(String.format("* WARNING: couldn't read fragment LM file '%s'", this.fragmentLMFile), e);
        }
    }

    public void addLMFragment(Tree tree) {
        if (this.lmFragments == null) {
            return;
        }
        int depth = tree.getDepth();
        if (this.MAX_DEPTH != 0 && depth > this.MAX_DEPTH) {
            LOG.warn("Skipping fragment {} (depth {} > {})", new Object[]{tree, Integer.valueOf(depth), Integer.valueOf(this.MAX_DEPTH)});
            return;
        }
        if (this.MIN_LEX_DEPTH > 1 && tree.isLexicalized() && depth < this.MIN_LEX_DEPTH) {
            LOG.warn("Skipping fragment {} (lex depth {} < {})", new Object[]{tree, Integer.valueOf(depth), Integer.valueOf(this.MIN_LEX_DEPTH)});
            return;
        }
        if (this.lmFragments.get(tree.getRule()) == null) {
            this.lmFragments.put(tree.getRule(), new ArrayList<>());
        }
        this.lmFragments.get(tree.getRule()).add(tree);
        this.numFragments++;
    }

    @Override // org.apache.joshua.decoder.ff.StatefulFF, org.apache.joshua.decoder.ff.FeatureFunction
    public DPState compute(Rule rule, List<HGNode> list, int i, int i2, SourcePath sourcePath, Sentence sentence, FeatureFunction.Accumulator accumulator) {
        Tree buildTree = Tree.buildTree(rule, list, this.BUILD_DEPTH);
        Stack stack = new Stack();
        stack.add(buildTree);
        while (!stack.empty()) {
            Tree tree = (Tree) stack.pop();
            if (tree != null) {
                if (this.lmFragments.get(tree.getRule()) != null) {
                    this.lmFragments.get(tree.getRule()).stream().filter(tree2 -> {
                        return tree2.getLabel() == tree.getLabel() && match(tree2, tree);
                    }).forEach(tree3 -> {
                        accumulator.add(tree3.escapedString(), 1.0f);
                        if (0 != 0) {
                            if (tree3.isLexicalized()) {
                                accumulator.add(String.format("FragmentFF_lexdepth%d", Integer.valueOf(tree3.getDepth())), 1.0f);
                            } else {
                                accumulator.add(String.format("FragmentFF_depth%d", Integer.valueOf(tree3.getDepth())), 1.0f);
                            }
                        }
                    });
                }
                if (tree.getChildren() != null) {
                    for (Tree tree4 : tree.getChildren()) {
                        if (!tree4.isBoundary()) {
                            stack.add(tree4);
                        }
                    }
                }
            }
        }
        return new FragmentState(buildTree);
    }

    private boolean match(Tree tree, Tree tree2) {
        if (tree.getLabel() != tree2.getLabel()) {
            return false;
        }
        List<Tree> children = tree.getChildren();
        if (children.size() <= 0) {
            return true;
        }
        List<Tree> children2 = tree2.getChildren();
        if (children.size() != children2.size()) {
            return false;
        }
        for (int i = 0; i < children.size(); i++) {
            if (children.get(i).getLabel() != children2.get(i).getLabel()) {
                return false;
            }
        }
        for (int i2 = 0; i2 < children.size(); i2++) {
            if (!match(children.get(i2), children2.get(i2))) {
                return false;
            }
        }
        return true;
    }

    @Override // org.apache.joshua.decoder.ff.StatefulFF, org.apache.joshua.decoder.ff.FeatureFunction
    public DPState computeFinal(HGNode hGNode, int i, int i2, SourcePath sourcePath, Sentence sentence, FeatureFunction.Accumulator accumulator) {
        return null;
    }

    @Override // org.apache.joshua.decoder.ff.StatefulFF, org.apache.joshua.decoder.ff.FeatureFunction
    public float estimateFutureCost(Rule rule, DPState dPState, Sentence sentence) {
        return 0.0f;
    }

    @Override // org.apache.joshua.decoder.ff.FeatureFunction
    public float estimateCost(Rule rule) {
        return 0.0f;
    }

    public static void main(String[] strArr) {
        FragmentLMFF fragmentLMFF = new FragmentLMFF(new FeatureVector(), new String[]{"-lm", "test/fragments.txt", "-map", "test/mapping.txt"}, null);
        Tree fromString = Tree.fromString("(S NP (VP (VBD \"said\") SBAR) (. \".\"))");
        Rule parseLine = new HieroFormatReader().parseLine("[S] ||| the man [VP,1] [.,2] ||| the man [VP,1] [.,2] ||| 0");
        Rule parseLine2 = new HieroFormatReader().parseLine("[VP] ||| said [SBAR,1] ||| said [SBAR,1] ||| 0");
        Rule parseLine3 = new HieroFormatReader().parseLine("[SBAR] ||| that he was done ||| that he was done ||| 0");
        Rule parseLine4 = new HieroFormatReader().parseLine("[.] ||| . ||| . ||| 0");
        OwnerId register = OwnerMap.register("0");
        parseLine.setOwner(register);
        parseLine2.setOwner(register);
        parseLine3.setOwner(register);
        parseLine4.setOwner(register);
        HGNode hGNode = new HGNode(3, 7, parseLine3.getLHS(), (List<DPState>) null, new HyperEdge(parseLine3, 0.0f, 0.0f, null, null), 0.0f);
        ArrayList arrayList = new ArrayList();
        Collections.addAll(arrayList, hGNode);
        HGNode hGNode2 = new HGNode(2, 7, parseLine2.getLHS(), (List<DPState>) null, new HyperEdge(parseLine2, 0.0f, 0.0f, arrayList, null), 0.0f);
        HGNode hGNode3 = new HGNode(7, 8, parseLine4.getLHS(), (List<DPState>) null, new HyperEdge(parseLine4, 0.0f, 0.0f, null, null), 0.0f);
        ArrayList arrayList2 = new ArrayList();
        Collections.addAll(arrayList2, hGNode2, hGNode3);
        Tree buildTree = Tree.buildTree(parseLine, arrayList2, 1);
        LOG.info("Does\n  {} match\n  {}??\n  -> {}", new Object[]{fromString, buildTree, Boolean.valueOf(fragmentLMFF.match(fromString, buildTree))});
    }
}
