diff --git a/engine/orchestration/src/com/cloud/agent/manager/ClusteredAgentManagerImpl.java b/engine/orchestration/src/com/cloud/agent/manager/ClusteredAgentManagerImpl.java index ba82938fcdc..f65d83bab5b 100755 --- a/engine/orchestration/src/com/cloud/agent/manager/ClusteredAgentManagerImpl.java +++ b/engine/orchestration/src/com/cloud/agent/manager/ClusteredAgentManagerImpl.java @@ -501,7 +501,7 @@ public class ClusteredAgentManagerImpl extends AgentManagerImpl implements Clust SocketChannel ch1 = null; try { ch1 = SocketChannel.open(new InetSocketAddress(addr, Port.value())); - ch1.configureBlocking(true); // make sure we are working at blocking mode + ch1.configureBlocking(false); ch1.socket().setKeepAlive(true); ch1.socket().setSoTimeout(60 * 1000); try { @@ -509,8 +509,11 @@ public class ClusteredAgentManagerImpl extends AgentManagerImpl implements Clust sslEngine = sslContext.createSSLEngine(ip, Port.value()); sslEngine.setUseClientMode(true); sslEngine.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslEngine.getEnabledProtocols())); - - Link.doHandshake(ch1, sslEngine, true); + sslEngine.beginHandshake(); + if (!Link.doHandshake(ch1, sslEngine, true)) { + ch1.close(); + throw new IOException("SSL handshake failed!"); + } s_logger.info("SSL: Handshake done"); } catch (Exception e) { ch1.close(); diff --git a/utils/pom.xml b/utils/pom.xml index 0aec5c62e62..0b84f38938a 100755 --- a/utils/pom.xml +++ b/utils/pom.xml @@ -178,10 +178,9 @@ com/cloud/utils/testcase/*TestCase* com/cloud/utils/db/*Test* - com/cloud/utils/testcase/NioTest.java - + com.mycila license-maven-plugin diff --git a/utils/src/com/cloud/utils/nio/Link.java b/utils/src/com/cloud/utils/nio/Link.java index ddfd474dfff..729af6c0d02 100755 --- a/utils/src/com/cloud/utils/nio/Link.java +++ b/utils/src/com/cloud/utils/nio/Link.java @@ -24,11 +24,8 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.net.InetSocketAddress; -import java.net.SocketTimeoutException; import java.nio.ByteBuffer; -import java.nio.channels.Channels; import java.nio.channels.ClosedChannelException; -import java.nio.channels.ReadableByteChannel; import java.nio.channels.SelectionKey; import java.nio.channels.SocketChannel; import java.security.GeneralSecurityException; @@ -40,6 +37,7 @@ import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLException; import javax.net.ssl.SSLSession; import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; @@ -453,115 +451,185 @@ public class Link { return sslContext; } - public static void doHandshake(SocketChannel ch, SSLEngine sslEngine, boolean isClient) throws IOException { - if (s_logger.isTraceEnabled()) { - s_logger.trace("SSL: begin Handshake, isClient: " + isClient); + public static ByteBuffer enlargeBuffer(ByteBuffer buffer, final int sessionProposedCapacity) { + if (buffer == null || sessionProposedCapacity < 0) { + return buffer; } - - SSLEngineResult engResult; - SSLSession sslSession = sslEngine.getSession(); - HandshakeStatus hsStatus; - ByteBuffer in_pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40); - ByteBuffer in_appBuf = ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40); - ByteBuffer out_pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40); - ByteBuffer out_appBuf = ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40); - int count; - ch.socket().setSoTimeout(60 * 1000); - InputStream inStream = ch.socket().getInputStream(); - // Use readCh to make sure the timeout on reading is working - ReadableByteChannel readCh = Channels.newChannel(inStream); - - if (isClient) { - hsStatus = SSLEngineResult.HandshakeStatus.NEED_WRAP; + if (sessionProposedCapacity > buffer.capacity()) { + buffer = ByteBuffer.allocate(sessionProposedCapacity); } else { - hsStatus = SSLEngineResult.HandshakeStatus.NEED_UNWRAP; + buffer = ByteBuffer.allocate(buffer.capacity() * 2); } + return buffer; + } - while (hsStatus != SSLEngineResult.HandshakeStatus.FINISHED) { - if (s_logger.isTraceEnabled()) { - s_logger.trace("SSL: Handshake status " + hsStatus); - } - engResult = null; - if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP) { - out_pkgBuf.clear(); - out_appBuf.clear(); - out_appBuf.put("Hello".getBytes()); - engResult = sslEngine.wrap(out_appBuf, out_pkgBuf); - out_pkgBuf.flip(); - int remain = out_pkgBuf.limit(); - while (remain != 0) { - remain -= ch.write(out_pkgBuf); - if (remain < 0) { - throw new IOException("Too much bytes sent?"); - } - } - } else if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) { - in_appBuf.clear(); - // One packet may contained multiply operation - if (in_pkgBuf.position() == 0 || !in_pkgBuf.hasRemaining()) { - in_pkgBuf.clear(); - count = 0; - try { - count = readCh.read(in_pkgBuf); - } catch (SocketTimeoutException ex) { - if (s_logger.isTraceEnabled()) { - s_logger.trace("Handshake reading time out! Cut the connection"); - } - count = -1; - } - if (count == -1) { - throw new IOException("Connection closed with -1 on reading size."); - } - in_pkgBuf.flip(); - } - engResult = sslEngine.unwrap(in_pkgBuf, in_appBuf); - ByteBuffer tmp_pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40); - int loop_count = 0; - while (engResult.getStatus() == SSLEngineResult.Status.BUFFER_UNDERFLOW) { - // The client is too slow? Cut it and let it reconnect - if (loop_count > 10) { - throw new IOException("Too many times in SSL BUFFER_UNDERFLOW, disconnect guest."); - } - // We need more packets to complete this operation - if (s_logger.isTraceEnabled()) { - s_logger.trace("SSL: Buffer underflowed, getting more packets"); - } - tmp_pkgBuf.clear(); - count = ch.read(tmp_pkgBuf); - if (count == -1) { - throw new IOException("Connection closed with -1 on reading size."); - } - tmp_pkgBuf.flip(); - - in_pkgBuf.mark(); - in_pkgBuf.position(in_pkgBuf.limit()); - in_pkgBuf.limit(in_pkgBuf.limit() + tmp_pkgBuf.limit()); - in_pkgBuf.put(tmp_pkgBuf); - in_pkgBuf.reset(); - - in_appBuf.clear(); - engResult = sslEngine.unwrap(in_pkgBuf, in_appBuf); - loop_count++; - } - } else if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK) { - Runnable run; - while ((run = sslEngine.getDelegatedTask()) != null) { - if (s_logger.isTraceEnabled()) { - s_logger.trace("SSL: Running delegated task!"); - } - run.run(); - } - } else if (hsStatus == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) { - throw new IOException("NOT a handshaking!"); - } - if (engResult != null && engResult.getStatus() != SSLEngineResult.Status.OK) { - throw new IOException("Fail to handshake! " + engResult.getStatus()); - } - if (engResult != null) - hsStatus = engResult.getHandshakeStatus(); - else - hsStatus = sslEngine.getHandshakeStatus(); + public static ByteBuffer handleBufferUnderflow(final SSLEngine engine, ByteBuffer buffer) { + if (engine == null || buffer == null) { + return buffer; } + if (buffer.position() < buffer.limit()) { + return buffer; + } + ByteBuffer replaceBuffer = enlargeBuffer(buffer, engine.getSession().getPacketBufferSize()); + buffer.flip(); + replaceBuffer.put(buffer); + return replaceBuffer; + } + + private static boolean doHandshakeUnwrap(final SocketChannel socketChannel, final SSLEngine sslEngine, + ByteBuffer peerAppData, ByteBuffer peerNetData, final int appBufferSize) throws IOException { + if (socketChannel == null || sslEngine == null || peerAppData == null || peerNetData == null || appBufferSize < 0) { + return false; + } + if (socketChannel.read(peerNetData) < 0) { + if (sslEngine.isInboundDone() && sslEngine.isOutboundDone()) { + return false; + } + try { + sslEngine.closeInbound(); + } catch (SSLException e) { + s_logger.warn("This SSL engine was forced to close inbound due to end of stream."); + } + sslEngine.closeOutbound(); + // After closeOutbound the engine will be set to WRAP state, + // in order to try to send a close message to the client. + return true; + } + peerNetData.flip(); + SSLEngineResult result = null; + try { + result = sslEngine.unwrap(peerNetData, peerAppData); + peerNetData.compact(); + } catch (SSLException sslException) { + s_logger.error("SSL error occurred while processing unwrap data: " + sslException.getMessage()); + sslEngine.closeOutbound(); + return true; + } + switch (result.getStatus()) { + case OK: + break; + case BUFFER_OVERFLOW: + // Will occur when peerAppData's capacity is smaller than the data derived from peerNetData's unwrap. + peerAppData = enlargeBuffer(peerAppData, appBufferSize); + break; + case BUFFER_UNDERFLOW: + // Will occur either when no data was read from the peer or when the peerNetData buffer + // was too small to hold all peer's data. + peerNetData = handleBufferUnderflow(sslEngine, peerNetData); + break; + case CLOSED: + if (sslEngine.isOutboundDone()) { + return false; + } else { + sslEngine.closeOutbound(); + break; + } + default: + throw new IllegalStateException("Invalid SSL status: " + result.getStatus()); + } + return true; + } + + private static boolean doHandshakeWrap(final SocketChannel socketChannel, final SSLEngine sslEngine, + ByteBuffer myAppData, ByteBuffer myNetData, ByteBuffer peerNetData, + final int netBufferSize) throws IOException { + if (socketChannel == null || sslEngine == null || myNetData == null || peerNetData == null + || myAppData == null || netBufferSize < 0) { + return false; + } + myNetData.clear(); + SSLEngineResult result = null; + try { + result = sslEngine.wrap(myAppData, myNetData); + } catch (SSLException sslException) { + s_logger.error("SSL error occurred while processing wrap data: " + sslException.getMessage()); + sslEngine.closeOutbound(); + return true; + } + switch (result.getStatus()) { + case OK : + myNetData.flip(); + while (myNetData.hasRemaining()) { + socketChannel.write(myNetData); + } + break; + case BUFFER_OVERFLOW: + // Will occur if there is not enough space in myNetData buffer to write all the data + // that would be generated by the method wrap. Since myNetData is set to session's packet + // size we should not get to this point because SSLEngine is supposed to produce messages + // smaller or equal to that, but a general handling would be the following: + myNetData = enlargeBuffer(myNetData, netBufferSize); + break; + case BUFFER_UNDERFLOW: + throw new SSLException("Buffer underflow occurred after a wrap. We should not reach here."); + case CLOSED: + try { + myNetData.flip(); + while (myNetData.hasRemaining()) { + socketChannel.write(myNetData); + } + // At this point the handshake status will probably be NEED_UNWRAP + // so we make sure that peerNetData is clear to read. + peerNetData.clear(); + } catch (Exception e) { + s_logger.error("Failed to send server's CLOSE message due to socket channel's failure."); + } + break; + default: + throw new IllegalStateException("Invalid SSL status: " + result.getStatus()); + } + return true; + } + + public static boolean doHandshake(final SocketChannel socketChannel, final SSLEngine sslEngine, final boolean isClient) throws IOException { + if (socketChannel == null || sslEngine == null) { + return false; + } + final int appBufferSize = sslEngine.getSession().getApplicationBufferSize(); + final int netBufferSize = sslEngine.getSession().getPacketBufferSize(); + ByteBuffer myAppData = ByteBuffer.allocate(appBufferSize); + ByteBuffer peerAppData = ByteBuffer.allocate(appBufferSize); + ByteBuffer myNetData = ByteBuffer.allocate(netBufferSize); + ByteBuffer peerNetData = ByteBuffer.allocate(netBufferSize); + + final long startTimeMills = System.currentTimeMillis(); + + HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus(); + while (handshakeStatus != SSLEngineResult.HandshakeStatus.FINISHED + && handshakeStatus != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) { + final long timeTaken = System.currentTimeMillis() - startTimeMills; + if (timeTaken > 15000L) { + s_logger.warn("SSL Handshake has taken more than 15s to connect to: " + socketChannel.getRemoteAddress() + + ". Please investigate this connection."); + return false; + } + switch (handshakeStatus) { + case NEED_UNWRAP: + if (!doHandshakeUnwrap(socketChannel, sslEngine, peerAppData, peerNetData, appBufferSize)) { + return false; + } + break; + case NEED_WRAP: + if (!doHandshakeWrap(socketChannel, sslEngine, myAppData, myNetData, peerNetData, netBufferSize)) { + return false; + } + break; + case NEED_TASK: + Runnable task; + while ((task = sslEngine.getDelegatedTask()) != null) { + new Thread(task).run(); + } + break; + case FINISHED: + break; + case NOT_HANDSHAKING: + break; + default: + throw new IllegalStateException("Invalid SSL status: " + handshakeStatus); + } + handshakeStatus = sslEngine.getHandshakeStatus(); + } + return true; } } diff --git a/utils/src/com/cloud/utils/nio/NioClient.java b/utils/src/com/cloud/utils/nio/NioClient.java index 2f742f99dc5..620a354545f 100755 --- a/utils/src/com/cloud/utils/nio/NioClient.java +++ b/utils/src/com/cloud/utils/nio/NioClient.java @@ -37,7 +37,6 @@ public class NioClient extends NioConnection { private static final Logger s_logger = Logger.getLogger(NioClient.class); protected String _host; - protected String _bindAddress; protected SocketChannel _clientConnection; public NioClient(String name, String host, int port, int workers, HandlerFactory factory) { @@ -45,10 +44,6 @@ public class NioClient extends NioConnection { _host = host; } - public void setBindAddress(String ipAddress) { - _bindAddress = ipAddress; - } - @Override protected void init() throws IOException { _selector = Selector.open(); @@ -56,29 +51,23 @@ public class NioClient extends NioConnection { try { _clientConnection = SocketChannel.open(); - _clientConnection.configureBlocking(true); s_logger.info("Connecting to " + _host + ":" + _port); - if (_bindAddress != null) { - s_logger.info("Binding outbound interface at " + _bindAddress); - - InetSocketAddress bindAddr = new InetSocketAddress(_bindAddress, 0); - _clientConnection.socket().bind(bindAddr); - } - InetSocketAddress peerAddr = new InetSocketAddress(_host, _port); _clientConnection.connect(peerAddr); + _clientConnection.configureBlocking(false); - SSLEngine sslEngine = null; - // Begin SSL handshake in BLOCKING mode - _clientConnection.configureBlocking(true); - - SSLContext sslContext = Link.initSSLContext(true); - sslEngine = sslContext.createSSLEngine(_host, _port); + final SSLContext sslContext = Link.initSSLContext(true); + SSLEngine sslEngine = sslContext.createSSLEngine(_host, _port); sslEngine.setUseClientMode(true); sslEngine.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslEngine.getEnabledProtocols())); - Link.doHandshake(_clientConnection, sslEngine, true); + sslEngine.beginHandshake(); + if (!Link.doHandshake(_clientConnection, sslEngine, true)) { + s_logger.error("SSL Handshake failed while connecting to host: " + _host + " port: " + _port); + _selector.close(); + throw new IOException("SSL Handshake failed while connecting to host: " + _host + " port: " + _port); + } s_logger.info("SSL: Handshake done"); s_logger.info("Connected to " + _host + ":" + _port); diff --git a/utils/src/com/cloud/utils/nio/NioConnection.java b/utils/src/com/cloud/utils/nio/NioConnection.java index 34679b8277b..fa92959e3e0 100755 --- a/utils/src/com/cloud/utils/nio/NioConnection.java +++ b/utils/src/com/cloud/utils/nio/NioConnection.java @@ -33,6 +33,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Set; +import java.util.concurrent.Executors; import java.util.concurrent.ExecutorService; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadPoolExecutor; @@ -63,6 +64,7 @@ public abstract class NioConnection implements Runnable { protected HandlerFactory _factory; protected String _name; protected ExecutorService _executor; + protected ExecutorService _sslHandshakeExecutor; public NioConnection(String name, int port, int workers, HandlerFactory factory) { _name = name; @@ -72,6 +74,7 @@ public abstract class NioConnection implements Runnable { _port = port; _factory = factory; _executor = new ThreadPoolExecutor(workers, 5 * workers, 1, TimeUnit.DAYS, new LinkedBlockingQueue(), new NamedThreadFactory(name + "-Handler")); + _sslHandshakeExecutor = Executors.newCachedThreadPool(new NamedThreadFactory(name + "-SSLHandshakeHandler")); } public void start() { @@ -167,6 +170,8 @@ public abstract class NioConnection implements Runnable { processTodos(); } catch (Throwable e) { s_logger.warn("Caught an exception but continuing on.", e); + } finally { + _selector.wakeup(); } } synchronized (_thread) { @@ -180,53 +185,69 @@ public abstract class NioConnection implements Runnable { abstract void unregisterLink(InetSocketAddress saddr); - protected void accept(SelectionKey key) throws IOException { - ServerSocketChannel serverSocketChannel = (ServerSocketChannel)key.channel(); + protected void accept(final SelectionKey key) throws IOException { + final ServerSocketChannel serverSocketChannel = (ServerSocketChannel)key.channel(); - SocketChannel socketChannel = serverSocketChannel.accept(); - Socket socket = socketChannel.socket(); + final SocketChannel socketChannel = serverSocketChannel.accept(); + socketChannel.configureBlocking(false); + final Socket socket = socketChannel.socket(); socket.setKeepAlive(true); - if (s_logger.isTraceEnabled()) { - s_logger.trace("Connection accepted for " + socket); - } - - // Begin SSL handshake in BLOCKING mode - socketChannel.configureBlocking(true); - - SSLEngine sslEngine = null; + final SSLEngine sslEngine; try { - SSLContext sslContext = Link.initSSLContext(false); + final SSLContext sslContext = Link.initSSLContext(false); sslEngine = sslContext.createSSLEngine(); sslEngine.setUseClientMode(false); sslEngine.setNeedClientAuth(false); sslEngine.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslEngine.getEnabledProtocols())); - Link.doHandshake(socketChannel, sslEngine, false); - - } catch (Exception e) { - if (s_logger.isTraceEnabled()) { - s_logger.trace("Socket " + socket + " closed on read. Probably -1 returned: " + e.getMessage()); - } + final NioConnection nioConnection = this; + _sslHandshakeExecutor.submit(new Runnable() { + @Override + public void run() { + _selector.wakeup(); + try { + sslEngine.beginHandshake(); + if (!Link.doHandshake(socketChannel, sslEngine, false)) { + throw new IOException("SSL handshake timed out with " + socketChannel.getRemoteAddress()); + } + if (s_logger.isTraceEnabled()) { + s_logger.trace("SSL: Handshake done"); + } + final InetSocketAddress saddr = (InetSocketAddress)socket.getRemoteSocketAddress(); + final Link link = new Link(saddr, nioConnection); + link.setSSLEngine(sslEngine); + link.setKey(socketChannel.register(key.selector(), SelectionKey.OP_READ, link)); + final Task task = _factory.create(Task.Type.CONNECT, link, null); + registerLink(saddr, link); + _executor.submit(task); + } catch (IOException e) { + if (s_logger.isTraceEnabled()) { + s_logger.trace("Connection closed due to failure: " + e.getMessage()); + } + try { + socketChannel.close(); + socket.close(); + } catch (IOException ignore) { + } + } finally { + _selector.wakeup(); + } + } + }); + } catch (final Exception e) { + if (s_logger.isTraceEnabled()) { + s_logger.trace("Connection closed due to failure: " + e.getMessage()); + } try { socketChannel.close(); socket.close(); } catch (IOException ignore) { } return; + } finally { + _selector.wakeup(); } - - if (s_logger.isTraceEnabled()) { - s_logger.trace("SSL: Handshake done"); - } - socketChannel.configureBlocking(false); - InetSocketAddress saddr = (InetSocketAddress)socket.getRemoteSocketAddress(); - Link link = new Link(saddr, this); - link.setSSLEngine(sslEngine); - link.setKey(socketChannel.register(key.selector(), SelectionKey.OP_READ, link)); - Task task = _factory.create(Task.Type.CONNECT, link, null); - registerLink(saddr, link); - _executor.execute(task); } protected void terminate(SelectionKey key) { diff --git a/utils/src/com/cloud/utils/nio/NioServer.java b/utils/src/com/cloud/utils/nio/NioServer.java index 98a4a51dbfa..adcecda4e64 100755 --- a/utils/src/com/cloud/utils/nio/NioServer.java +++ b/utils/src/com/cloud/utils/nio/NioServer.java @@ -43,6 +43,10 @@ public class NioServer extends NioConnection { _links = new WeakHashMap(1024); } + public int getPort() { + return _serverSocket.socket().getLocalPort(); + } + @Override protected void init() throws IOException { _selector = SelectorProvider.provider().openSelector(); @@ -55,7 +59,7 @@ public class NioServer extends NioConnection { _serverSocket.register(_selector, SelectionKey.OP_ACCEPT, null); - s_logger.info("NioConnection started and listening on " + _localAddr.toString()); + s_logger.info("NioConnection started and listening on " + _serverSocket.socket().getLocalSocketAddress()); } @Override diff --git a/utils/test/com/cloud/utils/testcase/NioTest.java b/utils/test/com/cloud/utils/testcase/NioTest.java index fc166847dc2..515119d9e7e 100644 --- a/utils/test/com/cloud/utils/testcase/NioTest.java +++ b/utils/test/com/cloud/utils/testcase/NioTest.java @@ -19,198 +19,278 @@ package com.cloud.utils.testcase; -import java.nio.channels.ClosedChannelException; -import java.util.Random; - -import junit.framework.TestCase; - -import org.apache.log4j.Logger; -import org.junit.Assert; - +import com.cloud.utils.concurrency.NamedThreadFactory; import com.cloud.utils.nio.HandlerFactory; import com.cloud.utils.nio.Link; import com.cloud.utils.nio.NioClient; import com.cloud.utils.nio.NioServer; import com.cloud.utils.nio.Task; import com.cloud.utils.nio.Task.Type; +import org.apache.log4j.Logger; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; /** - * - * - * - * + * NioTest demonstrates that NioServer can function without getting its main IO + * loop blocked when an aggressive or malicious client connects to the server but + * fail to participate in SSL handshake. In this test, we run bunch of clients + * that send a known payload to the server, to which multiple malicious clients + * also try to connect and hang. + * A malicious client could cause denial-of-service if the server's main IO loop + * along with SSL handshake was blocking. A passing tests shows that NioServer + * can still function in case of connection load and that the main IO loop along + * with SSL handshake is non-blocking with some internal timeout mechanism. */ -public class NioTest extends TestCase { +public class NioTest { - private static final Logger s_logger = Logger.getLogger(NioTest.class); + private static final Logger LOGGER = Logger.getLogger(NioTest.class); - private NioServer _server; - private NioClient _client; + // Test should fail in due time instead of looping forever + private static final int TESTTIMEOUT = 60000; - private Link _clientLink; + final private int totalTestCount = 4; + private int completedTestCount = 0; - private int _testCount; - private int _completedCount; + private NioServer server; + private List clients = new ArrayList<>(); + private List maliciousClients = new ArrayList<>(); + + private ExecutorService clientExecutor = Executors.newFixedThreadPool(totalTestCount, new NamedThreadFactory("NioClientHandler"));; + private ExecutorService maliciousExecutor = Executors.newFixedThreadPool(totalTestCount, new NamedThreadFactory("MaliciousNioClientHandler"));; + + private Random randomGenerator = new Random(); + private byte[] testBytes; private boolean isTestsDone() { boolean result; synchronized (this) { - result = (_testCount == _completedCount); + result = totalTestCount == completedTestCount; } return result; } - private void getOneMoreTest() { - synchronized (this) { - _testCount++; - } - } - private void oneMoreTestDone() { synchronized (this) { - _completedCount++; + completedTestCount++; } } - @Override + @Before public void setUp() { - s_logger.info("Test"); + LOGGER.info("Setting up Benchmark Test"); - _testCount = 0; - _completedCount = 0; + completedTestCount = 0; + testBytes = new byte[1000000]; + randomGenerator.nextBytes(testBytes); - _server = new NioServer("NioTestServer", 7777, 5, new NioTestServer()); - _server.start(); + server = new NioServer("NioTestServer", 0, 1, new NioTestServer()); + try { + server.start(); + } catch (Exception e) { + Assert.fail(e.getMessage()); + } - _client = new NioClient("NioTestServer", "127.0.0.1", 7777, 5, new NioTestClient()); - _client.start(); + /** + * The malicious client(s) tries to block NioServer's main IO loop + * thread until SSL handshake timeout value (from Link class, 15s) after + * which the valid NioClient(s) get the opportunity to make connection(s) + */ + for (int i = 0; i < totalTestCount; i++) { + final NioClient maliciousClient = new NioMaliciousClient("NioMaliciousTestClient-" + i, "127.0.0.1", server.getPort(), 1, new NioMaliciousTestClient()); + maliciousClients.add(maliciousClient); + maliciousExecutor.submit(new ThreadedNioClient(maliciousClient)); + } - while (_clientLink == null) { - try { - s_logger.debug("Link is not up! Waiting ..."); - Thread.sleep(1000); - } catch (InterruptedException e) { - // TODO Auto-generated catch block - e.printStackTrace(); - } + for (int i = 0; i < totalTestCount; i++) { + final NioClient client = new NioClient("NioTestClient-" + i, "127.0.0.1", server.getPort(), 1, new NioTestClient()); + clients.add(client); + clientExecutor.submit(new ThreadedNioClient(client)); } } - @Override + @After public void tearDown() { - while (!isTestsDone()) { - try { - s_logger.debug(this._completedCount + "/" + this._testCount + " tests done. Waiting for completion"); - Thread.sleep(1000); - } catch (InterruptedException e) { - // TODO Auto-generated catch block - e.printStackTrace(); - } - } stopClient(); stopServer(); } protected void stopClient() { - _client.stop(); - s_logger.info("Client stopped."); + for (NioClient client : clients) { + client.stop(); + } + for (NioClient maliciousClient : maliciousClients) { + maliciousClient.stop(); + } + LOGGER.info("Clients stopped."); } protected void stopServer() { - _server.stop(); - s_logger.info("Server stopped."); + server.stop(); + LOGGER.info("Server stopped."); } - protected void setClientLink(Link link) { - _clientLink = link; - } - - Random randomGenerator = new Random(); - - byte[] _testBytes; - + @Test(timeout=TESTTIMEOUT) public void testConnection() { - _testBytes = new byte[1000000]; - randomGenerator.nextBytes(_testBytes); - try { - getOneMoreTest(); - _clientLink.send(_testBytes); - s_logger.info("Client: Data sent"); - getOneMoreTest(); - _clientLink.send(_testBytes); - s_logger.info("Client: Data sent"); - } catch (ClosedChannelException e) { - // TODO Auto-generated catch block - e.printStackTrace(); + while (!isTestsDone()) { + try { + LOGGER.debug(completedTestCount + "/" + totalTestCount + " tests done. Waiting for completion"); + Thread.sleep(1000); + } catch (final InterruptedException e) { + Assert.fail(e.getMessage()); + } + } + LOGGER.debug(completedTestCount + "/" + totalTestCount + " tests done."); + } + + protected void doServerProcess(final byte[] data) { + oneMoreTestDone(); + Assert.assertArrayEquals(testBytes, data); + LOGGER.info("Verify data received by server done."); + } + + public byte[] getTestBytes() { + return testBytes; + } + + public class ThreadedNioClient implements Runnable { + final private NioClient client; + ThreadedNioClient(final NioClient client) { + this.client = client; + } + + @Override + public void run() { + try { + client.start(); + } catch (Exception e) { + Assert.fail(e.getMessage()); + } } } - protected void doServerProcess(byte[] data) { - oneMoreTestDone(); - Assert.assertArrayEquals(_testBytes, data); - s_logger.info("Verify done."); + public class NioMaliciousClient extends NioClient { + + public NioMaliciousClient(String name, String host, int port, int workers, HandlerFactory factory) { + super(name, host, port, workers, factory); + } + + @Override + protected void init() throws IOException { + _selector = Selector.open(); + try { + _clientConnection = SocketChannel.open(); + LOGGER.info("Connecting to " + _host + ":" + _port); + final InetSocketAddress peerAddr = new InetSocketAddress(_host, _port); + _clientConnection.connect(peerAddr); + // This is done on purpose, the malicious client would connect + // to the server and then do nothing, hence using a large sleep value + Thread.sleep(Long.MAX_VALUE); + } catch (final IOException e) { + _selector.close(); + throw e; + } catch (InterruptedException e) { + LOGGER.debug(e.getMessage()); + } + } + } + + public class NioMaliciousTestClient implements HandlerFactory { + + @Override + public Task create(final Type type, final Link link, final byte[] data) { + return new NioMaliciousTestClientHandler(type, link, data); + } + + public class NioMaliciousTestClientHandler extends Task { + + public NioMaliciousTestClientHandler(final Type type, final Link link, final byte[] data) { + super(type, link, data); + } + + @Override + public void doTask(final Task task) { + LOGGER.info("Malicious Client: Received task " + task.getType().toString()); + } + } } public class NioTestClient implements HandlerFactory { @Override - public Task create(Type type, Link link, byte[] data) { + public Task create(final Type type, final Link link, final byte[] data) { return new NioTestClientHandler(type, link, data); } public class NioTestClientHandler extends Task { - public NioTestClientHandler(Type type, Link link, byte[] data) { + public NioTestClientHandler(final Type type, final Link link, final byte[] data) { super(type, link, data); } @Override public void doTask(final Task task) { if (task.getType() == Task.Type.CONNECT) { - s_logger.info("Client: Received CONNECT task"); - setClientLink(task.getLink()); + LOGGER.info("Client: Received CONNECT task"); + try { + LOGGER.info("Sending data to server"); + task.getLink().send(getTestBytes()); + } catch (ClosedChannelException e) { + LOGGER.error(e.getMessage()); + e.printStackTrace(); + } } else if (task.getType() == Task.Type.DATA) { - s_logger.info("Client: Received DATA task"); + LOGGER.info("Client: Received DATA task"); } else if (task.getType() == Task.Type.DISCONNECT) { - s_logger.info("Client: Received DISCONNECT task"); + LOGGER.info("Client: Received DISCONNECT task"); stopClient(); } else if (task.getType() == Task.Type.OTHER) { - s_logger.info("Client: Received OTHER task"); + LOGGER.info("Client: Received OTHER task"); } } - } } public class NioTestServer implements HandlerFactory { @Override - public Task create(Type type, Link link, byte[] data) { + public Task create(final Type type, final Link link, final byte[] data) { return new NioTestServerHandler(type, link, data); } public class NioTestServerHandler extends Task { - public NioTestServerHandler(Type type, Link link, byte[] data) { + public NioTestServerHandler(final Type type, final Link link, final byte[] data) { super(type, link, data); } @Override public void doTask(final Task task) { if (task.getType() == Task.Type.CONNECT) { - s_logger.info("Server: Received CONNECT task"); + LOGGER.info("Server: Received CONNECT task"); } else if (task.getType() == Task.Type.DATA) { - s_logger.info("Server: Received DATA task"); + LOGGER.info("Server: Received DATA task"); doServerProcess(task.getData()); } else if (task.getType() == Task.Type.DISCONNECT) { - s_logger.info("Server: Received DISCONNECT task"); + LOGGER.info("Server: Received DISCONNECT task"); stopServer(); } else if (task.getType() == Task.Type.OTHER) { - s_logger.info("Server: Received OTHER task"); + LOGGER.info("Server: Received OTHER task"); } } - } } }