Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 57 additions & 5 deletions xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,39 +31,74 @@
import io.grpc.Status;
import io.grpc.xds.client.Bootstrapper;
import io.grpc.xds.client.XdsTransportFactory;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

/**
* A factory for creating gRPC-based transports for xDS communication.
*
* <p>WARNING: This class reuses channels when possible, based on the provided {@link
* Bootstrapper.ServerInfo} with important considerations. The {@link Bootstrapper.ServerInfo}
* includes {@link ChannelCredentials}, which is compared by reference equality. This means every
* {@link Bootstrapper.BootstrapInfo} would have non-equal copies of {@link
* Bootstrapper.ServerInfo}, even if they all represent the same xDS server configuration. For gRPC
* name resolution with the {@code xds} and {@code google-c2p} scheme, this transport sharing works
* as expected as it internally reuses a single {@link Bootstrapper.BootstrapInfo} instance.
* Otherwise, new transports would be created for each {@link Bootstrapper.ServerInfo} despite them
* possibly representing the same xDS server configuration and defeating the purpose of transport
* sharing.
*/
final class GrpcXdsTransportFactory implements XdsTransportFactory {

private final CallCredentials callCredentials;
// The map of xDS server info to its corresponding gRPC xDS transport.
// This enables reusing and sharing the same underlying gRPC channel.
//
// NOTE: ConcurrentHashMap is used as a per-entry lock and all reads and writes must be a mutation
// via the ConcurrentHashMap APIs to acquire the per-entry lock in order to ensure thread safety
// for reference counting of each GrpcXdsTransport instance.
private static final Map<Bootstrapper.ServerInfo, GrpcXdsTransport> xdsServerInfoToTransportMap =
new ConcurrentHashMap<>();

GrpcXdsTransportFactory(CallCredentials callCredentials) {
this.callCredentials = callCredentials;
}

@Override
public XdsTransport create(Bootstrapper.ServerInfo serverInfo) {
return new GrpcXdsTransport(serverInfo, callCredentials);
return xdsServerInfoToTransportMap.compute(
serverInfo,
(info, transport) -> {
if (transport == null) {
transport = new GrpcXdsTransport(serverInfo, callCredentials);
}
++transport.refCount;
return transport;
});
}

@VisibleForTesting
public XdsTransport createForTest(ManagedChannel channel) {
return new GrpcXdsTransport(channel, callCredentials);
return new GrpcXdsTransport(channel, callCredentials, null);
}

@VisibleForTesting
static class GrpcXdsTransport implements XdsTransport {

private final ManagedChannel channel;
private final CallCredentials callCredentials;
private final Bootstrapper.ServerInfo serverInfo;
// Must only be accessed via the ConcurrentHashMap APIs which act as the locking methods.
private int refCount = 0;

public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo) {
this(serverInfo, null);
}

@VisibleForTesting
public GrpcXdsTransport(ManagedChannel channel) {
this(channel, null);
this(channel, null, null);
}

public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo, CallCredentials callCredentials) {
Expand All @@ -73,12 +108,17 @@ public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo, CallCredentials call
.keepAliveTime(5, TimeUnit.MINUTES)
.build();
this.callCredentials = callCredentials;
this.serverInfo = serverInfo;
}

