diff --git a/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java b/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java index f34e336553b..66591cda153 100644 --- a/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java +++ b/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java @@ -436,16 +436,13 @@ public void onFileReloadingTrustManagerBadInitialContentTest() throws Exception } @Test - public void keyManagerAliasesTest() { + public void keyManagerAliasesTest() throws Exception { AdvancedTlsX509KeyManager km = new AdvancedTlsX509KeyManager(); - assertArrayEquals( - new String[] {"default"}, km.getClientAliases("", null)); - assertEquals( - "default", km.chooseClientAlias(new String[] {"default"}, null, null)); - assertArrayEquals( - new String[] {"default"}, km.getServerAliases("", null)); - assertEquals( - "default", km.chooseServerAlias("default", null, null)); + km.updateIdentityCredentials(serverCert0, serverKey0); + assertArrayEquals(new String[] {"key-1"}, km.getClientAliases("", null)); + assertEquals("key-1", km.chooseClientAlias(new String[] {"key-1"}, null, null)); + assertArrayEquals(new String[] {"key-1"}, km.getServerAliases("", null)); + assertEquals("key-1", km.chooseServerAlias("key-1", null, null)); } @Test diff --git a/util/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java b/util/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java index 1f807cd405d..f6034b613f2 100644 --- a/util/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java +++ b/util/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java @@ -32,6 +32,7 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Level; import java.util.logging.Logger; import javax.net.ssl.SSLEngine; @@ -40,59 +41,97 @@ /** * AdvancedTlsX509KeyManager is an {@code X509ExtendedKeyManager} that allows users to configure * advanced TLS features, such as private key and certificate chain reloading. + * + *

The key material alias increments on every credential load (e.g. {@code "key-1"}, + * {@code "key-2"}, ...), ensuring the same alias always maps to the same key material. This is + * required by Netty's {@code OpenSslCachingX509KeyManagerFactory} to correctly cache key + * material and create a new cache entry on cert reload. + * + *

When using {@code SslProvider.OPENSSL}, wrap this key manager in Netty's + * {@code OpenSslCachingX509KeyManagerFactory} to avoid per-handshake key material encoding + * overhead, e.g. {@code new OpenSslCachingX509KeyManagerFactory( + * new KeyManagerFactoryWrapper(advancedTlsKeyManager))}, and pass the factory to + * {@code SslContextBuilder} instead of the key manager directly. */ public final class AdvancedTlsX509KeyManager extends X509ExtendedKeyManager { private static final Logger log = Logger.getLogger(AdvancedTlsX509KeyManager.class.getName()); // Minimum allowed period for refreshing files with credential information. - private static final int MINIMUM_REFRESH_PERIOD_IN_MINUTES = 1 ; + private static final int MINIMUM_REFRESH_PERIOD_IN_MINUTES = 1; + // Prefix for the key material alias; revision counter appended on each credential load. + static final String ALIAS_PREFIX = "key-"; + + private final AtomicInteger revision = new AtomicInteger(0); + private final int revisionWarningThreshold; // The credential information to be sent to peers to prove our identity. private volatile KeyInfo keyInfo; + public AdvancedTlsX509KeyManager() { + // Netty's default OpenSslCachingX509KeyManagerFactory maxCachedEntries. + this(1024); + } + + /** + * Creates a key manager with a custom revision warning threshold. + * @param revisionWarningThreshold the number of credential loads after which a warning is logged. + * Only relevant when using {@code SslProvider.OPENSSL} with + * {@code OpenSslCachingX509KeyManagerFactory}. + */ + public AdvancedTlsX509KeyManager(int revisionWarningThreshold) { + this.revisionWarningThreshold = revisionWarningThreshold; + } + + private String alias() { + KeyInfo info = this.keyInfo; + if (info == null) { + return null; + } + return info.alias; + } + @Override public PrivateKey getPrivateKey(String alias) { - if (alias.equals("default")) { - return this.keyInfo.key; - } - return null; + KeyInfo info = this.keyInfo; + return info != null && info.alias.equals(alias) ? info.key : null; } @Override public X509Certificate[] getCertificateChain(String alias) { - if (alias.equals("default")) { - return Arrays.copyOf(this.keyInfo.certs, this.keyInfo.certs.length); - } - return null; + KeyInfo info = this.keyInfo; + return info != null && info.alias.equals(alias) + ? Arrays.copyOf(info.certs, info.certs.length) : null; } @Override public String[] getClientAliases(String keyType, Principal[] issuers) { - return new String[] {"default"}; + String alias = alias(); + return alias != null ? new String[] {alias} : null; } @Override public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) { - return "default"; + return alias(); } @Override public String chooseEngineClientAlias(String[] keyType, Principal[] issuers, SSLEngine engine) { - return "default"; + return alias(); } @Override public String[] getServerAliases(String keyType, Principal[] issuers) { - return new String[] {"default"}; + String alias = alias(); + return alias != null ? new String[] {alias} : null; } @Override public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) { - return "default"; + return alias(); } @Override public String chooseEngineServerAlias(String keyType, Principal[] issuers, SSLEngine engine) { - return "default"; + return alias(); } /** @@ -116,7 +155,15 @@ public void updateIdentityCredentials(PrivateKey key, X509Certificate[] certs) { * @param key the private key that is going to be used */ public void updateIdentityCredentials(X509Certificate[] certs, PrivateKey key) { - this.keyInfo = new KeyInfo(checkNotNull(certs, "certs"), checkNotNull(key, "key")); + // When using SslProvider.OPENSSL with OpenSslCachingX509KeyManagerFactory, its cache stops + // accepting new aliases once maxCachedEntries is reached (default: 1024). Beyond this, + // handshakes still succeed but per-handshake re-encoding overhead resumes. + if (revision.get() >= revisionWarningThreshold) { + log.warning("AdvancedTlsX509KeyManager: revision counter has reached " + + revisionWarningThreshold); + } + this.keyInfo = new KeyInfo(checkNotNull(certs, "certs"), checkNotNull(key, "key"), + ALIAS_PREFIX + revision.incrementAndGet()); } /** @@ -218,10 +265,12 @@ private static class KeyInfo { // The private key and the cert chain we will use to send to peers to prove our identity. final X509Certificate[] certs; final PrivateKey key; + final String alias; - public KeyInfo(X509Certificate[] certs, PrivateKey key) { + public KeyInfo(X509Certificate[] certs, PrivateKey key, String alias) { this.certs = certs; this.key = key; + this.alias = alias; } } @@ -309,4 +358,3 @@ public interface Closeable extends java.io.Closeable { void close(); } } - diff --git a/util/src/test/java/io/grpc/util/AdvancedTlsX509KeyManagerTest.java b/util/src/test/java/io/grpc/util/AdvancedTlsX509KeyManagerTest.java index f96c85e4f4f..2974da32fc3 100644 --- a/util/src/test/java/io/grpc/util/AdvancedTlsX509KeyManagerTest.java +++ b/util/src/test/java/io/grpc/util/AdvancedTlsX509KeyManagerTest.java @@ -18,6 +18,8 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -48,7 +50,6 @@ public class AdvancedTlsX509KeyManagerTest { private static final String SERVER_0_PEM_FILE = "server0.pem"; private static final String CLIENT_0_KEY_FILE = "client.key"; private static final String CLIENT_0_PEM_FILE = "client.pem"; - private static final String ALIAS = "default"; private ScheduledExecutorService executor; @@ -79,22 +80,99 @@ public void setUp() throws Exception { public void updateTrustCredentials_replacesIssuers() throws Exception { // Overall happy path checking of public API. AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager(); + serverKeyManager.updateIdentityCredentials(serverCert0, serverKey0); - assertEquals(serverKey0, serverKeyManager.getPrivateKey(ALIAS)); - assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(ALIAS)); + String alias1 = serverKeyManager.chooseEngineServerAlias(null, null, null); + assertEquals(AdvancedTlsX509KeyManager.ALIAS_PREFIX + "1", alias1); + assertEquals(serverKey0, serverKeyManager.getPrivateKey(alias1)); + assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(alias1)); serverKeyManager.updateIdentityCredentials(clientCert0File, clientKey0File); - assertEquals(clientKey0, serverKeyManager.getPrivateKey(ALIAS)); - assertArrayEquals(clientCert0, serverKeyManager.getCertificateChain(ALIAS)); - - serverKeyManager.updateIdentityCredentials(serverCert0File, serverKey0File,1, + String alias2 = serverKeyManager.chooseEngineServerAlias(null, null, null); + assertEquals(AdvancedTlsX509KeyManager.ALIAS_PREFIX + "2", alias2); + assertEquals(clientKey0, serverKeyManager.getPrivateKey(alias2)); + assertArrayEquals(clientCert0, serverKeyManager.getCertificateChain(alias2)); + // Old alias no longer resolves — ensures alias stability contract is enforced. + assertNull(serverKeyManager.getPrivateKey(alias1)); + + serverKeyManager.updateIdentityCredentials(serverCert0File, serverKey0File, 1, TimeUnit.MINUTES, executor); - assertEquals(serverKey0, serverKeyManager.getPrivateKey(ALIAS)); - assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(ALIAS)); + String alias3 = serverKeyManager.chooseEngineServerAlias(null, null, null); + assertEquals(serverKey0, serverKeyManager.getPrivateKey(alias3)); + assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(alias3)); serverKeyManager.updateIdentityCredentials(serverCert0, serverKey0); - assertEquals(serverKey0, serverKeyManager.getPrivateKey(ALIAS)); - assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(ALIAS)); + String alias4 = serverKeyManager.chooseEngineServerAlias(null, null, null); + assertEquals(serverKey0, serverKeyManager.getPrivateKey(alias4)); + assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(alias4)); + } + + @Test + public void allAliasMethods_returnNullBeforeCredentialsLoaded() { + AdvancedTlsX509KeyManager keyManager = new AdvancedTlsX509KeyManager(); + + assertNull(keyManager.chooseClientAlias(null, null, null)); + assertNull(keyManager.chooseServerAlias(null, null, null)); + assertNull(keyManager.chooseEngineClientAlias(null, null, null)); + assertNull(keyManager.chooseEngineServerAlias(null, null, null)); + assertNull(keyManager.getClientAliases(null, null)); + assertNull(keyManager.getServerAliases(null, null)); + assertNull(keyManager.getPrivateKey("key-1")); + assertNull(keyManager.getCertificateChain("key-1")); + } + + @Test + public void allAliasMethods_agreeAfterCredentialLoad() throws Exception { + AdvancedTlsX509KeyManager keyManager = new AdvancedTlsX509KeyManager(); + keyManager.updateIdentityCredentials(serverCert0, serverKey0); + + String expectedAlias = AdvancedTlsX509KeyManager.ALIAS_PREFIX + "1"; + assertEquals(expectedAlias, keyManager.chooseClientAlias(null, null, null)); + assertEquals(expectedAlias, keyManager.chooseServerAlias(null, null, null)); + assertEquals(expectedAlias, keyManager.chooseEngineClientAlias(null, null, null)); + assertEquals(expectedAlias, keyManager.chooseEngineServerAlias(null, null, null)); + assertArrayEquals(new String[]{expectedAlias}, keyManager.getClientAliases(null, null)); + assertArrayEquals(new String[]{expectedAlias}, keyManager.getServerAliases(null, null)); + } + + @Test + public void revisionWarningThreshold_logsWarningAtThreshold() throws Exception { + Logger log = Logger.getLogger(AdvancedTlsX509KeyManager.class.getName()); + TestHandler handler = new TestHandler(); + log.addHandler(handler); + log.setUseParentHandlers(false); + log.setLevel(Level.ALL); + + try { + // Custom threshold: warning when revision reaches threshold. + int threshold = 3; + AdvancedTlsX509KeyManager customKeyManager = new AdvancedTlsX509KeyManager(threshold); + for (int i = 0; i < threshold; i++) { + customKeyManager.updateIdentityCredentials(serverCert0, serverKey0); + } + assertFalse(hasRevisionWarning(handler)); + customKeyManager.updateIdentityCredentials(serverCert0, serverKey0); + assertTrue(hasRevisionWarning(handler)); + + // Key manager must still provide credentials correctly after soft threshold is exceeded. + String alias = customKeyManager.chooseEngineServerAlias(null, null, null); + assertEquals(serverKey0, customKeyManager.getPrivateKey(alias)); + assertArrayEquals(serverCert0, customKeyManager.getCertificateChain(alias)); + + // Further credential updates must also work. + customKeyManager.updateIdentityCredentials(clientCert0File, clientKey0File); + String newAlias = customKeyManager.chooseEngineServerAlias(null, null, null); + assertEquals(clientKey0, customKeyManager.getPrivateKey(newAlias)); + assertArrayEquals(clientCert0, customKeyManager.getCertificateChain(newAlias)); + } finally { + log.removeHandler(handler); + } + } + + private static boolean hasRevisionWarning(TestHandler handler) { + return handler.getRecords().stream() + .anyMatch(r -> Level.WARNING.equals(r.getLevel()) + && r.getMessage().contains("revision counter has reached")); } @Test