Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ dependencies {
implementation libs.aws.s3
implementation libs.aws.s3.transfer.manager
implementation libs.aws.crt
implementation libs.aws.netty.nio.client
runtimeOnly libs.aws.sts
// AWS SDK v1 for backward compatibility with existing plugins
implementation libs.aws.java.sdk.core.legacy
Expand Down
1 change: 1 addition & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ aws-s3 = { module = "software.amazon.awssdk:s3", version.ref = "aws" }
aws-sts = { module = "software.amazon.awssdk:sts", version.ref = "aws" }
aws-s3-transfer-manager = { module = "software.amazon.awssdk:s3-transfer-manager", version.ref = "aws" }
aws-crt = { module = "software.amazon.awssdk.crt:aws-crt", version = "0.43.9" }
aws-netty-nio-client = { module = "software.amazon.awssdk:netty-nio-client", version.ref = "aws" }
aws-java-sdk-core-legacy = { module = "com.amazonaws:aws-java-sdk-core", version.ref = "aws-legacy" }
aws-java-sdk-s3-legacy = { module = "com.amazonaws:aws-java-sdk-s3", version.ref = "aws-legacy" }
aws-java-sdk-sts-legacy = { module = "com.amazonaws:aws-java-sdk-sts", version.ref = "aws-legacy" }
Expand Down
212 changes: 200 additions & 12 deletions src/main/java/com/yelp/nrtsearch/server/remote/s3/S3Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import software.amazon.awssdk.core.exception.SdkException;
import software.amazon.awssdk.core.retry.RetryMode;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.http.nio.netty.SdkEventLoopGroup;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3Client;
Expand All @@ -39,6 +41,8 @@
import software.amazon.awssdk.services.s3.model.HeadObjectRequest;
import software.amazon.awssdk.services.s3.model.NoSuchKeyException;
import software.amazon.awssdk.services.s3.model.S3Exception;
import software.amazon.awssdk.services.s3.multipart.MultipartConfiguration;
import software.amazon.awssdk.services.s3.multipart.ParallelConfiguration;

/** Utility class for working with S3. */
public class S3Util {
Expand Down Expand Up @@ -201,6 +205,128 @@ static long sizeStrToBytes(String sizeStr) {
}
}

/** Configuration for the Java-based S3 async client. */
public static class S3JavaAsyncConfig {
private static final String CONFIG_PREFIX = "remoteConfig.s3.java.";

private final long minimumPartSizeInBytes;
private final long thresholdSizeInBytes;
private final long apiCallBufferSizeInBytes;
private final int maxInFlightParts;
private final int ioThreads;

/**
* Create S3JavaAsyncConfig from NrtsearchConfig.
*
* @param configuration server configuration
* @return S3JavaAsyncConfig
*/
public static S3JavaAsyncConfig fromConfig(NrtsearchConfig configuration) {
long minimumPartSizeInBytes =
S3CrtConfig.sizeStrToBytes(
configuration.getConfigReader().getString(CONFIG_PREFIX + "minimumPartSize", "8mb"));
long thresholdSizeInBytes =
S3CrtConfig.sizeStrToBytes(
configuration.getConfigReader().getString(CONFIG_PREFIX + "thresholdSize", "8mb"));
long apiCallBufferSizeInBytes =
S3CrtConfig.sizeStrToBytes(
configuration.getConfigReader().getString(CONFIG_PREFIX + "apiCallBufferSize", "0"));
int maxInFlightParts =
configuration.getConfigReader().getInteger(CONFIG_PREFIX + "maxInFlightParts", 0);
int ioThreads = configuration.getConfigReader().getInteger(CONFIG_PREFIX + "ioThreads", 0);
return new S3JavaAsyncConfig(
minimumPartSizeInBytes,
thresholdSizeInBytes,
apiCallBufferSizeInBytes,
maxInFlightParts,
ioThreads);
}

/**
* Constructor.
*
* @param minimumPartSizeInBytes minimum multipart part size in bytes
* @param thresholdSizeInBytes size threshold to trigger multipart upload in bytes
* @param apiCallBufferSizeInBytes API call buffer size in bytes (0 means SDK default)
* @param maxInFlightParts max in-flight multipart parts (0 means SDK default)
* @param ioThreads number of Netty I/O threads (0 means SDK default)
* @throws IllegalArgumentException if minimumPartSizeInBytes or thresholdSizeInBytes are <= 0,
* or if apiCallBufferSizeInBytes, maxInFlightParts, or ioThreads are < 0
*/
public S3JavaAsyncConfig(
long minimumPartSizeInBytes,
long thresholdSizeInBytes,
long apiCallBufferSizeInBytes,
int maxInFlightParts,
int ioThreads) {
if (minimumPartSizeInBytes <= 0) {
throw new IllegalArgumentException("minimumPartSizeInBytes must be > 0");
}
if (thresholdSizeInBytes <= 0) {
throw new IllegalArgumentException("thresholdSizeInBytes must be > 0");
}
if (apiCallBufferSizeInBytes < 0) {
throw new IllegalArgumentException("apiCallBufferSizeInBytes must be >= 0");
}
if (maxInFlightParts < 0) {
throw new IllegalArgumentException("maxInFlightParts must be >= 0");
}
if (ioThreads < 0) {
throw new IllegalArgumentException("ioThreads must be >= 0");
}
this.minimumPartSizeInBytes = minimumPartSizeInBytes;
this.thresholdSizeInBytes = thresholdSizeInBytes;
this.apiCallBufferSizeInBytes = apiCallBufferSizeInBytes;
this.maxInFlightParts = maxInFlightParts;
this.ioThreads = ioThreads;
}

/**
* Get the minimum part size in bytes.
*
* @return minimum part size in bytes
*/
public long getMinimumPartSizeInBytes() {
return minimumPartSizeInBytes;
}

/**
* Get the threshold size in bytes to trigger multipart upload.
*
* @return threshold size in bytes
*/
public long getThresholdSizeInBytes() {
return thresholdSizeInBytes;
}

/**
* Get the API call buffer size in bytes (0 means SDK default).
*
* @return API call buffer size in bytes
*/
public long getApiCallBufferSizeInBytes() {
return apiCallBufferSizeInBytes;
}

/**
* Get the max in-flight multipart parts (0 means SDK default).
*
* @return max in-flight multipart parts
*/
public int getMaxInFlightParts() {
return maxInFlightParts;
}

/**
* Get the number of Netty I/O threads (0 means SDK default).
*
* @return number of I/O threads
*/
public int getIoThreads() {
return ioThreads;
}
}

/**
* Create a new S3 client bundle from the given configuration.
*
Expand Down Expand Up @@ -269,10 +395,11 @@ public static S3ClientBundle buildS3ClientBundle(NrtsearchConfig configuration)
}

int maxRetries = configuration.getMaxS3ClientRetries();
software.amazon.awssdk.core.client.config.ClientOverrideConfiguration overrideConfig = null;
if (maxRetries > 0) {
RetryPolicy retryPolicy =
RetryPolicy.builder(RetryMode.STANDARD).numRetries(maxRetries).build();
software.amazon.awssdk.core.client.config.ClientOverrideConfiguration overrideConfig =
overrideConfig =
software.amazon.awssdk.core.client.config.ClientOverrideConfiguration.builder()
.retryPolicy(retryPolicy)
.build();
Expand All @@ -281,6 +408,72 @@ public static S3ClientBundle buildS3ClientBundle(NrtsearchConfig configuration)

S3Client s3Client = clientBuilder.build();

String asyncClientType =
configuration.getConfigReader().getString("remoteConfig.s3.asyncClientType", "crt");
S3AsyncClient s3AsyncClient;
if (asyncClientType.equalsIgnoreCase("java")) {
s3AsyncClient =
buildJavaAsyncClient(configuration, s3Client, overrideConfig, globalBucketAccess);
} else if (asyncClientType.equalsIgnoreCase("crt")) {
s3AsyncClient = buildCrtAsyncClient(configuration, s3Client, globalBucketAccess);
} else {
throw new IllegalArgumentException(
"Unknown asyncClientType: '" + asyncClientType + "'. Valid values are 'crt' or 'java'.");
}
return new S3ClientBundle(s3Client, s3AsyncClient);
}

private static S3AsyncClient buildJavaAsyncClient(
NrtsearchConfig configuration,
S3Client s3Client,
software.amazon.awssdk.core.client.config.ClientOverrideConfiguration overrideConfig,
boolean globalBucketAccess) {
S3JavaAsyncConfig javaConfig = S3JavaAsyncConfig.fromConfig(configuration);
logger.info(
"S3 Java async client config: minimumPartSizeInBytes={}, thresholdSizeInBytes={}, apiCallBufferSizeInBytes={}, maxInFlightParts={}, ioThreads={}",
javaConfig.getMinimumPartSizeInBytes(),
javaConfig.getThresholdSizeInBytes(),
javaConfig.getApiCallBufferSizeInBytes(),
javaConfig.getMaxInFlightParts(),
javaConfig.getIoThreads());

MultipartConfiguration.Builder multipartBuilder =
MultipartConfiguration.builder()
.minimumPartSizeInBytes(javaConfig.getMinimumPartSizeInBytes())
.thresholdInBytes(javaConfig.getThresholdSizeInBytes());
if (javaConfig.getApiCallBufferSizeInBytes() > 0) {
multipartBuilder.apiCallBufferSizeInBytes(javaConfig.getApiCallBufferSizeInBytes());
}
if (javaConfig.getMaxInFlightParts() > 0) {
multipartBuilder.parallelConfiguration(
ParallelConfiguration.builder()
.maxInFlightParts(javaConfig.getMaxInFlightParts())
.build());
}

software.amazon.awssdk.services.s3.S3AsyncClientBuilder builder =
S3AsyncClient.builder()
.credentialsProvider(createCredentialsProvider(configuration))
.region(s3Client.serviceClientConfiguration().region())
.multipartEnabled(true)
.multipartConfiguration(multipartBuilder.build());
if (javaConfig.getIoThreads() > 0) {
builder.httpClientBuilder(
NettyNioAsyncHttpClient.builder()
.eventLoopGroupBuilder(
SdkEventLoopGroup.builder().numberOfThreads(javaConfig.getIoThreads())));
}
if (overrideConfig != null) {
builder.overrideConfiguration(overrideConfig);
}
if (globalBucketAccess) {
builder.crossRegionAccessEnabled(true);
}
return builder.build();
}

private static S3AsyncClient buildCrtAsyncClient(
NrtsearchConfig configuration, S3Client s3Client, boolean globalBucketAccess) {
S3CrtConfig crtConfig = S3CrtConfig.fromConfig(configuration);
logger.info(
"S3 CRT client config: targetThroughputInGbps={}, maxConcurrency={}, minimumPartSizeInBytes={}, maxNativeMemoryLimitInBytes={}",
Expand All @@ -289,29 +482,24 @@ public static S3ClientBundle buildS3ClientBundle(NrtsearchConfig configuration)
crtConfig.getMinimumPartSizeInBytes(),
crtConfig.getMaxNativeMemoryLimitInBytes());

S3CrtAsyncClientBuilder s3CrtAsyncClientBuilder =
S3CrtAsyncClientBuilder builder =
S3CrtAsyncClient.builder()
.credentialsProvider(createCredentialsProvider(configuration))
.region(s3Client.serviceClientConfiguration().region())
.minimumPartSizeInBytes(crtConfig.getMinimumPartSizeInBytes());

if (crtConfig.getTargetThroughputInGbps() > 0) {
s3CrtAsyncClientBuilder.targetThroughputInGbps(crtConfig.getTargetThroughputInGbps());
builder.targetThroughputInGbps(crtConfig.getTargetThroughputInGbps());
} else if (crtConfig.getMaxConcurrency() > 0) {
s3CrtAsyncClientBuilder.maxConcurrency(crtConfig.getMaxConcurrency());
builder.maxConcurrency(crtConfig.getMaxConcurrency());
}

if (crtConfig.getMaxNativeMemoryLimitInBytes() > 0) {
s3CrtAsyncClientBuilder.maxNativeMemoryLimitInBytes(
crtConfig.getMaxNativeMemoryLimitInBytes());
builder.maxNativeMemoryLimitInBytes(crtConfig.getMaxNativeMemoryLimitInBytes());
}

if (globalBucketAccess) {
s3CrtAsyncClientBuilder.crossRegionAccessEnabled(true);
builder.crossRegionAccessEnabled(true);
}

S3AsyncClient s3AsyncClient = s3CrtAsyncClientBuilder.build();
return new S3ClientBundle(s3Client, s3AsyncClient);
return builder.build();
}

/**
Expand Down
Loading
Loading