package io.spiffe.workloadapi;

import io.spiffe.bundle.jwtbundle.JwtBundle;
import io.spiffe.bundle.jwtbundle.JwtBundleSet;
import io.spiffe.exception.BundleNotFoundException;
import io.spiffe.exception.JwtSourceException;
import io.spiffe.exception.JwtSvidException;
import io.spiffe.exception.SocketEndpointAddressException;
import io.spiffe.exception.WatcherException;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.TrustDomain;
import io.spiffe.svid.jwtsvid.JwtSvid;
import io.spiffe.workloadapi.DefaultWorkloadApiClient;
import io.spiffe.workloadapi.internal.ThreadUtils;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.logging.Level;
import java.util.logging.Logger;
import lombok.Generated;
import lombok.NonNull;
import org.apache.commons.lang3.tuple.ImmutablePair;

/* loaded from: input_file:io/spiffe/workloadapi/CachedJwtSource.class */
public class CachedJwtSource implements JwtSource {

    @Generated
    private static final Logger log = Logger.getLogger(CachedJwtSource.class.getName());
    static final String TIMEOUT_SYSTEM_PROPERTY = "spiffe.newJwtSource.timeout";
    static final Duration DEFAULT_TIMEOUT = Duration.parse(System.getProperty(TIMEOUT_SYSTEM_PROPERTY, "PT0S"));
    private JwtBundleSet bundles;
    private final WorkloadApiClient workloadApiClient;
    private volatile boolean closed;
    private final Map<ImmutablePair<SpiffeId, Set<String>>, List<JwtSvid>> jwtSvids = new ConcurrentHashMap();
    private Clock clock = Clock.systemDefaultZone();

    private CachedJwtSource(WorkloadApiClient workloadApiClient) {
        this.workloadApiClient = workloadApiClient;
    }

    public static JwtSource newSource() throws JwtSourceException, SocketEndpointAddressException {
        return newSource(JwtSourceOptions.builder().initTimeout(DEFAULT_TIMEOUT).build());
    }

    public static JwtSource newSource(@NonNull JwtSourceOptions jwtSourceOptions) throws SocketEndpointAddressException, JwtSourceException {
        if (jwtSourceOptions == null) {
            throw new NullPointerException("options is marked non-null but is null");
        }
        if (jwtSourceOptions.getWorkloadApiClient() == null) {
            jwtSourceOptions.setWorkloadApiClient(createClient(jwtSourceOptions));
        }
        if (jwtSourceOptions.getInitTimeout() == null) {
            jwtSourceOptions.setInitTimeout(DEFAULT_TIMEOUT);
        }
        CachedJwtSource cachedJwtSource = new CachedJwtSource(jwtSourceOptions.getWorkloadApiClient());
        try {
            cachedJwtSource.init(jwtSourceOptions.getInitTimeout());
            return cachedJwtSource;
        } catch (Exception e) {
            cachedJwtSource.close();
            throw new JwtSourceException("Error creating JWT source", e);
        }
    }

    @Override // io.spiffe.svid.jwtsvid.JwtSvidSource
    public JwtSvid fetchJwtSvid(String str, String... strArr) throws JwtSvidException {
        if (isClosed()) {
            throw new IllegalStateException("JWT SVID source is closed");
        }
        return getJwtSvids(str, strArr).get(0);
    }

    @Override // io.spiffe.svid.jwtsvid.JwtSvidSource
    public JwtSvid fetchJwtSvid(SpiffeId spiffeId, String str, String... strArr) throws JwtSvidException {
        if (isClosed()) {
            throw new IllegalStateException("JWT SVID source is closed");
        }
        return getJwtSvids(spiffeId, str, strArr).get(0);
    }

    @Override // io.spiffe.svid.jwtsvid.JwtSvidSource
    public List<JwtSvid> fetchJwtSvids(String str, String... strArr) throws JwtSvidException {
        if (isClosed()) {
            throw new IllegalStateException("JWT SVID source is closed");
        }
        return getJwtSvids(str, strArr);
    }

