JCL-349: Better concurrency support for OpenID session refresh by acoburn · Pull Request #456 · inrupt/solid-client-java

Expand Up @@ -40,6 +40,8 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.ForkJoinTask; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier;
Expand All @@ -54,19 +56,24 @@ import org.jose4j.jwt.consumer.JwtConsumerBuilder; import org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver; import org.jose4j.keys.resolvers.VerificationKeyResolver; import org.slf4j.Logger; import org.slf4j.LoggerFactory;
/** * A session implementation for use with OpenID Connect ID Tokens. * */ public final class OpenIdSession implements Session {
private static final Logger LOGGER = LoggerFactory.getLogger(OpenIdSession.class);
public static final URI ID_TOKEN = URI.create("http://openid.net/specs/openid-connect-core-1_0.html#IDToken");
private final String id; private final Set<String> schemes; private final Supplier<CompletionStage<Credential>> authenticator; private final AtomicReference<Credential> credential = new AtomicReference<>(); private final ForkJoinPool executor = new ForkJoinPool(1); private final DPoP dpop;
private OpenIdSession(final String id, final DPoP dpop, Expand Down Expand Up @@ -182,15 +189,11 @@ public Set<String> supportedSchemes() { @Override public Optional<Credential> getCredential(final URI name, final URI uri) { if (ID_TOKEN.equals(name)) { final Credential c = credential.get(); if (!hasExpired(c)) { return Optional.of(c); } final Credential freshCredential = authenticator.get().toCompletableFuture().join(); if (!hasExpired(freshCredential)) { credential.set(freshCredential); return Optional.of(freshCredential); final Credential cred = credential.get(); if (!hasExpired(cred)) { return Optional.of(cred); } return Optional.ofNullable(executor.invoke(ForkJoinTask.adapt(this::synchronizedFetch))); } return Optional.empty(); } Expand Down Expand Up @@ -222,7 +225,7 @@ public Optional<Credential> fromCache(final Request request) { @Override public CompletionStage<Optional<Credential>> authenticate(final Request request, final Set<String> algorithms) { return authenticator.get().thenApply(Optional::ofNullable); return CompletableFuture.completedFuture(getCredential(ID_TOKEN, null)); }
boolean hasExpired(final Credential credential) { Expand All @@ -232,6 +235,22 @@ boolean hasExpired(final Credential credential) { return true; }
private synchronized Credential synchronizedFetch() { // Check again inside the synchronized method final Credential cred = credential.get(); if (!hasExpired(cred)) { return cred; }
// Fetch the refreshed credentials final Credential refreshed = authenticator.get().toCompletableFuture().join(); if (!hasExpired(refreshed)) { credential.set(refreshed); return refreshed; } return null; }
static String getSessionIdentifier(final JwtClaims claims) { final String webid = claims.getClaimValueAsString("webid"); if (webid != null) { Expand Down