package aima.core.learning.reinforcement.agent;

import aima.core.agent.Action;
import aima.core.learning.reinforcement.PerceptStateReward;
import aima.core.probability.mdp.ActionsFunction;
import aima.core.util.FrequencyCounter;
import aima.core.util.datastructure.Pair;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

/* loaded from: input_file:aima/core/learning/reinforcement/agent/QLearningAgent.class */
public class QLearningAgent<S, A extends Action> extends ReinforcementAgent<S, A> {
    Map<Pair<S, A>, Double> Q = new HashMap();
    private FrequencyCounter<Pair<S, A>> Nsa = new FrequencyCounter<>();
    private S s = null;
    private A a = null;
    private Double r = null;
    private ActionsFunction<S, A> actionsFunction;
    private A noneAction;
    private double alpha;
    private double gamma;
    private int Ne;
    private double Rplus;

    public QLearningAgent(ActionsFunction<S, A> actionsFunction, A a, double d, double d2, int i, double d3) {
        this.actionsFunction = null;
        this.noneAction = null;
        this.alpha = 0.0d;
        this.gamma = 0.0d;
        this.Ne = 0;
        this.Rplus = 0.0d;
        this.actionsFunction = actionsFunction;
        this.noneAction = a;
        this.alpha = d;
        this.gamma = d2;
        this.Ne = i;
        this.Rplus = d3;
    }

    @Override // aima.core.learning.reinforcement.agent.ReinforcementAgent
    public A execute(PerceptStateReward<S> perceptStateReward) {
        S state = perceptStateReward.state();
        double reward = perceptStateReward.reward();
        if (isTerminal(state)) {
            this.Q.put(new Pair<>(state, this.noneAction), Double.valueOf(reward));
        }
        if (null != this.s) {
            Pair<S, A> pair = new Pair<>(this.s, this.a);
            this.Nsa.incrementFor(pair);
            Double d = this.Q.get(pair);
            if (null == d) {
                d = Double.valueOf(0.0d);
            }
            this.Q.put(pair, Double.valueOf(d.doubleValue() + (alpha(this.Nsa, this.s, this.a) * ((this.r.doubleValue() + (this.gamma * maxAPrime(state))) - d.doubleValue()))));
        }
        if (isTerminal(state)) {
            this.s = null;
            this.a = null;
            this.r = null;
        } else {
            this.s = state;
            this.a = argmaxAPrime(state);
            this.r = Double.valueOf(reward);
        }
        return this.a;
    }

    @Override // aima.core.learning.reinforcement.agent.ReinforcementAgent
    public void reset() {
        this.Q.clear();
        this.Nsa.clear();
        this.s = null;
        this.a = null;
        this.r = null;
    }

    @Override // aima.core.learning.reinforcement.agent.ReinforcementAgent
    public Map<S, Double> getUtility() {
        HashMap hashMap = new HashMap();
        for (Pair<S, A> pair : this.Q.keySet()) {
            Double d = this.Q.get(pair);
            Double d2 = (Double) hashMap.get(pair.getFirst());
            if (null == d2 || d2.doubleValue() < d.doubleValue()) {
                hashMap.put(pair.getFirst(), d);
            }
        }
        return hashMap;
    }

    protected double alpha(FrequencyCounter<Pair<S, A>> frequencyCounter, S s, A a) {
        return this.alpha;
    }

    protected double f(Double d, int i) {
        return (null == d || i < this.Ne) ? this.Rplus : d.doubleValue();
    }

    private boolean isTerminal(S s) {
        boolean z = false;
        if (null != s && this.actionsFunction.actions(s).size() == 0) {
            z = true;
        }
        return z;
    }

    private double maxAPrime(S s) {
        double d = Double.NEGATIVE_INFINITY;
        if (this.actionsFunction.actions(s).size() == 0) {
            d = this.Q.get(new Pair(s, this.noneAction)).doubleValue();
        } else {
            Iterator<A> it = this.actionsFunction.actions(s).iterator();
            while (it.hasNext()) {
                Double d2 = this.Q.get(new Pair(s, it.next()));
                if (null != d2 && d2.doubleValue() > d) {
                    d = d2.doubleValue();
                }
            }
        }
        if (d == Double.NEGATIVE_INFINITY) {
            d = 0.0d;
        }
        return d;
    }

    private A argmaxAPrime(S s) {
        A a = null;
        double d = Double.NEGATIVE_INFINITY;
        for (A a2 : this.actionsFunction.actions(s)) {
            Pair<S, A> pair = new Pair<>(s, a2);
            double f = f(this.Q.get(pair), this.Nsa.getCount(pair).intValue());
            if (f > d) {
                d = f;
                a = a2;
            }
        }
        return a;
    }
}
