diff --git a/engine/orchestration/src/com/cloud/agent/manager/ClusteredAgentManagerImpl.java b/engine/orchestration/src/com/cloud/agent/manager/ClusteredAgentManagerImpl.java
index ba82938fcdc..f65d83bab5b 100755
--- a/engine/orchestration/src/com/cloud/agent/manager/ClusteredAgentManagerImpl.java
+++ b/engine/orchestration/src/com/cloud/agent/manager/ClusteredAgentManagerImpl.java
@@ -501,7 +501,7 @@ public class ClusteredAgentManagerImpl extends AgentManagerImpl implements Clust
SocketChannel ch1 = null;
try {
ch1 = SocketChannel.open(new InetSocketAddress(addr, Port.value()));
- ch1.configureBlocking(true); // make sure we are working at blocking mode
+ ch1.configureBlocking(false);
ch1.socket().setKeepAlive(true);
ch1.socket().setSoTimeout(60 * 1000);
try {
@@ -509,8 +509,11 @@ public class ClusteredAgentManagerImpl extends AgentManagerImpl implements Clust
sslEngine = sslContext.createSSLEngine(ip, Port.value());
sslEngine.setUseClientMode(true);
sslEngine.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslEngine.getEnabledProtocols()));
-
- Link.doHandshake(ch1, sslEngine, true);
+ sslEngine.beginHandshake();
+ if (!Link.doHandshake(ch1, sslEngine, true)) {
+ ch1.close();
+ throw new IOException("SSL handshake failed!");
+ }
s_logger.info("SSL: Handshake done");
} catch (Exception e) {
ch1.close();
diff --git a/utils/pom.xml b/utils/pom.xml
index 0aec5c62e62..0b84f38938a 100755
--- a/utils/pom.xml
+++ b/utils/pom.xml
@@ -178,10 +178,9 @@
com/cloud/utils/testcase/*TestCase*
com/cloud/utils/db/*Test*
- com/cloud/utils/testcase/NioTest.java
-
+
com.mycila
license-maven-plugin
diff --git a/utils/src/com/cloud/utils/nio/Link.java b/utils/src/com/cloud/utils/nio/Link.java
index ddfd474dfff..729af6c0d02 100755
--- a/utils/src/com/cloud/utils/nio/Link.java
+++ b/utils/src/com/cloud/utils/nio/Link.java
@@ -24,11 +24,8 @@ import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.InetSocketAddress;
-import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
-import java.nio.channels.Channels;
import java.nio.channels.ClosedChannelException;
-import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.security.GeneralSecurityException;
@@ -40,6 +37,7 @@ import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
+import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
@@ -453,115 +451,185 @@ public class Link {
return sslContext;
}
- public static void doHandshake(SocketChannel ch, SSLEngine sslEngine, boolean isClient) throws IOException {
- if (s_logger.isTraceEnabled()) {
- s_logger.trace("SSL: begin Handshake, isClient: " + isClient);
+ public static ByteBuffer enlargeBuffer(ByteBuffer buffer, final int sessionProposedCapacity) {
+ if (buffer == null || sessionProposedCapacity < 0) {
+ return buffer;
}
-
- SSLEngineResult engResult;
- SSLSession sslSession = sslEngine.getSession();
- HandshakeStatus hsStatus;
- ByteBuffer in_pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
- ByteBuffer in_appBuf = ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40);
- ByteBuffer out_pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
- ByteBuffer out_appBuf = ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40);
- int count;
- ch.socket().setSoTimeout(60 * 1000);
- InputStream inStream = ch.socket().getInputStream();
- // Use readCh to make sure the timeout on reading is working
- ReadableByteChannel readCh = Channels.newChannel(inStream);
-
- if (isClient) {
- hsStatus = SSLEngineResult.HandshakeStatus.NEED_WRAP;
+ if (sessionProposedCapacity > buffer.capacity()) {
+ buffer = ByteBuffer.allocate(sessionProposedCapacity);
} else {
- hsStatus = SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
+ buffer = ByteBuffer.allocate(buffer.capacity() * 2);
}
+ return buffer;
+ }
- while (hsStatus != SSLEngineResult.HandshakeStatus.FINISHED) {
- if (s_logger.isTraceEnabled()) {
- s_logger.trace("SSL: Handshake status " + hsStatus);
- }
- engResult = null;
- if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP) {
- out_pkgBuf.clear();
- out_appBuf.clear();
- out_appBuf.put("Hello".getBytes());
- engResult = sslEngine.wrap(out_appBuf, out_pkgBuf);
- out_pkgBuf.flip();
- int remain = out_pkgBuf.limit();
- while (remain != 0) {
- remain -= ch.write(out_pkgBuf);
- if (remain < 0) {
- throw new IOException("Too much bytes sent?");
- }
- }
- } else if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
- in_appBuf.clear();
- // One packet may contained multiply operation
- if (in_pkgBuf.position() == 0 || !in_pkgBuf.hasRemaining()) {
- in_pkgBuf.clear();
- count = 0;
- try {
- count = readCh.read(in_pkgBuf);
- } catch (SocketTimeoutException ex) {
- if (s_logger.isTraceEnabled()) {
- s_logger.trace("Handshake reading time out! Cut the connection");
- }
- count = -1;
- }
- if (count == -1) {
- throw new IOException("Connection closed with -1 on reading size.");
- }
- in_pkgBuf.flip();
- }
- engResult = sslEngine.unwrap(in_pkgBuf, in_appBuf);
- ByteBuffer tmp_pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
- int loop_count = 0;
- while (engResult.getStatus() == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
- // The client is too slow? Cut it and let it reconnect
- if (loop_count > 10) {
- throw new IOException("Too many times in SSL BUFFER_UNDERFLOW, disconnect guest.");
- }
- // We need more packets to complete this operation
- if (s_logger.isTraceEnabled()) {
- s_logger.trace("SSL: Buffer underflowed, getting more packets");
- }
- tmp_pkgBuf.clear();
- count = ch.read(tmp_pkgBuf);
- if (count == -1) {
- throw new IOException("Connection closed with -1 on reading size.");
- }
- tmp_pkgBuf.flip();
-
- in_pkgBuf.mark();
- in_pkgBuf.position(in_pkgBuf.limit());
- in_pkgBuf.limit(in_pkgBuf.limit() + tmp_pkgBuf.limit());
- in_pkgBuf.put(tmp_pkgBuf);
- in_pkgBuf.reset();
-
- in_appBuf.clear();
- engResult = sslEngine.unwrap(in_pkgBuf, in_appBuf);
- loop_count++;
- }
- } else if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK) {
- Runnable run;
- while ((run = sslEngine.getDelegatedTask()) != null) {
- if (s_logger.isTraceEnabled()) {
- s_logger.trace("SSL: Running delegated task!");
- }
- run.run();
- }
- } else if (hsStatus == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
- throw new IOException("NOT a handshaking!");
- }
- if (engResult != null && engResult.getStatus() != SSLEngineResult.Status.OK) {
- throw new IOException("Fail to handshake! " + engResult.getStatus());
- }
- if (engResult != null)
- hsStatus = engResult.getHandshakeStatus();
- else
- hsStatus = sslEngine.getHandshakeStatus();
+ public static ByteBuffer handleBufferUnderflow(final SSLEngine engine, ByteBuffer buffer) {
+ if (engine == null || buffer == null) {
+ return buffer;
}
+ if (buffer.position() < buffer.limit()) {
+ return buffer;
+ }
+ ByteBuffer replaceBuffer = enlargeBuffer(buffer, engine.getSession().getPacketBufferSize());
+ buffer.flip();
+ replaceBuffer.put(buffer);
+ return replaceBuffer;
+ }
+
+ private static boolean doHandshakeUnwrap(final SocketChannel socketChannel, final SSLEngine sslEngine,
+ ByteBuffer peerAppData, ByteBuffer peerNetData, final int appBufferSize) throws IOException {
+ if (socketChannel == null || sslEngine == null || peerAppData == null || peerNetData == null || appBufferSize < 0) {
+ return false;
+ }
+ if (socketChannel.read(peerNetData) < 0) {
+ if (sslEngine.isInboundDone() && sslEngine.isOutboundDone()) {
+ return false;
+ }
+ try {
+ sslEngine.closeInbound();
+ } catch (SSLException e) {
+ s_logger.warn("This SSL engine was forced to close inbound due to end of stream.");
+ }
+ sslEngine.closeOutbound();
+ // After closeOutbound the engine will be set to WRAP state,
+ // in order to try to send a close message to the client.
+ return true;
+ }
+ peerNetData.flip();
+ SSLEngineResult result = null;
+ try {
+ result = sslEngine.unwrap(peerNetData, peerAppData);
+ peerNetData.compact();
+ } catch (SSLException sslException) {
+ s_logger.error("SSL error occurred while processing unwrap data: " + sslException.getMessage());
+ sslEngine.closeOutbound();
+ return true;
+ }
+ switch (result.getStatus()) {
+ case OK:
+ break;
+ case BUFFER_OVERFLOW:
+ // Will occur when peerAppData's capacity is smaller than the data derived from peerNetData's unwrap.
+ peerAppData = enlargeBuffer(peerAppData, appBufferSize);
+ break;
+ case BUFFER_UNDERFLOW:
+ // Will occur either when no data was read from the peer or when the peerNetData buffer
+ // was too small to hold all peer's data.
+ peerNetData = handleBufferUnderflow(sslEngine, peerNetData);
+ break;
+ case CLOSED:
+ if (sslEngine.isOutboundDone()) {
+ return false;
+ } else {
+ sslEngine.closeOutbound();
+ break;
+ }
+ default:
+ throw new IllegalStateException("Invalid SSL status: " + result.getStatus());
+ }
+ return true;
+ }
+
+ private static boolean doHandshakeWrap(final SocketChannel socketChannel, final SSLEngine sslEngine,
+ ByteBuffer myAppData, ByteBuffer myNetData, ByteBuffer peerNetData,
+ final int netBufferSize) throws IOException {
+ if (socketChannel == null || sslEngine == null || myNetData == null || peerNetData == null
+ || myAppData == null || netBufferSize < 0) {
+ return false;
+ }
+ myNetData.clear();
+ SSLEngineResult result = null;
+ try {
+ result = sslEngine.wrap(myAppData, myNetData);
+ } catch (SSLException sslException) {
+ s_logger.error("SSL error occurred while processing wrap data: " + sslException.getMessage());
+ sslEngine.closeOutbound();
+ return true;
+ }
+ switch (result.getStatus()) {
+ case OK :
+ myNetData.flip();
+ while (myNetData.hasRemaining()) {
+ socketChannel.write(myNetData);
+ }
+ break;
+ case BUFFER_OVERFLOW:
+ // Will occur if there is not enough space in myNetData buffer to write all the data
+ // that would be generated by the method wrap. Since myNetData is set to session's packet
+ // size we should not get to this point because SSLEngine is supposed to produce messages
+ // smaller or equal to that, but a general handling would be the following:
+ myNetData = enlargeBuffer(myNetData, netBufferSize);
+ break;
+ case BUFFER_UNDERFLOW:
+ throw new SSLException("Buffer underflow occurred after a wrap. We should not reach here.");
+ case CLOSED:
+ try {
+ myNetData.flip();
+ while (myNetData.hasRemaining()) {
+ socketChannel.write(myNetData);
+ }
+ // At this point the handshake status will probably be NEED_UNWRAP
+ // so we make sure that peerNetData is clear to read.
+ peerNetData.clear();
+ } catch (Exception e) {
+ s_logger.error("Failed to send server's CLOSE message due to socket channel's failure.");
+ }
+ break;
+ default:
+ throw new IllegalStateException("Invalid SSL status: " + result.getStatus());
+ }
+ return true;
+ }
+
+ public static boolean doHandshake(final SocketChannel socketChannel, final SSLEngine sslEngine, final boolean isClient) throws IOException {
+ if (socketChannel == null || sslEngine == null) {
+ return false;
+ }
+ final int appBufferSize = sslEngine.getSession().getApplicationBufferSize();
+ final int netBufferSize = sslEngine.getSession().getPacketBufferSize();
+ ByteBuffer myAppData = ByteBuffer.allocate(appBufferSize);
+ ByteBuffer peerAppData = ByteBuffer.allocate(appBufferSize);
+ ByteBuffer myNetData = ByteBuffer.allocate(netBufferSize);
+ ByteBuffer peerNetData = ByteBuffer.allocate(netBufferSize);
+
+ final long startTimeMills = System.currentTimeMillis();
+
+ HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
+ while (handshakeStatus != SSLEngineResult.HandshakeStatus.FINISHED
+ && handshakeStatus != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
+ final long timeTaken = System.currentTimeMillis() - startTimeMills;
+ if (timeTaken > 15000L) {
+ s_logger.warn("SSL Handshake has taken more than 15s to connect to: " + socketChannel.getRemoteAddress() +
+ ". Please investigate this connection.");
+ return false;
+ }
+ switch (handshakeStatus) {
+ case NEED_UNWRAP:
+ if (!doHandshakeUnwrap(socketChannel, sslEngine, peerAppData, peerNetData, appBufferSize)) {
+ return false;
+ }
+ break;
+ case NEED_WRAP:
+ if (!doHandshakeWrap(socketChannel, sslEngine, myAppData, myNetData, peerNetData, netBufferSize)) {
+ return false;
+ }
+ break;
+ case NEED_TASK:
+ Runnable task;
+ while ((task = sslEngine.getDelegatedTask()) != null) {
+ new Thread(task).run();
+ }
+ break;
+ case FINISHED:
+ break;
+ case NOT_HANDSHAKING:
+ break;
+ default:
+ throw new IllegalStateException("Invalid SSL status: " + handshakeStatus);
+ }
+ handshakeStatus = sslEngine.getHandshakeStatus();
+ }
+ return true;
}
}
diff --git a/utils/src/com/cloud/utils/nio/NioClient.java b/utils/src/com/cloud/utils/nio/NioClient.java
index 2f742f99dc5..620a354545f 100755
--- a/utils/src/com/cloud/utils/nio/NioClient.java
+++ b/utils/src/com/cloud/utils/nio/NioClient.java
@@ -37,7 +37,6 @@ public class NioClient extends NioConnection {
private static final Logger s_logger = Logger.getLogger(NioClient.class);
protected String _host;
- protected String _bindAddress;
protected SocketChannel _clientConnection;
public NioClient(String name, String host, int port, int workers, HandlerFactory factory) {
@@ -45,10 +44,6 @@ public class NioClient extends NioConnection {
_host = host;
}
- public void setBindAddress(String ipAddress) {
- _bindAddress = ipAddress;
- }
-
@Override
protected void init() throws IOException {
_selector = Selector.open();
@@ -56,29 +51,23 @@ public class NioClient extends NioConnection {
try {
_clientConnection = SocketChannel.open();
- _clientConnection.configureBlocking(true);
s_logger.info("Connecting to " + _host + ":" + _port);
- if (_bindAddress != null) {
- s_logger.info("Binding outbound interface at " + _bindAddress);
-
- InetSocketAddress bindAddr = new InetSocketAddress(_bindAddress, 0);
- _clientConnection.socket().bind(bindAddr);
- }
-
InetSocketAddress peerAddr = new InetSocketAddress(_host, _port);
_clientConnection.connect(peerAddr);
+ _clientConnection.configureBlocking(false);
- SSLEngine sslEngine = null;
- // Begin SSL handshake in BLOCKING mode
- _clientConnection.configureBlocking(true);
-
- SSLContext sslContext = Link.initSSLContext(true);
- sslEngine = sslContext.createSSLEngine(_host, _port);
+ final SSLContext sslContext = Link.initSSLContext(true);
+ SSLEngine sslEngine = sslContext.createSSLEngine(_host, _port);
sslEngine.setUseClientMode(true);
sslEngine.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslEngine.getEnabledProtocols()));
- Link.doHandshake(_clientConnection, sslEngine, true);
+ sslEngine.beginHandshake();
+ if (!Link.doHandshake(_clientConnection, sslEngine, true)) {
+ s_logger.error("SSL Handshake failed while connecting to host: " + _host + " port: " + _port);
+ _selector.close();
+ throw new IOException("SSL Handshake failed while connecting to host: " + _host + " port: " + _port);
+ }
s_logger.info("SSL: Handshake done");
s_logger.info("Connected to " + _host + ":" + _port);
diff --git a/utils/src/com/cloud/utils/nio/NioConnection.java b/utils/src/com/cloud/utils/nio/NioConnection.java
index 34679b8277b..fa92959e3e0 100755
--- a/utils/src/com/cloud/utils/nio/NioConnection.java
+++ b/utils/src/com/cloud/utils/nio/NioConnection.java
@@ -33,6 +33,7 @@ import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
+import java.util.concurrent.Executors;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
@@ -63,6 +64,7 @@ public abstract class NioConnection implements Runnable {
protected HandlerFactory _factory;
protected String _name;
protected ExecutorService _executor;
+ protected ExecutorService _sslHandshakeExecutor;
public NioConnection(String name, int port, int workers, HandlerFactory factory) {
_name = name;
@@ -72,6 +74,7 @@ public abstract class NioConnection implements Runnable {
_port = port;
_factory = factory;
_executor = new ThreadPoolExecutor(workers, 5 * workers, 1, TimeUnit.DAYS, new LinkedBlockingQueue(), new NamedThreadFactory(name + "-Handler"));
+ _sslHandshakeExecutor = Executors.newCachedThreadPool(new NamedThreadFactory(name + "-SSLHandshakeHandler"));
}
public void start() {
@@ -167,6 +170,8 @@ public abstract class NioConnection implements Runnable {
processTodos();
} catch (Throwable e) {
s_logger.warn("Caught an exception but continuing on.", e);
+ } finally {
+ _selector.wakeup();
}
}
synchronized (_thread) {
@@ -180,53 +185,69 @@ public abstract class NioConnection implements Runnable {
abstract void unregisterLink(InetSocketAddress saddr);
- protected void accept(SelectionKey key) throws IOException {
- ServerSocketChannel serverSocketChannel = (ServerSocketChannel)key.channel();
+ protected void accept(final SelectionKey key) throws IOException {
+ final ServerSocketChannel serverSocketChannel = (ServerSocketChannel)key.channel();
- SocketChannel socketChannel = serverSocketChannel.accept();
- Socket socket = socketChannel.socket();
+ final SocketChannel socketChannel = serverSocketChannel.accept();
+ socketChannel.configureBlocking(false);
+ final Socket socket = socketChannel.socket();
socket.setKeepAlive(true);
- if (s_logger.isTraceEnabled()) {
- s_logger.trace("Connection accepted for " + socket);
- }
-
- // Begin SSL handshake in BLOCKING mode
- socketChannel.configureBlocking(true);
-
- SSLEngine sslEngine = null;
+ final SSLEngine sslEngine;
try {
- SSLContext sslContext = Link.initSSLContext(false);
+ final SSLContext sslContext = Link.initSSLContext(false);
sslEngine = sslContext.createSSLEngine();
sslEngine.setUseClientMode(false);
sslEngine.setNeedClientAuth(false);
sslEngine.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslEngine.getEnabledProtocols()));
- Link.doHandshake(socketChannel, sslEngine, false);
-
- } catch (Exception e) {
- if (s_logger.isTraceEnabled()) {
- s_logger.trace("Socket " + socket + " closed on read. Probably -1 returned: " + e.getMessage());
- }
+ final NioConnection nioConnection = this;
+ _sslHandshakeExecutor.submit(new Runnable() {
+ @Override
+ public void run() {
+ _selector.wakeup();
+ try {
+ sslEngine.beginHandshake();
+ if (!Link.doHandshake(socketChannel, sslEngine, false)) {
+ throw new IOException("SSL handshake timed out with " + socketChannel.getRemoteAddress());
+ }
+ if (s_logger.isTraceEnabled()) {
+ s_logger.trace("SSL: Handshake done");
+ }
+ final InetSocketAddress saddr = (InetSocketAddress)socket.getRemoteSocketAddress();
+ final Link link = new Link(saddr, nioConnection);
+ link.setSSLEngine(sslEngine);
+ link.setKey(socketChannel.register(key.selector(), SelectionKey.OP_READ, link));
+ final Task task = _factory.create(Task.Type.CONNECT, link, null);
+ registerLink(saddr, link);
+ _executor.submit(task);
+ } catch (IOException e) {
+ if (s_logger.isTraceEnabled()) {
+ s_logger.trace("Connection closed due to failure: " + e.getMessage());
+ }
+ try {
+ socketChannel.close();
+ socket.close();
+ } catch (IOException ignore) {
+ }
+ } finally {
+ _selector.wakeup();
+ }
+ }
+ });
+ } catch (final Exception e) {
+ if (s_logger.isTraceEnabled()) {
+ s_logger.trace("Connection closed due to failure: " + e.getMessage());
+ }
try {
socketChannel.close();
socket.close();
} catch (IOException ignore) {
}
return;
+ } finally {
+ _selector.wakeup();
}
-
- if (s_logger.isTraceEnabled()) {
- s_logger.trace("SSL: Handshake done");
- }
- socketChannel.configureBlocking(false);
- InetSocketAddress saddr = (InetSocketAddress)socket.getRemoteSocketAddress();
- Link link = new Link(saddr, this);
- link.setSSLEngine(sslEngine);
- link.setKey(socketChannel.register(key.selector(), SelectionKey.OP_READ, link));
- Task task = _factory.create(Task.Type.CONNECT, link, null);
- registerLink(saddr, link);
- _executor.execute(task);
}
protected void terminate(SelectionKey key) {
diff --git a/utils/src/com/cloud/utils/nio/NioServer.java b/utils/src/com/cloud/utils/nio/NioServer.java
index 98a4a51dbfa..adcecda4e64 100755
--- a/utils/src/com/cloud/utils/nio/NioServer.java
+++ b/utils/src/com/cloud/utils/nio/NioServer.java
@@ -43,6 +43,10 @@ public class NioServer extends NioConnection {
_links = new WeakHashMap(1024);
}
+ public int getPort() {
+ return _serverSocket.socket().getLocalPort();
+ }
+
@Override
protected void init() throws IOException {
_selector = SelectorProvider.provider().openSelector();
@@ -55,7 +59,7 @@ public class NioServer extends NioConnection {
_serverSocket.register(_selector, SelectionKey.OP_ACCEPT, null);
- s_logger.info("NioConnection started and listening on " + _localAddr.toString());
+ s_logger.info("NioConnection started and listening on " + _serverSocket.socket().getLocalSocketAddress());
}
@Override
diff --git a/utils/test/com/cloud/utils/testcase/NioTest.java b/utils/test/com/cloud/utils/testcase/NioTest.java
index fc166847dc2..515119d9e7e 100644
--- a/utils/test/com/cloud/utils/testcase/NioTest.java
+++ b/utils/test/com/cloud/utils/testcase/NioTest.java
@@ -19,198 +19,278 @@
package com.cloud.utils.testcase;
-import java.nio.channels.ClosedChannelException;
-import java.util.Random;
-
-import junit.framework.TestCase;
-
-import org.apache.log4j.Logger;
-import org.junit.Assert;
-
+import com.cloud.utils.concurrency.NamedThreadFactory;
import com.cloud.utils.nio.HandlerFactory;
import com.cloud.utils.nio.Link;
import com.cloud.utils.nio.NioClient;
import com.cloud.utils.nio.NioServer;
import com.cloud.utils.nio.Task;
import com.cloud.utils.nio.Task.Type;
+import org.apache.log4j.Logger;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.nio.channels.ClosedChannelException;
+import java.nio.channels.Selector;
+import java.nio.channels.SocketChannel;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
/**
- *
- *
- *
- *
+ * NioTest demonstrates that NioServer can function without getting its main IO
+ * loop blocked when an aggressive or malicious client connects to the server but
+ * fail to participate in SSL handshake. In this test, we run bunch of clients
+ * that send a known payload to the server, to which multiple malicious clients
+ * also try to connect and hang.
+ * A malicious client could cause denial-of-service if the server's main IO loop
+ * along with SSL handshake was blocking. A passing tests shows that NioServer
+ * can still function in case of connection load and that the main IO loop along
+ * with SSL handshake is non-blocking with some internal timeout mechanism.
*/
-public class NioTest extends TestCase {
+public class NioTest {
- private static final Logger s_logger = Logger.getLogger(NioTest.class);
+ private static final Logger LOGGER = Logger.getLogger(NioTest.class);
- private NioServer _server;
- private NioClient _client;
+ // Test should fail in due time instead of looping forever
+ private static final int TESTTIMEOUT = 60000;
- private Link _clientLink;
+ final private int totalTestCount = 4;
+ private int completedTestCount = 0;
- private int _testCount;
- private int _completedCount;
+ private NioServer server;
+ private List clients = new ArrayList<>();
+ private List maliciousClients = new ArrayList<>();
+
+ private ExecutorService clientExecutor = Executors.newFixedThreadPool(totalTestCount, new NamedThreadFactory("NioClientHandler"));;
+ private ExecutorService maliciousExecutor = Executors.newFixedThreadPool(totalTestCount, new NamedThreadFactory("MaliciousNioClientHandler"));;
+
+ private Random randomGenerator = new Random();
+ private byte[] testBytes;
private boolean isTestsDone() {
boolean result;
synchronized (this) {
- result = (_testCount == _completedCount);
+ result = totalTestCount == completedTestCount;
}
return result;
}
- private void getOneMoreTest() {
- synchronized (this) {
- _testCount++;
- }
- }
-
private void oneMoreTestDone() {
synchronized (this) {
- _completedCount++;
+ completedTestCount++;
}
}
- @Override
+ @Before
public void setUp() {
- s_logger.info("Test");
+ LOGGER.info("Setting up Benchmark Test");
- _testCount = 0;
- _completedCount = 0;
+ completedTestCount = 0;
+ testBytes = new byte[1000000];
+ randomGenerator.nextBytes(testBytes);
- _server = new NioServer("NioTestServer", 7777, 5, new NioTestServer());
- _server.start();
+ server = new NioServer("NioTestServer", 0, 1, new NioTestServer());
+ try {
+ server.start();
+ } catch (Exception e) {
+ Assert.fail(e.getMessage());
+ }
- _client = new NioClient("NioTestServer", "127.0.0.1", 7777, 5, new NioTestClient());
- _client.start();
+ /**
+ * The malicious client(s) tries to block NioServer's main IO loop
+ * thread until SSL handshake timeout value (from Link class, 15s) after
+ * which the valid NioClient(s) get the opportunity to make connection(s)
+ */
+ for (int i = 0; i < totalTestCount; i++) {
+ final NioClient maliciousClient = new NioMaliciousClient("NioMaliciousTestClient-" + i, "127.0.0.1", server.getPort(), 1, new NioMaliciousTestClient());
+ maliciousClients.add(maliciousClient);
+ maliciousExecutor.submit(new ThreadedNioClient(maliciousClient));
+ }
- while (_clientLink == null) {
- try {
- s_logger.debug("Link is not up! Waiting ...");
- Thread.sleep(1000);
- } catch (InterruptedException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- }
+ for (int i = 0; i < totalTestCount; i++) {
+ final NioClient client = new NioClient("NioTestClient-" + i, "127.0.0.1", server.getPort(), 1, new NioTestClient());
+ clients.add(client);
+ clientExecutor.submit(new ThreadedNioClient(client));
}
}
- @Override
+ @After
public void tearDown() {
- while (!isTestsDone()) {
- try {
- s_logger.debug(this._completedCount + "/" + this._testCount + " tests done. Waiting for completion");
- Thread.sleep(1000);
- } catch (InterruptedException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- }
- }
stopClient();
stopServer();
}
protected void stopClient() {
- _client.stop();
- s_logger.info("Client stopped.");
+ for (NioClient client : clients) {
+ client.stop();
+ }
+ for (NioClient maliciousClient : maliciousClients) {
+ maliciousClient.stop();
+ }
+ LOGGER.info("Clients stopped.");
}
protected void stopServer() {
- _server.stop();
- s_logger.info("Server stopped.");
+ server.stop();
+ LOGGER.info("Server stopped.");
}
- protected void setClientLink(Link link) {
- _clientLink = link;
- }
-
- Random randomGenerator = new Random();
-
- byte[] _testBytes;
-
+ @Test(timeout=TESTTIMEOUT)
public void testConnection() {
- _testBytes = new byte[1000000];
- randomGenerator.nextBytes(_testBytes);
- try {
- getOneMoreTest();
- _clientLink.send(_testBytes);
- s_logger.info("Client: Data sent");
- getOneMoreTest();
- _clientLink.send(_testBytes);
- s_logger.info("Client: Data sent");
- } catch (ClosedChannelException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
+ while (!isTestsDone()) {
+ try {
+ LOGGER.debug(completedTestCount + "/" + totalTestCount + " tests done. Waiting for completion");
+ Thread.sleep(1000);
+ } catch (final InterruptedException e) {
+ Assert.fail(e.getMessage());
+ }
+ }
+ LOGGER.debug(completedTestCount + "/" + totalTestCount + " tests done.");
+ }
+
+ protected void doServerProcess(final byte[] data) {
+ oneMoreTestDone();
+ Assert.assertArrayEquals(testBytes, data);
+ LOGGER.info("Verify data received by server done.");
+ }
+
+ public byte[] getTestBytes() {
+ return testBytes;
+ }
+
+ public class ThreadedNioClient implements Runnable {
+ final private NioClient client;
+ ThreadedNioClient(final NioClient client) {
+ this.client = client;
+ }
+
+ @Override
+ public void run() {
+ try {
+ client.start();
+ } catch (Exception e) {
+ Assert.fail(e.getMessage());
+ }
}
}
- protected void doServerProcess(byte[] data) {
- oneMoreTestDone();
- Assert.assertArrayEquals(_testBytes, data);
- s_logger.info("Verify done.");
+ public class NioMaliciousClient extends NioClient {
+
+ public NioMaliciousClient(String name, String host, int port, int workers, HandlerFactory factory) {
+ super(name, host, port, workers, factory);
+ }
+
+ @Override
+ protected void init() throws IOException {
+ _selector = Selector.open();
+ try {
+ _clientConnection = SocketChannel.open();
+ LOGGER.info("Connecting to " + _host + ":" + _port);
+ final InetSocketAddress peerAddr = new InetSocketAddress(_host, _port);
+ _clientConnection.connect(peerAddr);
+ // This is done on purpose, the malicious client would connect
+ // to the server and then do nothing, hence using a large sleep value
+ Thread.sleep(Long.MAX_VALUE);
+ } catch (final IOException e) {
+ _selector.close();
+ throw e;
+ } catch (InterruptedException e) {
+ LOGGER.debug(e.getMessage());
+ }
+ }
+ }
+
+ public class NioMaliciousTestClient implements HandlerFactory {
+
+ @Override
+ public Task create(final Type type, final Link link, final byte[] data) {
+ return new NioMaliciousTestClientHandler(type, link, data);
+ }
+
+ public class NioMaliciousTestClientHandler extends Task {
+
+ public NioMaliciousTestClientHandler(final Type type, final Link link, final byte[] data) {
+ super(type, link, data);
+ }
+
+ @Override
+ public void doTask(final Task task) {
+ LOGGER.info("Malicious Client: Received task " + task.getType().toString());
+ }
+ }
}
public class NioTestClient implements HandlerFactory {
@Override
- public Task create(Type type, Link link, byte[] data) {
+ public Task create(final Type type, final Link link, final byte[] data) {
return new NioTestClientHandler(type, link, data);
}
public class NioTestClientHandler extends Task {
- public NioTestClientHandler(Type type, Link link, byte[] data) {
+ public NioTestClientHandler(final Type type, final Link link, final byte[] data) {
super(type, link, data);
}
@Override
public void doTask(final Task task) {
if (task.getType() == Task.Type.CONNECT) {
- s_logger.info("Client: Received CONNECT task");
- setClientLink(task.getLink());
+ LOGGER.info("Client: Received CONNECT task");
+ try {
+ LOGGER.info("Sending data to server");
+ task.getLink().send(getTestBytes());
+ } catch (ClosedChannelException e) {
+ LOGGER.error(e.getMessage());
+ e.printStackTrace();
+ }
} else if (task.getType() == Task.Type.DATA) {
- s_logger.info("Client: Received DATA task");
+ LOGGER.info("Client: Received DATA task");
} else if (task.getType() == Task.Type.DISCONNECT) {
- s_logger.info("Client: Received DISCONNECT task");
+ LOGGER.info("Client: Received DISCONNECT task");
stopClient();
} else if (task.getType() == Task.Type.OTHER) {
- s_logger.info("Client: Received OTHER task");
+ LOGGER.info("Client: Received OTHER task");
}
}
-
}
}
public class NioTestServer implements HandlerFactory {
@Override
- public Task create(Type type, Link link, byte[] data) {
+ public Task create(final Type type, final Link link, final byte[] data) {
return new NioTestServerHandler(type, link, data);
}
public class NioTestServerHandler extends Task {
- public NioTestServerHandler(Type type, Link link, byte[] data) {
+ public NioTestServerHandler(final Type type, final Link link, final byte[] data) {
super(type, link, data);
}
@Override
public void doTask(final Task task) {
if (task.getType() == Task.Type.CONNECT) {
- s_logger.info("Server: Received CONNECT task");
+ LOGGER.info("Server: Received CONNECT task");
} else if (task.getType() == Task.Type.DATA) {
- s_logger.info("Server: Received DATA task");
+ LOGGER.info("Server: Received DATA task");
doServerProcess(task.getData());
} else if (task.getType() == Task.Type.DISCONNECT) {
- s_logger.info("Server: Received DISCONNECT task");
+ LOGGER.info("Server: Received DISCONNECT task");
stopServer();
} else if (task.getType() == Task.Type.OTHER) {
- s_logger.info("Server: Received OTHER task");
+ LOGGER.info("Server: Received OTHER task");
}
}
-
}
}
}