diff --git a/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java b/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java
index 0da51bf47f7..5100537aea2 100644
--- a/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java
+++ b/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java
@@ -31,11 +31,35 @@
import io.grpc.Status;
import io.grpc.xds.client.Bootstrapper;
import io.grpc.xds.client.XdsTransportFactory;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
+/**
+ * A factory for creating gRPC-based transports for xDS communication.
+ *
+ *
WARNING: This class reuses channels when possible, based on the provided {@link
+ * Bootstrapper.ServerInfo} with important considerations. The {@link Bootstrapper.ServerInfo}
+ * includes {@link ChannelCredentials}, which is compared by reference equality. This means every
+ * {@link Bootstrapper.BootstrapInfo} would have non-equal copies of {@link
+ * Bootstrapper.ServerInfo}, even if they all represent the same xDS server configuration. For gRPC
+ * name resolution with the {@code xds} and {@code google-c2p} scheme, this transport sharing works
+ * as expected as it internally reuses a single {@link Bootstrapper.BootstrapInfo} instance.
+ * Otherwise, new transports would be created for each {@link Bootstrapper.ServerInfo} despite them
+ * possibly representing the same xDS server configuration and defeating the purpose of transport
+ * sharing.
+ */
final class GrpcXdsTransportFactory implements XdsTransportFactory {
private final CallCredentials callCredentials;
+ // The map of xDS server info to its corresponding gRPC xDS transport.
+ // This enables reusing and sharing the same underlying gRPC channel.
+ //
+ // NOTE: ConcurrentHashMap is used as a per-entry lock and all reads and writes must be a mutation
+ // via the ConcurrentHashMap APIs to acquire the per-entry lock in order to ensure thread safety
+ // for reference counting of each GrpcXdsTransport instance.
+ private static final Map xdsServerInfoToTransportMap =
+ new ConcurrentHashMap<>();
GrpcXdsTransportFactory(CallCredentials callCredentials) {
this.callCredentials = callCredentials;
@@ -43,12 +67,20 @@ final class GrpcXdsTransportFactory implements XdsTransportFactory {
@Override
public XdsTransport create(Bootstrapper.ServerInfo serverInfo) {
- return new GrpcXdsTransport(serverInfo, callCredentials);
+ return xdsServerInfoToTransportMap.compute(
+ serverInfo,
+ (info, transport) -> {
+ if (transport == null) {
+ transport = new GrpcXdsTransport(serverInfo, callCredentials);
+ }
+ ++transport.refCount;
+ return transport;
+ });
}
@VisibleForTesting
public XdsTransport createForTest(ManagedChannel channel) {
- return new GrpcXdsTransport(channel, callCredentials);
+ return new GrpcXdsTransport(channel, callCredentials, null);
}
@VisibleForTesting
@@ -56,6 +88,9 @@ static class GrpcXdsTransport implements XdsTransport {
private final ManagedChannel channel;
private final CallCredentials callCredentials;
+ private final Bootstrapper.ServerInfo serverInfo;
+ // Must only be accessed via the ConcurrentHashMap APIs which act as the locking methods.
+ private int refCount = 0;
public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo) {
this(serverInfo, null);
@@ -63,7 +98,7 @@ public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo) {
@VisibleForTesting
public GrpcXdsTransport(ManagedChannel channel) {
- this(channel, null);
+ this(channel, null, null);
}
public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo, CallCredentials callCredentials) {
@@ -73,12 +108,17 @@ public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo, CallCredentials call
.keepAliveTime(5, TimeUnit.MINUTES)
.build();
this.callCredentials = callCredentials;
+ this.serverInfo = serverInfo;
}
@VisibleForTesting
- public GrpcXdsTransport(ManagedChannel channel, CallCredentials callCredentials) {
+ public GrpcXdsTransport(
+ ManagedChannel channel,
+ CallCredentials callCredentials,
+ Bootstrapper.ServerInfo serverInfo) {
this.channel = checkNotNull(channel, "channel");
this.callCredentials = callCredentials;
+ this.serverInfo = serverInfo;
}
@Override
@@ -98,7 +138,19 @@ public StreamingCall createStreamingCall(
@Override
public void shutdown() {
- channel.shutdown();
+ if (serverInfo == null) {
+ channel.shutdown();
+ return;
+ }
+ xdsServerInfoToTransportMap.computeIfPresent(
+ serverInfo,
+ (info, transport) -> {
+ if (--transport.refCount == 0) { // Prefix decrement and return the updated value.
+ transport.channel.shutdown();
+ return null; // Remove mapping.
+ }
+ return transport;
+ });
}
private class XdsStreamingCall implements
diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java b/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java
index 66e0d4b3198..9c606a962f6 100644
--- a/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java
+++ b/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java
@@ -30,6 +30,7 @@
import io.grpc.Server;
import io.grpc.Status;
import io.grpc.stub.StreamObserver;
+import io.grpc.testing.GrpcCleanupRule;
import io.grpc.xds.client.Bootstrapper;
import io.grpc.xds.client.XdsTransportFactory;
import java.util.concurrent.BlockingQueue;
@@ -37,6 +38,7 @@
import java.util.concurrent.TimeUnit;
import org.junit.After;
import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -44,6 +46,8 @@
@RunWith(JUnit4.class)
public class GrpcXdsTransportFactoryTest {
+ @Rule public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule();
+
private Server server;
@Before
@@ -118,6 +122,59 @@ public void callApis() throws Exception {
xdsTransport.shutdown();
}
+ @Test
+ public void refCountedXdsTransport_sameXdsServerAddress_returnsExistingTransport() {
+ Bootstrapper.ServerInfo xdsServerInfo =
+ Bootstrapper.ServerInfo.create(
+ "localhost:" + server.getPort(), InsecureChannelCredentials.create());
+ GrpcXdsTransportFactory xdsTransportFactory = new GrpcXdsTransportFactory(null);
+ // Calling create() for the first time creates a new GrpcXdsTransport instance.
+ // The ref count was previously 0 and now is 1.
+ XdsTransportFactory.XdsTransport transport1 = xdsTransportFactory.create(xdsServerInfo);
+ // Calling create() for the second time to the same xDS server address returns the same
+ // GrpcXdsTransport instance. The ref count was previously 1 and now is 2.
+ XdsTransportFactory.XdsTransport transport2 = xdsTransportFactory.create(xdsServerInfo);
+ assertThat(transport1).isSameInstanceAs(transport2);
+ // Calling shutdown() for the first time does not shut down the GrpcXdsTransport instance.
+ // The ref count was previously 2 and now is 1.
+ transport1.shutdown();
+ // Calling shutdown() for the second time shuts down the GrpcXdsTransport instance.
+ // The ref count was previously 1 and now is 0.
+ transport2.shutdown();
+ }
+
+ @Test
+ public void refCountedXdsTransport_differentXdsServerAddress_returnsDifferentTransport()
+ throws Exception {
+ // Create and start a second xDS serverĀ on a different port.
+ Server server2 =
+ grpcCleanupRule.register(
+ Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create())
+ .addService(echoAdsService())
+ .build()
+ .start());
+ Bootstrapper.ServerInfo xdsServerInfo1 =
+ Bootstrapper.ServerInfo.create(
+ "localhost:" + server.getPort(), InsecureChannelCredentials.create());
+ Bootstrapper.ServerInfo xdsServerInfo2 =
+ Bootstrapper.ServerInfo.create(
+ "localhost:" + server2.getPort(), InsecureChannelCredentials.create());
+ GrpcXdsTransportFactory xdsTransportFactory = new GrpcXdsTransportFactory(null);
+ // Calling create() to the first xDS server creates a new GrpcXdsTransport instance.
+ // The ref count was previously 0 and now is 1.
+ XdsTransportFactory.XdsTransport transport1 = xdsTransportFactory.create(xdsServerInfo1);
+ // Calling create() to the second xDS server creates a different GrpcXdsTransport instance.
+ // The ref count was previously 0 and now is 1.
+ XdsTransportFactory.XdsTransport transport2 = xdsTransportFactory.create(xdsServerInfo2);
+ assertThat(transport1).isNotSameInstanceAs(transport2);
+ // Calling shutdown() shuts down the GrpcXdsTransport instance for the first xDS server.
+ // The ref count was previously 1 and now is 0.
+ transport1.shutdown();
+ // Calling shutdown() shuts down the GrpcXdsTransport instance for the second xDS server.
+ // The ref count was previously 1 and now is 0.
+ transport2.shutdown();
+ }
+
private static class FakeEventHandler implements
XdsTransportFactory.EventHandler {
private final BlockingQueue respQ = new LinkedBlockingQueue<>();