Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -436,16 +436,13 @@ public void onFileReloadingTrustManagerBadInitialContentTest() throws Exception
}

@Test
public void keyManagerAliasesTest() {
public void keyManagerAliasesTest() throws Exception {
AdvancedTlsX509KeyManager km = new AdvancedTlsX509KeyManager();
assertArrayEquals(
new String[] {"default"}, km.getClientAliases("", null));
assertEquals(
"default", km.chooseClientAlias(new String[] {"default"}, null, null));
assertArrayEquals(
new String[] {"default"}, km.getServerAliases("", null));
assertEquals(
"default", km.chooseServerAlias("default", null, null));
km.updateIdentityCredentials(serverCert0, serverKey0);
assertArrayEquals(new String[] {"key-1"}, km.getClientAliases("", null));
assertEquals("key-1", km.chooseClientAlias(new String[] {"key-1"}, null, null));
assertArrayEquals(new String[] {"key-1"}, km.getServerAliases("", null));
assertEquals("key-1", km.chooseServerAlias("key-1", null, null));
}

@Test
Expand Down
84 changes: 66 additions & 18 deletions util/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.net.ssl.SSLEngine;
Expand All @@ -40,59 +41,97 @@
/**
* AdvancedTlsX509KeyManager is an {@code X509ExtendedKeyManager} that allows users to configure
* advanced TLS features, such as private key and certificate chain reloading.
*
* <p>The key material alias increments on every credential load (e.g. {@code "key-1"},
* {@code "key-2"}, ...), ensuring the same alias always maps to the same key material. This is
* required by Netty's {@code OpenSslCachingX509KeyManagerFactory} to correctly cache key
* material and create a new cache entry on cert reload.
*
* <p>When using {@code SslProvider.OPENSSL}, wrap this key manager in Netty's
* {@code OpenSslCachingX509KeyManagerFactory} to avoid per-handshake key material encoding
* overhead, e.g. {@code new OpenSslCachingX509KeyManagerFactory(
* new KeyManagerFactoryWrapper(advancedTlsKeyManager))}, and pass the factory to
* {@code SslContextBuilder} instead of the key manager directly.
*/
public final class AdvancedTlsX509KeyManager extends X509ExtendedKeyManager {
private static final Logger log = Logger.getLogger(AdvancedTlsX509KeyManager.class.getName());
// Minimum allowed period for refreshing files with credential information.
private static final int MINIMUM_REFRESH_PERIOD_IN_MINUTES = 1 ;
private static final int MINIMUM_REFRESH_PERIOD_IN_MINUTES = 1;
// Prefix for the key material alias; revision counter appended on each credential load.
static final String ALIAS_PREFIX = "key-";

private final AtomicInteger revision = new AtomicInteger(0);
private final int revisionWarningThreshold;
// The credential information to be sent to peers to prove our identity.
private volatile KeyInfo keyInfo;

public AdvancedTlsX509KeyManager() {
// Netty's default OpenSslCachingX509KeyManagerFactory maxCachedEntries.
this(1024);
}

/**
* Creates a key manager with a custom revision warning threshold.
* @param revisionWarningThreshold the number of credential loads after which a warning is logged.
* Only relevant when using {@code SslProvider.OPENSSL} with
* {@code OpenSslCachingX509KeyManagerFactory}.
*/
public AdvancedTlsX509KeyManager(int revisionWarningThreshold) {
this.revisionWarningThreshold = revisionWarningThreshold;
}

private String alias() {
KeyInfo info = this.keyInfo;
if (info == null) {
return null;
}
return info.alias;
}

@Override
public PrivateKey getPrivateKey(String alias) {
if (alias.equals("default")) {
return this.keyInfo.key;
}
return null;
KeyInfo info = this.keyInfo;
return info != null && info.alias.equals(alias) ? info.key : null;
}

@Override
public X509Certificate[] getCertificateChain(String alias) {
if (alias.equals("default")) {
return Arrays.copyOf(this.keyInfo.certs, this.keyInfo.certs.length);
}
return null;
KeyInfo info = this.keyInfo;
return info != null && info.alias.equals(alias)
? Arrays.copyOf(info.certs, info.certs.length) : null;
}

@Override
public String[] getClientAliases(String keyType, Principal[] issuers) {
return new String[] {"default"};
String alias = alias();
return alias != null ? new String[] {alias} : null;
}

@Override
public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) {
return "default";
return alias();
}

@Override
public String chooseEngineClientAlias(String[] keyType, Principal[] issuers, SSLEngine engine) {
return "default";
return alias();
}

@Override
public String[] getServerAliases(String keyType, Principal[] issuers) {
return new String[] {"default"};
String alias = alias();
return alias != null ? new String[] {alias} : null;
}

@Override
public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) {
return "default";
return alias();
}

