Currently, TLS session tickets introduced by [JDK-8211018](https://bugs.openjdk.org/browse/JDK-8211018) in JDK 13 (i.e. `SessionTicketExtension$StatelessKey`) are generated in the class `SessionTicketExtension` and they use a single, global key ID (`currentKeyID`) for all `SSLContext`s.
This is problematic if more than one `SSLContext` is used, because every context which requests a session ticket will increment the global id `currentKeyID` when it creates a ticket. This means that in turn all the other contexts won't be able to find a ticket under the new id in their `SSLContextImpl` and create a new one (again incrementing `currentKeyID`). In fact, every time a ticket is requested from a different context, this will transitively trigger a new ticket creation in all the other contexts. We've observed millions of session ticket accumulating for some workloads.
Another issue with the curent implementation is that cleanup is racy because the underlying data structure (i.e. `keyHashMap` in `SSLContextImpl`) as well as the cleanup code itself are not threadsafe.
I therefor propose to move `currentKeyID` into the `SSLContextImpl` to solve these issues.
The following test program (contributed by Steven Collison (https://raycoll.com/)) can be used to demonstrate the current behaviour. It outputs the number of `StatelessKey` instances at the end of the program. Opening 1000 connections with a single `SSLContext` results in a single `StatelessKey` instance being created:
```
$ java -XX:+UseSerialGC -Xmx16m -cp ~/Java/ SSLSocketServerMultipleSSLContext 9999 1 1000
605: 1 32 sun.security.ssl.SessionTicketExtension$StatelessKey (java.base@20-internal)
```
The same example with the 1000 connections being opened alternatively on thwo different contexts will instead create 1000 `StatelessKey` instances:
```
$ java -XX:+UseSerialGC -Xmx16m -cp ~/Java/ SSLSocketServerMultipleSSLContext 9999 2 1000
11: 1000 32000 sun.security.ssl.SessionTicketExtension$StatelessKey (java.base@20-internal)
```
With my proposed patch, the numbers goes back to two instances again:
```
$ java -XX:+UseSerialGC -Xmx16m -cp ~/Java/ SSLSocketServerMultipleSSLContext 9999 2 1000
611: 2 64 sun.security.ssl.SessionTicketExtension$StatelessKey (java.base@20-internal)
```
```
// Contributed by Steven Collison (https://raycoll.com/)
//
// Requires the trust store 'testkeys.jks' in the current working directory which can be created as follows:
//
// keytool -genkey -alias test -keyalg RSA -keypass testkeys -storepass testkeys -keystore testkeys.jks -keysize 2048 -validity 1461
//
// When prompted for input always press <return> and answer the last question with "yes"
// The newly created, self signed certificate can be verified with (use 'testkeys' as password):
// keytool -list -v -keystore testkeys.jks
import java.net.*;
import java.io.*;
import javax.net.ssl.*;
import javax.security.cert.*;
import java.util.Enumeration;
import java.util.ArrayList;
import java.util.List;
import java.security.SecureRandom;
import javax.net.ServerSocketFactory;
import javax.net.ssl.SSLServerSocket;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import java.security.KeyStore;
/**
* A basic TLS server that allows multiple SSLContexts to be attached to the same port.
* This is used to demonstrate a memory leak when session tickets are used and handshakes are
* processed by multiple SSLContexts objects.
* Usage: SSLSocketServerMultipleSSLContext port num_ssl_contexts num_connections_to_accept
*/
public class SSLSocketServerMultipleSSLContext {
private static final String trustFilename = "testkeys.jks";
private static final String password = "testkeys";
private static final boolean debug = Boolean.getBoolean("debug");
public static void main(String[] args) throws Exception {
System.setProperty("javax.net.ssl.trustStore", trustFilename);
System.setProperty("javax.net.ssl.trustStorePassword", password);
System.setProperty("jdk.tls.client.protocols", "TLSv1.2");
int port = Integer.parseInt(args[0]);
int num_contexts = Integer.parseInt(args[1]);
int iterations = Integer.parseInt(args[2]);
// Create the requested number of SocketFactories. Each of these use a unique SSLContext instance.
ArrayList<SSLSocketFactory> sslSocketFactories = new ArrayList<SSLSocketFactory>();
for (int i = 0; i < num_contexts; i++) {
sslSocketFactories.add(getSocketFactory("TLS"));
}
// Create the plain serversocket(we'll wrap convert to SSLSocket with the chosen SSLSocketFactory after accept).
ServerSocket serverSocket = new ServerSocket(port);
for (int i = 0; i < iterations; i++) {
// The following line can be commented out to run this test from the command line with:
// echo "Q" | openssl s_client -tls1_2 -sess_in sess_1.sess -sess_out sess_2.sess -connect localhost:9999
startClient(port, i);
try (Socket socket = serverSocket.accept()) {
if (debug) {
System.out.println("accepted");
}
// Wrap plain socket in SSL, round robin across the list of SSLSocketFactories we have.
SSLSocketFactory selectedFactory = sslSocketFactories.get(i % sslSocketFactories.size());
SSLSocket sslSocket = (SSLSocket) selectedFactory.createSocket(socket, null, true);
sslSocket.startHandshake();
InputStream is = new BufferedInputStream(sslSocket.getInputStream());
OutputStream os = new BufferedOutputStream(sslSocket.getOutputStream());
byte[] data = new byte[2048];
int len = is.read(data);
if (debug) {
System.out.println("Received: " + new String(data));
}
socket.close();
} catch (Exception e) {
e.printStackTrace();
}
}
if (debug) {
System.out.println("server stopped");
}
long pid = ProcessHandle.current().pid();
String jdkPath = System.getProperty("java.home");
ProcessBuilder pb = new ProcessBuilder(jdkPath + "/bin/jcmd", Long.toString(pid), " GC.class_histogram");
Process jcmd = pb.start();
List<String> line = jcmd.inputReader().lines().filter(l -> l.contains("StatelessKey")).toList();
if (line.size() == 1) {
System.out.println(line.get(0));
}
}
private static void startClient(int port, int iteration) {
new Thread() {
public void run() {
try {
SSLSocket socket = (SSLSocket)SSLContext.getDefault().getSocketFactory().createSocket();
socket.connect(new InetSocketAddress("localhost", port));
OutputStream os = new BufferedOutputStream(socket.getOutputStream());
os.write(new byte[] {'h', 'e', 'l', 'l', 'o'});
os.flush();
long creationTime = socket.getSession().getCreationTime();
if (debug) {
System.out.println(String.format("creationTime (%d) = %d", iteration, creationTime));
}
socket.close();
} catch (Exception ex) {
ex.printStackTrace();
}
}
} . start();
}
private static SSLSocketFactory getSocketFactory(String type) {
if (type.equals("TLS")) {
SSLSocketFactory ssf = null;
try {
// set up key manager to do server authentication
SSLContext ctx;
KeyManagerFactory kmf;
KeyStore ks;
char[] passphrase = password.toCharArray();
ctx = SSLContext.getInstance("TLS");
kmf = KeyManagerFactory.getInstance("SunX509");
ks = KeyStore.getInstance("JKS");
ks.load(new FileInputStream(trustFilename), passphrase);
kmf.init(ks, passphrase);
ctx.init(kmf.getKeyManagers(), null, null);
ssf = ctx.getSocketFactory();
return ssf;
} catch (Exception e) {
e.printStackTrace();
}
} else {
return null;
}
return null;
}
}
```