Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@
package software.amazon.awssdk.awscore.internal.identity;

import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute;
import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
import software.amazon.awssdk.core.SdkRequest;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater;
import software.amazon.awssdk.identity.spi.IdentityProviders;

/**
* AWS implementation of {@link IdentityProviderUpdater} that reads credential overrides
* from {@link AwsRequestOverrideConfiguration}.
* from {@link AwsRequestOverrideConfiguration} and deprecated {@link AwsSignerExecutionAttribute#AWS_CREDENTIALS}.
*/
@SdkInternalApi
public final class AwsIdentityProviderUpdater implements IdentityProviderUpdater {
Expand All @@ -38,17 +42,32 @@ public static AwsIdentityProviderUpdater create() {
}

@Override
public IdentityProviders update(SdkRequest request, IdentityProviders base) {
public IdentityProviders update(SdkRequest request, IdentityProviders base, ExecutionAttributes executionAttributes) {
if (base == null) {
return null;
}
return request.overrideConfiguration()

IdentityProviders updated = request.overrideConfiguration()
.filter(c -> c instanceof AwsRequestOverrideConfiguration)
.map(c -> (AwsRequestOverrideConfiguration) c)
.map(c -> base.copy(b -> {
c.credentialsIdentityProvider().ifPresent(b::putIdentityProvider);
c.tokenIdentityProvider().ifPresent(b::putIdentityProvider);
}))
.orElse(base);
.orElse(null);

if (updated != null) {
return updated;
}

// Support deprecated AWS_CREDENTIALS execution attribute for backwards compatibility
// with interceptors that set credentials via AwsSignerExecutionAttribute.AWS_CREDENTIALS
AwsCredentials credentials = executionAttributes.getOptionalAttribute(AwsSignerExecutionAttribute.AWS_CREDENTIALS)
.orElse(null);
if (credentials != null) {
return base.copy(b -> b.putIdentityProvider(StaticCredentialsProvider.create(credentials)));
}

return base;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ private IdentityProviders updateIdentityProvidersIfNeeded(ExecutionAttributes ex
IdentityProviderUpdater updater =
executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDER_UPDATER);
if (updater != null) {
identityProviders = updater.update(request, identityProviders);
identityProviders = updater.update(request, identityProviders, executionAttributes);
}
return identityProviders;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,11 @@ private static boolean interceptorModifiedEndpoint(SdkHttpFullRequest.Builder re
return false;
}
String requestHost = request.host();
Integer requestPort = request.port();
return requestHost != null
&& (!requestHost.equals(preModifyUri.getHost())
|| !String.valueOf(request.protocol()).equals(preModifyUri.getScheme())
|| request.port() != preModifyUri.getPort());
|| (requestPort != null && requestPort != preModifyUri.getPort()));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be replaced with Objects.equals()?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, when no port is set, the SDK returns null but URI.getPort() returns -1, so Objects.equals() would be incorrect and needs special handling for this when no port is set case. So just skipping the comparison when port is null.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see!

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import software.amazon.awssdk.annotations.SdkProtectedApi;
import software.amazon.awssdk.core.SdkRequest;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.identity.spi.IdentityProviders;

/**
Expand All @@ -29,11 +30,13 @@
@SdkProtectedApi
public interface IdentityProviderUpdater {
/**
* Updates identity providers based on request-level overrides.
* Updates identity providers by applying request-level credential overrides or
* credentials set via {@code AwsSignerExecutionAttribute.AWS_CREDENTIALS} by interceptors.
*
* @param request The request (after interceptors have modified it)
* @param base The base identity providers from client configuration
* @return Updated identity providers, or base if no overrides
* @param executionAttributes The execution attributes, checked for interceptor-set AWS_CREDENTIALS
* @return Updated identity providers, or base if no overrides apply
*/
IdentityProviders update(SdkRequest request, IdentityProviders base);
IdentityProviders update(SdkRequest request, IdentityProviders base, ExecutionAttributes executionAttributes);
}
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ void execute_withIdentityProviderUpdater_callsUpdaterWithRequest() throws Except
IdentityProviders updatedProviders = mock(IdentityProviders.class);

IdentityProviderUpdater updater = mock(IdentityProviderUpdater.class);
doReturn(updatedProviders).when(updater).update(sdkRequest, baseProviders);
doReturn(updatedProviders).when(updater).update(sdkRequest, baseProviders, executionAttributes);

// Setup so that auth scheme uses the updated providers
@SuppressWarnings("unchecked")
Expand All @@ -162,7 +162,7 @@ void execute_withIdentityProviderUpdater_callsUpdaterWithRequest() throws Except

stage.execute(httpRequestBuilder, context);

verify(updater).update(sdkRequest, baseProviders);
verify(updater).update(sdkRequest, baseProviders, executionAttributes);
}

@Test
Expand Down
Loading