@VisibleForTesting
public GrpcXdsTransport(ManagedChannel channel, CallCredentials callCredentials) {
public GrpcXdsTransport(
ManagedChannel channel,
CallCredentials callCredentials,
Bootstrapper.ServerInfo serverInfo) {
this.channel = checkNotNull(channel, "channel");
this.callCredentials = callCredentials;
this.serverInfo = serverInfo;
}

@Override
Expand All @@ -98,7 +138,19 @@ public <ReqT, RespT> StreamingCall<ReqT, RespT> createStreamingCall(

@Override
public void shutdown() {
channel.shutdown();
if (serverInfo == null) {
channel.shutdown();
return;
}
xdsServerInfoToTransportMap.computeIfPresent(
serverInfo,
(info, transport) -> {
if (--transport.refCount == 0) { // Prefix decrement and return the updated value.
transport.channel.shutdown();
return null; // Remove mapping.
}
return transport;
});
}

private class XdsStreamingCall<ReqT, RespT> implements
Expand Down
57 changes: 57 additions & 0 deletions xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,24 @@
import io.grpc.Server;
import io.grpc.Status;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.GrpcCleanupRule;
import io.grpc.xds.client.Bootstrapper;
import io.grpc.xds.client.XdsTransportFactory;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class GrpcXdsTransportFactoryTest {

@Rule public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule();

private Server server;

@Before
Expand Down Expand Up @@ -118,6 +122,59 @@ public void callApis() throws Exception {
xdsTransport.shutdown();
}

@Test
public void refCountedXdsTransport_sameXdsServerAddress_returnsExistingTransport() {
Bootstrapper.ServerInfo xdsServerInfo =
Bootstrapper.ServerInfo.create(
"localhost:" + server.getPort(), InsecureChannelCredentials.create());
GrpcXdsTransportFactory xdsTransportFactory = new GrpcXdsTransportFactory(null);
// Calling create() for the first time creates a new GrpcXdsTransport instance.
// The ref count was previously 0 and now is 1.
XdsTransportFactory.XdsTransport transport1 = xdsTransportFactory.create(xdsServerInfo);
// Calling create() for the second time to the same xDS server address returns the same
// GrpcXdsTransport instance. The ref count was previously 1 and now is 2.
XdsTransportFactory.XdsTransport transport2 = xdsTransportFactory.create(xdsServerInfo);
assertThat(transport1).isSameInstanceAs(transport2);
// Calling shutdown() for the first time does not shut down the GrpcXdsTransport instance.
// The ref count was previously 2 and now is 1.
transport1.shutdown();
// Calling shutdown() for the second time shuts down the GrpcXdsTransport instance.
// The ref count was previously 1 and now is 0.
transport2.shutdown();
}

@Test
public void refCountedXdsTransport_differentXdsServerAddress_returnsDifferentTransport()
throws Exception {
// Create and start a second xDS server on a different port.
Server server2 =
grpcCleanupRule.register(
Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create())
.addService(echoAdsService())
.build()
.start());
Bootstrapper.ServerInfo xdsServerInfo1 =
Bootstrapper.ServerInfo.create(
"localhost:" + server.getPort(), InsecureChannelCredentials.create());
Bootstrapper.ServerInfo xdsServerInfo2 =
Bootstrapper.ServerInfo.create(
"localhost:" + server2.getPort(), InsecureChannelCredentials.create());
GrpcXdsTransportFactory xdsTransportFactory = new GrpcXdsTransportFactory(null);
// Calling create() to the first xDS server creates a new GrpcXdsTransport instance.
// The ref count was previously 0 and now is 1.
XdsTransportFactory.XdsTransport transport1 = xdsTransportFactory.create(xdsServerInfo1);
// Calling create() to the second xDS server creates a different GrpcXdsTransport instance.
// The ref count was previously 0 and now is 1.
XdsTransportFactory.XdsTransport transport2 = xdsTransportFactory.create(xdsServerInfo2);
assertThat(transport1).isNotSameInstanceAs(transport2);
// Calling shutdown() shuts down the GrpcXdsTransport instance for the first xDS server.
// The ref count was previously 1 and now is 0.
transport1.shutdown();
// Calling shutdown() shuts down the GrpcXdsTransport instance for the second xDS server.
// The ref count was previously 1 and now is 0.
transport2.shutdown();
}

private static class FakeEventHandler implements
XdsTransportFactory.EventHandler<DiscoveryResponse> {
private final BlockingQueue<DiscoveryResponse> respQ = new LinkedBlockingQueue<>();
Expand Down
Loading