@Override
public String chooseEngineServerAlias(String keyType, Principal[] issuers,
SSLEngine engine) {
return "default";
return alias();
}

/**
Expand All @@ -116,7 +155,15 @@ public void updateIdentityCredentials(PrivateKey key, X509Certificate[] certs) {
* @param key the private key that is going to be used
*/
public void updateIdentityCredentials(X509Certificate[] certs, PrivateKey key) {
this.keyInfo = new KeyInfo(checkNotNull(certs, "certs"), checkNotNull(key, "key"));
// When using SslProvider.OPENSSL with OpenSslCachingX509KeyManagerFactory, its cache stops
// accepting new aliases once maxCachedEntries is reached (default: 1024). Beyond this,
// handshakes still succeed but per-handshake re-encoding overhead resumes.
if (revision.get() >= revisionWarningThreshold) {
log.warning("AdvancedTlsX509KeyManager: revision counter has reached "
+ revisionWarningThreshold);
}
this.keyInfo = new KeyInfo(checkNotNull(certs, "certs"), checkNotNull(key, "key"),
ALIAS_PREFIX + revision.incrementAndGet());
}

/**
Expand Down Expand Up @@ -218,10 +265,12 @@ private static class KeyInfo {
// The private key and the cert chain we will use to send to peers to prove our identity.
final X509Certificate[] certs;
final PrivateKey key;
final String alias;

public KeyInfo(X509Certificate[] certs, PrivateKey key) {
public KeyInfo(X509Certificate[] certs, PrivateKey key, String alias) {
this.certs = certs;
this.key = key;
this.alias = alias;
}
}

Expand Down Expand Up @@ -309,4 +358,3 @@ public interface Closeable extends java.io.Closeable {
void close();
}
}

100 changes: 89 additions & 11 deletions util/src/test/java/io/grpc/util/AdvancedTlsX509KeyManagerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
Expand Down Expand Up @@ -48,7 +50,6 @@ public class AdvancedTlsX509KeyManagerTest {
private static final String SERVER_0_PEM_FILE = "server0.pem";
private static final String CLIENT_0_KEY_FILE = "client.key";
private static final String CLIENT_0_PEM_FILE = "client.pem";
private static final String ALIAS = "default";

private ScheduledExecutorService executor;

Expand Down Expand Up @@ -79,22 +80,99 @@ public void setUp() throws Exception {
public void updateTrustCredentials_replacesIssuers() throws Exception {
// Overall happy path checking of public API.
AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager();

serverKeyManager.updateIdentityCredentials(serverCert0, serverKey0);
assertEquals(serverKey0, serverKeyManager.getPrivateKey(ALIAS));
assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(ALIAS));
String alias1 = serverKeyManager.chooseEngineServerAlias(null, null, null);
assertEquals(AdvancedTlsX509KeyManager.ALIAS_PREFIX + "1", alias1);
assertEquals(serverKey0, serverKeyManager.getPrivateKey(alias1));
assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(alias1));

serverKeyManager.updateIdentityCredentials(clientCert0File, clientKey0File);
assertEquals(clientKey0, serverKeyManager.getPrivateKey(ALIAS));
assertArrayEquals(clientCert0, serverKeyManager.getCertificateChain(ALIAS));

serverKeyManager.updateIdentityCredentials(serverCert0File, serverKey0File,1,
String alias2 = serverKeyManager.chooseEngineServerAlias(null, null, null);
assertEquals(AdvancedTlsX509KeyManager.ALIAS_PREFIX + "2", alias2);
assertEquals(clientKey0, serverKeyManager.getPrivateKey(alias2));
assertArrayEquals(clientCert0, serverKeyManager.getCertificateChain(alias2));
// Old alias no longer resolves — ensures alias stability contract is enforced.
assertNull(serverKeyManager.getPrivateKey(alias1));

serverKeyManager.updateIdentityCredentials(serverCert0File, serverKey0File, 1,
TimeUnit.MINUTES, executor);
assertEquals(serverKey0, serverKeyManager.getPrivateKey(ALIAS));
assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(ALIAS));
String alias3 = serverKeyManager.chooseEngineServerAlias(null, null, null);
assertEquals(serverKey0, serverKeyManager.getPrivateKey(alias3));
assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(alias3));

serverKeyManager.updateIdentityCredentials(serverCert0, serverKey0);
assertEquals(serverKey0, serverKeyManager.getPrivateKey(ALIAS));
assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(ALIAS));
String alias4 = serverKeyManager.chooseEngineServerAlias(null, null, null);
assertEquals(serverKey0, serverKeyManager.getPrivateKey(alias4));
assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(alias4));
}

@Test
public void allAliasMethods_returnNullBeforeCredentialsLoaded() {
AdvancedTlsX509KeyManager keyManager = new AdvancedTlsX509KeyManager();

assertNull(keyManager.chooseClientAlias(null, null, null));
assertNull(keyManager.chooseServerAlias(null, null, null));
assertNull(keyManager.chooseEngineClientAlias(null, null, null));
assertNull(keyManager.chooseEngineServerAlias(null, null, null));
assertNull(keyManager.getClientAliases(null, null));
assertNull(keyManager.getServerAliases(null, null));
assertNull(keyManager.getPrivateKey("key-1"));
assertNull(keyManager.getCertificateChain("key-1"));
}

@Test
public void allAliasMethods_agreeAfterCredentialLoad() throws Exception {
AdvancedTlsX509KeyManager keyManager = new AdvancedTlsX509KeyManager();
keyManager.updateIdentityCredentials(serverCert0, serverKey0);

String expectedAlias = AdvancedTlsX509KeyManager.ALIAS_PREFIX + "1";
assertEquals(expectedAlias, keyManager.chooseClientAlias(null, null, null));
assertEquals(expectedAlias, keyManager.chooseServerAlias(null, null, null));
assertEquals(expectedAlias, keyManager.chooseEngineClientAlias(null, null, null));
assertEquals(expectedAlias, keyManager.chooseEngineServerAlias(null, null, null));
assertArrayEquals(new String[]{expectedAlias}, keyManager.getClientAliases(null, null));
assertArrayEquals(new String[]{expectedAlias}, keyManager.getServerAliases(null, null));
}

@Test
public void revisionWarningThreshold_logsWarningAtThreshold() throws Exception {
Logger log = Logger.getLogger(AdvancedTlsX509KeyManager.class.getName());
TestHandler handler = new TestHandler();
log.addHandler(handler);
log.setUseParentHandlers(false);
log.setLevel(Level.ALL);

try {
// Custom threshold: warning when revision reaches threshold.
int threshold = 3;
AdvancedTlsX509KeyManager customKeyManager = new AdvancedTlsX509KeyManager(threshold);
for (int i = 0; i < threshold; i++) {
customKeyManager.updateIdentityCredentials(serverCert0, serverKey0);
}
assertFalse(hasRevisionWarning(handler));
customKeyManager.updateIdentityCredentials(serverCert0, serverKey0);
assertTrue(hasRevisionWarning(handler));

// Key manager must still provide credentials correctly after soft threshold is exceeded.
String alias = customKeyManager.chooseEngineServerAlias(null, null, null);
assertEquals(serverKey0, customKeyManager.getPrivateKey(alias));
assertArrayEquals(serverCert0, customKeyManager.getCertificateChain(alias));

// Further credential updates must also work.
customKeyManager.updateIdentityCredentials(clientCert0File, clientKey0File);
String newAlias = customKeyManager.chooseEngineServerAlias(null, null, null);
assertEquals(clientKey0, customKeyManager.getPrivateKey(newAlias));
assertArrayEquals(clientCert0, customKeyManager.getCertificateChain(newAlias));
} finally {
log.removeHandler(handler);
}
}

private static boolean hasRevisionWarning(TestHandler handler) {
return handler.getRecords().stream()
.anyMatch(r -> Level.WARNING.equals(r.getLevel())
&& r.getMessage().contains("revision counter has reached"));
}

@Test
Expand Down
Loading