This commit is contained in:
Abhishek Kumar 2026-03-09 13:14:41 +00:00 committed by GitHub
commit 15a514be67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 163 additions and 170 deletions

View File

@ -64,23 +64,22 @@ public class Link {
private final NioConnection _connection;
private SelectionKey _key;
private final ConcurrentLinkedQueue<ByteBuffer[]> _writeQueue;
private ByteBuffer _readBuffer;
private ByteBuffer _plaintextBuffer;
private final ByteBuffer headerBuffer = ByteBuffer.allocate(4); // accumulates length header inside TLS
private ByteBuffer netBuffer;
private ByteBuffer appBuffer;
private ByteBuffer plainTextBuffer;
private int frameRemaining = -1; // remaining bytes for current frame (inside TLS)
private Object _attach;
private boolean _readHeader;
private boolean _gotFollowingPacket;
private SSLEngine _sslEngine;
public Link(InetSocketAddress addr, NioConnection connection) {
_addr = addr;
_connection = connection;
_readBuffer = ByteBuffer.allocate(2048);
_attach = null;
_key = null;
_writeQueue = new ConcurrentLinkedQueue<ByteBuffer[]>();
_readHeader = true;
_gotFollowingPacket = false;
plainTextBuffer = null;
}
public Link(Link link) {
@ -103,58 +102,82 @@ public class Link {
public void setSSLEngine(SSLEngine sslEngine) {
_sslEngine = sslEngine;
if (_sslEngine == null) {
netBuffer = null;
appBuffer = null;
headerBuffer.clear();
frameRemaining = -1;
plainTextBuffer = null;
return;
}
final SSLSession s = _sslEngine.getSession();
netBuffer = ByteBuffer.allocate(Math.max(s.getPacketBufferSize(), 16 * 1024));
appBuffer = ByteBuffer.allocate(Math.max(s.getApplicationBufferSize(), 16 * 1024));
headerBuffer.clear();
frameRemaining = -1;
plainTextBuffer = null;
}
private static void doWrite(SocketChannel ch, ByteBuffer[] buffers, SSLEngine sslEngine) throws IOException {
SSLSession sslSession = sslEngine.getSession();
ByteBuffer pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
SSLEngineResult engResult;
ByteBuffer headBuf = ByteBuffer.allocate(4);
if (sslEngine == null) {
throw new IOException("SSLEngine not set");
}
final SSLSession session = sslEngine.getSession();
ByteBuffer netBuf = ByteBuffer.allocate(session.getPacketBufferSize());
// Build app sequence: 4-byte length header + payload buffers
int totalLen = 0;
for (ByteBuffer buffer : buffers) {
totalLen += buffer.limit();
for (ByteBuffer b : buffers) totalLen += b.remaining();
ByteBuffer header = ByteBuffer.allocate(4);
header.putInt(totalLen).flip();
ByteBuffer[] appSeq = new ByteBuffer[buffers.length + 1];
appSeq[0] = header;
for (int i = 0; i < buffers.length; i++) {
appSeq[i + 1] = buffers[i].duplicate();
}
int processedLen = 0;
while (processedLen < totalLen) {
headBuf.clear();
pkgBuf.clear();
engResult = sslEngine.wrap(buffers, pkgBuf);
if (engResult.getHandshakeStatus() != HandshakeStatus.FINISHED && engResult.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING &&
engResult.getStatus() != SSLEngineResult.Status.OK) {
throw new IOException("SSL: SSLEngine return bad result! " + engResult);
}
processedLen = 0;
for (ByteBuffer buffer : buffers) {
processedLen += buffer.position();
}
int dataRemaining = pkgBuf.position();
int header = dataRemaining;
int headRemaining = 4;
pkgBuf.flip();
if (processedLen < totalLen) {
header = header | HEADER_FLAG_FOLLOWING;
}
headBuf.putInt(header);
headBuf.flip();
while (headRemaining > 0) {
if (LOGGER.isTraceEnabled()) {
LOGGER.trace("Writing Header " + headRemaining);
while (true) {
// Check if all app buffers are fully consumed
boolean allDone = true;
for (ByteBuffer b : appSeq) {
if (b.hasRemaining()) {
allDone = false;
break;
}
long count = ch.write(headBuf);
headRemaining -= count;
}
while (dataRemaining > 0) {
if (LOGGER.isTraceEnabled()) {
LOGGER.trace("Writing Data " + dataRemaining);
}
long count = ch.write(pkgBuf);
dataRemaining -= count;
if (allDone) break;
netBuf.clear();
SSLEngineResult res;
try {
res = sslEngine.wrap(appSeq, netBuf);
} catch (SSLException e) {
throw new IOException("SSL wrap failed: " + e.getMessage(), e);
}
switch (res.getStatus()) {
case OK:
netBuf.flip();
while (netBuf.hasRemaining()) {
ch.write(netBuf); // may be partial, loop until drained
}
break;
case BUFFER_OVERFLOW:
netBuf = enlargeBuffer(netBuf, session.getPacketBufferSize());
break;
case CLOSED:
throw new IOException("SSLEngine is CLOSED during write");
default:
throw new IOException("Unexpected SSLEngineResult status on wrap: " + res.getStatus());
}
// Drain delegated tasks if any
if (res.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
Runnable task;
while ((task = sslEngine.getDelegatedTask()) != null) task.run();
}
if (res.getHandshakeStatus() == HandshakeStatus.NEED_UNWRAP) {
// Unusual during application writes; upper layer should drive handshake
break;
}
}
}
@ -174,116 +197,88 @@ public class Link {
}
}
/* SSL has limitation of 16k, we may need to split packets. 18000 is 16k + some extra SSL informations */
protected static final int MAX_SIZE_PER_PACKET = 18000;
protected static final int HEADER_FLAG_FOLLOWING = 0x10000;
public byte[] read(SocketChannel ch) throws IOException {
if (_readHeader) { // Start of a packet
if (_readBuffer.position() == 0) {
_readBuffer.limit(4);
}
if (ch.read(_readBuffer) == -1) {
throw new IOException("Connection closed with -1 on reading size.");
}
if (_readBuffer.hasRemaining()) {
LOGGER.trace("Need to read the rest of the packet length");
return null;
}
_readBuffer.flip();
int header = _readBuffer.getInt();
int readSize = (short)header;
if (LOGGER.isTraceEnabled()) {
LOGGER.trace("Packet length is " + readSize);
}
if (readSize > MAX_SIZE_PER_PACKET) {
throw new IOException("Wrong packet size: " + readSize);
}
if (!_gotFollowingPacket) {
_plaintextBuffer = ByteBuffer.allocate(2000);
}
if ((header & HEADER_FLAG_FOLLOWING) != 0) {
_gotFollowingPacket = true;
} else {
_gotFollowingPacket = false;
}
_readBuffer.clear();
_readHeader = false;
if (_readBuffer.capacity() < readSize) {
if (LOGGER.isTraceEnabled()) {
LOGGER.trace("Resizing the byte buffer from " + _readBuffer.capacity());
}
_readBuffer = ByteBuffer.allocate(readSize);
}
_readBuffer.limit(readSize);
if (_sslEngine == null) {
throw new IOException("SSLEngine not set");
}
if (ch.read(_readBuffer) == -1) {
if (ch.read(netBuffer) == -1) {
throw new IOException("Connection closed with -1 on read.");
}
if (_readBuffer.hasRemaining()) { // We're not done yet.
if (LOGGER.isTraceEnabled()) {
LOGGER.trace("Still has " + _readBuffer.remaining());
netBuffer.flip();
while (netBuffer.hasRemaining()) {
SSLEngineResult res;
try {
res = _sslEngine.unwrap(netBuffer, appBuffer);
} catch (SSLException e) {
throw new IOException("SSL unwrap failed: " + e.getMessage(), e);
}
return null;
}
_readBuffer.flip();
ByteBuffer appBuf;
SSLSession sslSession = _sslEngine.getSession();
SSLEngineResult engResult;
int remaining = 0;
while (_readBuffer.hasRemaining()) {
remaining = _readBuffer.remaining();
appBuf = ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40);
engResult = _sslEngine.unwrap(_readBuffer, appBuf);
if (engResult.getHandshakeStatus() != HandshakeStatus.FINISHED && engResult.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING &&
engResult.getStatus() != SSLEngineResult.Status.OK) {
throw new IOException("SSL: SSLEngine return bad result! " + engResult);
switch (res.getStatus()) {
case OK:
appBuffer.flip();
while (appBuffer.hasRemaining()) {
if (frameRemaining < 0) {
int need = 4 - headerBuffer.position();
int take = Math.min(need, appBuffer.remaining());
int oldLimit = appBuffer.limit();
appBuffer.limit(appBuffer.position() + take);
headerBuffer.put(appBuffer);
appBuffer.limit(oldLimit);
if (headerBuffer.position() < 4) break;
headerBuffer.flip();
frameRemaining = headerBuffer.getInt();
headerBuffer.clear();
if (frameRemaining < 0) {
throw new IOException("Negative frame length");
}
if (plainTextBuffer == null || plainTextBuffer.capacity() < frameRemaining) {
plainTextBuffer = ByteBuffer.allocate(Math.max(frameRemaining, 2048));
}
plainTextBuffer.clear();
} else {
int toCopy = Math.min(frameRemaining, appBuffer.remaining());
if (plainTextBuffer.remaining() < toCopy) {
ByteBuffer newBuffer = ByteBuffer.allocate(plainTextBuffer.capacity() + Math.max(toCopy, 4096));
plainTextBuffer.flip();
newBuffer.put(plainTextBuffer);
plainTextBuffer = newBuffer;
}
int oldLimit = appBuffer.limit();
appBuffer.limit(appBuffer.position() + toCopy);
plainTextBuffer.put(appBuffer);
appBuffer.limit(oldLimit);
frameRemaining -= toCopy;
if (frameRemaining == 0) {
plainTextBuffer.flip();
byte[] result = new byte[plainTextBuffer.remaining()];
plainTextBuffer.get(result);
appBuffer.compact();
netBuffer.compact();
frameRemaining = -1;
return result;
}
}
}
appBuffer.compact();
break;
case BUFFER_OVERFLOW:
appBuffer = enlargeBuffer(appBuffer, _sslEngine.getSession().getApplicationBufferSize());
break;
case BUFFER_UNDERFLOW:
netBuffer = handleBufferUnderflow(_sslEngine, netBuffer);
netBuffer.compact();
return null;
case CLOSED:
throw new IOException("SSLEngine closed during read");
default:
throw new IOException("Unexpected SSLEngineResult status on unwrap: " + res.getStatus());
}
if (remaining == _readBuffer.remaining()) {
throw new IOException("SSL: Unable to unwrap received data! still remaining " + remaining + "bytes!");
}
appBuf.flip();
if (_plaintextBuffer.remaining() < appBuf.limit()) {
// We need to expand _plaintextBuffer for more data
ByteBuffer newBuffer = ByteBuffer.allocate(_plaintextBuffer.capacity() + appBuf.limit() * 5);
_plaintextBuffer.flip();
newBuffer.put(_plaintextBuffer);
_plaintextBuffer = newBuffer;
}
_plaintextBuffer.put(appBuf);
if (LOGGER.isTraceEnabled()) {
LOGGER.trace("Done with packet: " + appBuf.limit());
if (res.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
Runnable task;
while ((task = _sslEngine.getDelegatedTask()) != null) task.run();
}
}
_readBuffer.clear();
_readHeader = true;
if (!_gotFollowingPacket) {
_plaintextBuffer.flip();
byte[] result = new byte[_plaintextBuffer.limit()];
_plaintextBuffer.get(result);
return result;
} else {
if (LOGGER.isTraceEnabled()) {
LOGGER.trace("Waiting for more packets");
}
return null;
}
netBuffer.compact();
return null;
}
public void send(byte[] data) throws ClosedChannelException {
@ -295,19 +290,14 @@ public class Link {
}
public void send(ByteBuffer[] data, boolean close) throws ClosedChannelException {
ByteBuffer[] item = new ByteBuffer[data.length + 1];
ByteBuffer[] item = new ByteBuffer[data.length];
int remaining = 0;
for (int i = 0; i < data.length; i++) {
remaining += data[i].remaining();
item[i + 1] = data[i];
item[i] = data[i];
}
item[0] = ByteBuffer.allocate(4);
item[0].putInt(remaining);
item[0].flip();
if (LOGGER.isTraceEnabled()) {
LOGGER.trace("Sending packet of length " + remaining);
LOGGER.trace("Sending framed message of length " + remaining);
}
_writeQueue.add(item);
@ -341,11 +331,7 @@ public class Link {
}
return true;
}
ByteBuffer[] raw_data = new ByteBuffer[data.length - 1];
System.arraycopy(data, 1, raw_data, 0, data.length - 1);
doWrite(ch, raw_data, _sslEngine);
doWrite(ch, data, _sslEngine);
}
return false;
}
@ -376,7 +362,7 @@ public class Link {
}
public static SSLEngine initServerSSLEngine(final CAService caService, final String clientAddress) throws GeneralSecurityException, IOException {
final SSLContext sslContext = SSLUtils.getSSLContext();
final SSLContext sslContext = SSLUtils.getSSLContextWithLatestProtocolVersion();
if (caService != null) {
return caService.createSSLEngine(sslContext, clientAddress);
}
@ -405,7 +391,7 @@ public class Link {
final KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
kmf.init(ks, passphrase);
final SSLContext sslContext = SSLUtils.getSSLContext();
final SSLContext sslContext = SSLUtils.getSSLContextWithLatestProtocolVersion();
sslContext.init(kmf.getKeyManagers(), tms, new SecureRandom());
return sslContext;
}
@ -449,7 +435,7 @@ public class Link {
final KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
kmf.init(ks, passphrase);
final SSLContext sslContext = SSLUtils.getSSLContext();
final SSLContext sslContext = SSLUtils.getSSLContextWithLatestProtocolVersion();
sslContext.init(kmf.getKeyManagers(), tms, new SecureRandom());
return sslContext;
}

View File

@ -74,7 +74,8 @@ public class NioClient extends NioConnection {
if (!Link.doHandshake(clientConnection, sslEngine, getSslHandshakeTimeout())) {
throw new IOException(String.format("SSL Handshake failed while connecting to host: %s", hostLog));
}
logger.info("SSL: Handshake done");
logger.info("SSL: Handshake done with {} protocol: {}, cipher suite: {}",
serverAddress, sslEngine.getSession().getProtocol(), sslEngine.getSession().getCipherSuite());
final Link link = new Link(serverAddress, this);
link.setSSLEngine(sslEngine);

View File

@ -274,7 +274,9 @@ public abstract class NioConnection implements Callable<Boolean> {
if (!Link.doHandshake(socketChannel, sslEngine, getSslHandshakeTimeout())) {
throw new IOException("SSL handshake timed out with " + socketAddress);
}
logger.trace("SSL: Handshake done");
logger.trace("SSL: Handshake done with {} protocol: {}, cipher suite: {}",
socketAddress, sslEngine.getSession().getProtocol(),
sslEngine.getSession().getCipherSuite());
final Link link = new Link(socketAddress, nioConnection);
link.setSSLEngine(sslEngine);
link.setKey(socketChannel.register(key.selector(), SelectionKey.OP_READ, link));

View File

@ -70,6 +70,10 @@ public class SSLUtils {
return SSLContext.getInstance("TLSv1.2");
}
public static SSLContext getSSLContextWithLatestProtocolVersion() throws NoSuchAlgorithmException {
return SSLContext.getInstance("TLSv1.3");
}
public static SSLContext getSSLContext(String provider) throws NoSuchAlgorithmException, NoSuchProviderException {
return SSLContext.getInstance("TLSv1.2", provider);
}