    @Override // io.spiffe.svid.jwtsvid.JwtSvidSource
    public List<JwtSvid> fetchJwtSvids(SpiffeId spiffeId, String str, String... strArr) throws JwtSvidException {
        if (isClosed()) {
            throw new IllegalStateException("JWT SVID source is closed");
        }
        return getJwtSvids(spiffeId, str, strArr);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // io.spiffe.bundle.BundleSource
    public JwtBundle getBundleForTrustDomain(@NonNull TrustDomain trustDomain) throws BundleNotFoundException {
        if (trustDomain == null) {
            throw new NullPointerException("trustDomain is marked non-null but is null");
        }
        if (isClosed()) {
            throw new IllegalStateException("JWT bundle source is closed");
        }
        return this.bundles.getBundleForTrustDomain(trustDomain);
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        if (!this.closed) {
            synchronized (this) {
                if (!this.closed) {
                    this.workloadApiClient.close();
                    this.closed = true;
                }
            }
        }
    }

    private List<JwtSvid> getJwtSvids(SpiffeId spiffeId, String str, String... strArr) throws JwtSvidException {
        ImmutablePair<SpiffeId, Set<String>> immutablePair = new ImmutablePair<>(spiffeId, getAudienceSet(str, strArr));
        List<JwtSvid> list = this.jwtSvids.get(immutablePair);
        if (list != null && !isTokenPastHalfLifetime(list.get(0))) {
            return list;
        }
        synchronized (this) {
            List<JwtSvid> list2 = this.jwtSvids.get(immutablePair);
            if (list2 != null && !isTokenPastHalfLifetime(list2.get(0))) {
                return list2;
            }
            List<JwtSvid> fetchJwtSvids = immutablePair.left == null ? this.workloadApiClient.fetchJwtSvids(str, strArr) : this.workloadApiClient.fetchJwtSvids(immutablePair.left, str, strArr);
            this.jwtSvids.put(immutablePair, fetchJwtSvids);
            return fetchJwtSvids;
        }
    }

    private List<JwtSvid> getJwtSvids(String str, String... strArr) throws JwtSvidException {
        return getJwtSvids(null, str, strArr);
    }

    private static Set<String> getAudienceSet(String str, String[] strArr) {
        Set<String> singleton;
        if (strArr == null || strArr.length <= 0) {
            singleton = Collections.singleton(str);
        } else {
            singleton = new HashSet(Arrays.asList(strArr));
            singleton.add(str);
        }
        return singleton;
    }

    private boolean isTokenPastHalfLifetime(JwtSvid jwtSvid) {
        return this.clock.instant().isAfter(Instant.ofEpochMilli(new Date(jwtSvid.getExpiry().getTime() - ((jwtSvid.getExpiry().getTime() - jwtSvid.getIssuedAt().getTime()) / 2)).getTime()));
    }

    private void init(Duration duration) throws TimeoutException {
        boolean await;
        CountDownLatch countDownLatch = new CountDownLatch(1);
        setJwtBundlesWatcher(countDownLatch);
        if (duration.isZero()) {
            ThreadUtils.await(countDownLatch);
            await = true;
        } else {
            await = ThreadUtils.await(countDownLatch, duration.getSeconds(), TimeUnit.SECONDS);
        }
        if (!await) {
            throw new TimeoutException("Timeout waiting for JWT bundles update");
        }
    }

    private void setJwtBundlesWatcher(final CountDownLatch countDownLatch) {
        this.workloadApiClient.watchJwtBundles(new Watcher<JwtBundleSet>() { // from class: io.spiffe.workloadapi.CachedJwtSource.1
            @Override // io.spiffe.workloadapi.Watcher
            public void onUpdate(JwtBundleSet jwtBundleSet) {
                CachedJwtSource.log.log(Level.INFO, "Received JwtBundleSet update");
                CachedJwtSource.this.setJwtBundleSet(jwtBundleSet);
                countDownLatch.countDown();
            }

            @Override // io.spiffe.workloadapi.Watcher
            public void onError(Throwable th) {
                CachedJwtSource.log.log(Level.SEVERE, "Error in JwtBundleSet watcher", th);
                countDownLatch.countDown();
                throw new WatcherException("Error fetching JwtBundleSet", th);
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void setJwtBundleSet(JwtBundleSet jwtBundleSet) {
        synchronized (this) {
            this.bundles = jwtBundleSet;
        }
    }

    private boolean isClosed() {
        boolean z;
        synchronized (this) {
            z = this.closed;
        }
        return z;
    }

    private static WorkloadApiClient createClient(JwtSourceOptions jwtSourceOptions) throws SocketEndpointAddressException {
        return DefaultWorkloadApiClient.newClient(DefaultWorkloadApiClient.ClientOptions.builder().spiffeSocketPath(jwtSourceOptions.getSpiffeSocketPath()).build());
    }

    void setClock(Clock clock) {
        this.clock = clock;
    }
}
