diff --git a/engine/orchestration/src/com/cloud/agent/manager/ClusteredAgentManagerImpl.java b/engine/orchestration/src/com/cloud/agent/manager/ClusteredAgentManagerImpl.java index 9239adc0911..75f860ddfad 100644 --- a/engine/orchestration/src/com/cloud/agent/manager/ClusteredAgentManagerImpl.java +++ b/engine/orchestration/src/com/cloud/agent/manager/ClusteredAgentManagerImpl.java @@ -499,7 +499,7 @@ public class ClusteredAgentManagerImpl extends AgentManagerImpl implements Clust SocketChannel ch1 = null; try { ch1 = SocketChannel.open(new InetSocketAddress(addr, Port.value())); - ch1.configureBlocking(false); + ch1.configureBlocking(true); // make sure we are working at blocking mode ch1.socket().setKeepAlive(true); ch1.socket().setSoTimeout(60 * 1000); try { @@ -507,11 +507,8 @@ public class ClusteredAgentManagerImpl extends AgentManagerImpl implements Clust sslEngine = sslContext.createSSLEngine(ip, Port.value()); sslEngine.setUseClientMode(true); sslEngine.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslEngine.getEnabledProtocols())); - sslEngine.beginHandshake(); - if (!Link.doHandshake(ch1, sslEngine, true)) { - ch1.close(); - throw new IOException("SSL handshake failed!"); - } + + Link.doHandshake(ch1, sslEngine, true); s_logger.info("SSL: Handshake done"); } catch (final Exception e) { ch1.close(); diff --git a/utils/pom.xml b/utils/pom.xml index 9e2358680f3..206eb1896a6 100755 --- a/utils/pom.xml +++ b/utils/pom.xml @@ -208,6 +208,7 @@ com/cloud/utils/testcase/*TestCase* com/cloud/utils/db/*Test* + com/cloud/utils/testcase/NioTest.java diff --git a/utils/src/main/java/com/cloud/utils/nio/Link.java b/utils/src/main/java/com/cloud/utils/nio/Link.java index f297d52c077..6d6306a53b8 100644 --- a/utils/src/main/java/com/cloud/utils/nio/Link.java +++ b/utils/src/main/java/com/cloud/utils/nio/Link.java @@ -19,32 +19,36 @@ package com.cloud.utils.nio; -import com.cloud.utils.PropertiesUtil; -import com.cloud.utils.db.DbProperties; -import org.apache.cloudstack.utils.security.SSLUtils; -import org.apache.log4j.Logger; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.InetSocketAddress; +import java.net.SocketTimeoutException; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.SelectionKey; +import java.nio.channels.SocketChannel; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.util.concurrent.ConcurrentLinkedQueue; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLEngineResult.HandshakeStatus; -import javax.net.ssl.SSLException; import javax.net.ssl.SSLSession; import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.net.InetSocketAddress; -import java.nio.ByteBuffer; -import java.nio.channels.ClosedChannelException; -import java.nio.channels.SelectionKey; -import java.nio.channels.SocketChannel; -import java.security.GeneralSecurityException; -import java.security.KeyStore; -import java.util.concurrent.ConcurrentLinkedQueue; + +import org.apache.cloudstack.utils.security.SSLUtils; +import org.apache.log4j.Logger; + +import com.cloud.utils.PropertiesUtil; +import com.cloud.utils.db.DbProperties; /** */ @@ -449,185 +453,115 @@ public class Link { return sslContext; } - public static ByteBuffer enlargeBuffer(ByteBuffer buffer, final int sessionProposedCapacity) { - if (buffer == null || sessionProposedCapacity < 0) { - return buffer; + public static void doHandshake(SocketChannel ch, SSLEngine sslEngine, boolean isClient) throws IOException { + if (s_logger.isTraceEnabled()) { + s_logger.trace("SSL: begin Handshake, isClient: " + isClient); } - if (sessionProposedCapacity > buffer.capacity()) { - buffer = ByteBuffer.allocate(sessionProposedCapacity); + + SSLEngineResult engResult; + SSLSession sslSession = sslEngine.getSession(); + HandshakeStatus hsStatus; + ByteBuffer in_pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40); + ByteBuffer in_appBuf = ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40); + ByteBuffer out_pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40); + ByteBuffer out_appBuf = ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40); + int count; + ch.socket().setSoTimeout(60 * 1000); + InputStream inStream = ch.socket().getInputStream(); + // Use readCh to make sure the timeout on reading is working + ReadableByteChannel readCh = Channels.newChannel(inStream); + + if (isClient) { + hsStatus = SSLEngineResult.HandshakeStatus.NEED_WRAP; } else { - buffer = ByteBuffer.allocate(buffer.capacity() * 2); + hsStatus = SSLEngineResult.HandshakeStatus.NEED_UNWRAP; } - return buffer; - } - public static ByteBuffer handleBufferUnderflow(final SSLEngine engine, ByteBuffer buffer) { - if (engine == null || buffer == null) { - return buffer; - } - if (buffer.position() < buffer.limit()) { - return buffer; - } - ByteBuffer replaceBuffer = enlargeBuffer(buffer, engine.getSession().getPacketBufferSize()); - buffer.flip(); - replaceBuffer.put(buffer); - return replaceBuffer; - } - - private static boolean doHandshakeUnwrap(final SocketChannel socketChannel, final SSLEngine sslEngine, - ByteBuffer peerAppData, ByteBuffer peerNetData, final int appBufferSize) throws IOException { - if (socketChannel == null || sslEngine == null || peerAppData == null || peerNetData == null || appBufferSize < 0) { - return false; - } - if (socketChannel.read(peerNetData) < 0) { - if (sslEngine.isInboundDone() && sslEngine.isOutboundDone()) { - return false; + while (hsStatus != SSLEngineResult.HandshakeStatus.FINISHED) { + if (s_logger.isTraceEnabled()) { + s_logger.trace("SSL: Handshake status " + hsStatus); } - try { - sslEngine.closeInbound(); - } catch (SSLException e) { - s_logger.warn("This SSL engine was forced to close inbound due to end of stream."); - } - sslEngine.closeOutbound(); - // After closeOutbound the engine will be set to WRAP state, - // in order to try to send a close message to the client. - return true; - } - peerNetData.flip(); - SSLEngineResult result = null; - try { - result = sslEngine.unwrap(peerNetData, peerAppData); - peerNetData.compact(); - } catch (SSLException sslException) { - s_logger.error("SSL error occurred while processing unwrap data: " + sslException.getMessage()); - sslEngine.closeOutbound(); - return true; - } - switch (result.getStatus()) { - case OK: - break; - case BUFFER_OVERFLOW: - // Will occur when peerAppData's capacity is smaller than the data derived from peerNetData's unwrap. - peerAppData = enlargeBuffer(peerAppData, appBufferSize); - break; - case BUFFER_UNDERFLOW: - // Will occur either when no data was read from the peer or when the peerNetData buffer - // was too small to hold all peer's data. - peerNetData = handleBufferUnderflow(sslEngine, peerNetData); - break; - case CLOSED: - if (sslEngine.isOutboundDone()) { - return false; - } else { - sslEngine.closeOutbound(); - break; + engResult = null; + if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP) { + out_pkgBuf.clear(); + out_appBuf.clear(); + out_appBuf.put("Hello".getBytes()); + engResult = sslEngine.wrap(out_appBuf, out_pkgBuf); + out_pkgBuf.flip(); + int remain = out_pkgBuf.limit(); + while (remain != 0) { + remain -= ch.write(out_pkgBuf); + if (remain < 0) { + throw new IOException("Too much bytes sent?"); + } } - default: - throw new IllegalStateException("Invalid SSL status: " + result.getStatus()); - } - return true; - } - - private static boolean doHandshakeWrap(final SocketChannel socketChannel, final SSLEngine sslEngine, - ByteBuffer myAppData, ByteBuffer myNetData, ByteBuffer peerNetData, - final int netBufferSize) throws IOException { - if (socketChannel == null || sslEngine == null || myNetData == null || peerNetData == null - || myAppData == null || netBufferSize < 0) { - return false; - } - myNetData.clear(); - SSLEngineResult result = null; - try { - result = sslEngine.wrap(myAppData, myNetData); - } catch (SSLException sslException) { - s_logger.error("SSL error occurred while processing wrap data: " + sslException.getMessage()); - sslEngine.closeOutbound(); - return true; - } - switch (result.getStatus()) { - case OK : - myNetData.flip(); - while (myNetData.hasRemaining()) { - socketChannel.write(myNetData); + } else if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) { + in_appBuf.clear(); + // One packet may contained multiply operation + if (in_pkgBuf.position() == 0 || !in_pkgBuf.hasRemaining()) { + in_pkgBuf.clear(); + count = 0; + try { + count = readCh.read(in_pkgBuf); + } catch (SocketTimeoutException ex) { + if (s_logger.isTraceEnabled()) { + s_logger.trace("Handshake reading time out! Cut the connection"); + } + count = -1; + } + if (count == -1) { + throw new IOException("Connection closed with -1 on reading size."); + } + in_pkgBuf.flip(); } - break; - case BUFFER_OVERFLOW: - // Will occur if there is not enough space in myNetData buffer to write all the data - // that would be generated by the method wrap. Since myNetData is set to session's packet - // size we should not get to this point because SSLEngine is supposed to produce messages - // smaller or equal to that, but a general handling would be the following: - myNetData = enlargeBuffer(myNetData, netBufferSize); - break; - case BUFFER_UNDERFLOW: - throw new SSLException("Buffer underflow occurred after a wrap. We should not reach here."); - case CLOSED: - try { - myNetData.flip(); - while (myNetData.hasRemaining()) { - socketChannel.write(myNetData); + engResult = sslEngine.unwrap(in_pkgBuf, in_appBuf); + ByteBuffer tmp_pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40); + int loop_count = 0; + while (engResult.getStatus() == SSLEngineResult.Status.BUFFER_UNDERFLOW) { + // The client is too slow? Cut it and let it reconnect + if (loop_count > 10) { + throw new IOException("Too many times in SSL BUFFER_UNDERFLOW, disconnect guest."); } - // At this point the handshake status will probably be NEED_UNWRAP - // so we make sure that peerNetData is clear to read. - peerNetData.clear(); - } catch (Exception e) { - s_logger.error("Failed to send server's CLOSE message due to socket channel's failure."); + // We need more packets to complete this operation + if (s_logger.isTraceEnabled()) { + s_logger.trace("SSL: Buffer underflowed, getting more packets"); + } + tmp_pkgBuf.clear(); + count = ch.read(tmp_pkgBuf); + if (count == -1) { + throw new IOException("Connection closed with -1 on reading size."); + } + tmp_pkgBuf.flip(); + + in_pkgBuf.mark(); + in_pkgBuf.position(in_pkgBuf.limit()); + in_pkgBuf.limit(in_pkgBuf.limit() + tmp_pkgBuf.limit()); + in_pkgBuf.put(tmp_pkgBuf); + in_pkgBuf.reset(); + + in_appBuf.clear(); + engResult = sslEngine.unwrap(in_pkgBuf, in_appBuf); + loop_count++; } - break; - default: - throw new IllegalStateException("Invalid SSL status: " + result.getStatus()); - } - return true; - } - - public static boolean doHandshake(final SocketChannel socketChannel, final SSLEngine sslEngine, final boolean isClient) throws IOException { - if (socketChannel == null || sslEngine == null) { - return false; - } - final int appBufferSize = sslEngine.getSession().getApplicationBufferSize(); - final int netBufferSize = sslEngine.getSession().getPacketBufferSize(); - ByteBuffer myAppData = ByteBuffer.allocate(appBufferSize); - ByteBuffer peerAppData = ByteBuffer.allocate(appBufferSize); - ByteBuffer myNetData = ByteBuffer.allocate(netBufferSize); - ByteBuffer peerNetData = ByteBuffer.allocate(netBufferSize); - - final long startTimeMills = System.currentTimeMillis(); - - HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus(); - while (handshakeStatus != SSLEngineResult.HandshakeStatus.FINISHED - && handshakeStatus != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) { - final long timeTaken = System.currentTimeMillis() - startTimeMills; - if (timeTaken > 60000L) { - s_logger.warn("SSL Handshake has taken more than 60s to connect to: " + socketChannel.getRemoteAddress() + - ". Please investigate this connection."); - return false; + } else if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK) { + Runnable run; + while ((run = sslEngine.getDelegatedTask()) != null) { + if (s_logger.isTraceEnabled()) { + s_logger.trace("SSL: Running delegated task!"); + } + run.run(); + } + } else if (hsStatus == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) { + throw new IOException("NOT a handshaking!"); } - switch (handshakeStatus) { - case NEED_UNWRAP: - if (!doHandshakeUnwrap(socketChannel, sslEngine, peerAppData, peerNetData, appBufferSize)) { - return false; - } - break; - case NEED_WRAP: - if (!doHandshakeWrap(socketChannel, sslEngine, myAppData, myNetData, peerNetData, netBufferSize)) { - return false; - } - break; - case NEED_TASK: - Runnable task; - while ((task = sslEngine.getDelegatedTask()) != null) { - new Thread(task).run(); - } - break; - case FINISHED: - break; - case NOT_HANDSHAKING: - break; - default: - throw new IllegalStateException("Invalid SSL status: " + handshakeStatus); + if (engResult != null && engResult.getStatus() != SSLEngineResult.Status.OK) { + throw new IOException("Fail to handshake! " + engResult.getStatus()); } - handshakeStatus = sslEngine.getHandshakeStatus(); + if (engResult != null) + hsStatus = engResult.getHandshakeStatus(); + else + hsStatus = sslEngine.getHandshakeStatus(); } - return true; } } diff --git a/utils/src/main/java/com/cloud/utils/nio/NioClient.java b/utils/src/main/java/com/cloud/utils/nio/NioClient.java index dc4f670de12..d989f306c55 100644 --- a/utils/src/main/java/com/cloud/utils/nio/NioClient.java +++ b/utils/src/main/java/com/cloud/utils/nio/NioClient.java @@ -36,6 +36,7 @@ public class NioClient extends NioConnection { private static final Logger s_logger = Logger.getLogger(NioClient.class); protected String _host; + protected String _bindAddress; protected SocketChannel _clientConnection; public NioClient(final String name, final String host, final int port, final int workers, final HandlerFactory factory) { @@ -43,6 +44,10 @@ public class NioClient extends NioConnection { _host = host; } + public void setBindAddress(final String ipAddress) { + _bindAddress = ipAddress; + } + @Override protected void init() throws IOException { _selector = Selector.open(); @@ -50,25 +55,33 @@ public class NioClient extends NioConnection { try { _clientConnection = SocketChannel.open(); - + _clientConnection.configureBlocking(true); s_logger.info("Connecting to " + _host + ":" + _port); + + if (_bindAddress != null) { + s_logger.info("Binding outbound interface at " + _bindAddress); + + final InetSocketAddress bindAddr = new InetSocketAddress(_bindAddress, 0); + _clientConnection.socket().bind(bindAddr); + } + final InetSocketAddress peerAddr = new InetSocketAddress(_host, _port); _clientConnection.connect(peerAddr); - _clientConnection.configureBlocking(false); + + SSLEngine sslEngine = null; + // Begin SSL handshake in BLOCKING mode + _clientConnection.configureBlocking(true); final SSLContext sslContext = Link.initSSLContext(true); - SSLEngine sslEngine = sslContext.createSSLEngine(_host, _port); + sslEngine = sslContext.createSSLEngine(_host, _port); sslEngine.setUseClientMode(true); sslEngine.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslEngine.getEnabledProtocols())); - sslEngine.beginHandshake(); - if (!Link.doHandshake(_clientConnection, sslEngine, true)) { - s_logger.error("SSL Handshake failed while connecting to host: " + _host + " port: " + _port); - _selector.close(); - throw new IOException("SSL Handshake failed while connecting to host: " + _host + " port: " + _port); - } + + Link.doHandshake(_clientConnection, sslEngine, true); s_logger.info("SSL: Handshake done"); s_logger.info("Connected to " + _host + ":" + _port); + _clientConnection.configureBlocking(false); final Link link = new Link(peerAddr, this); link.setSSLEngine(sslEngine); final SelectionKey key = _clientConnection.register(_selector, SelectionKey.OP_READ); diff --git a/utils/src/main/java/com/cloud/utils/nio/NioConnection.java b/utils/src/main/java/com/cloud/utils/nio/NioConnection.java index 6fdb4736ac7..249f512d9c9 100644 --- a/utils/src/main/java/com/cloud/utils/nio/NioConnection.java +++ b/utils/src/main/java/com/cloud/utils/nio/NioConnection.java @@ -19,13 +19,8 @@ package com.cloud.utils.nio; -import com.cloud.utils.concurrency.NamedThreadFactory; -import com.cloud.utils.exception.NioConnectionException; -import org.apache.cloudstack.utils.security.SSLUtils; -import org.apache.log4j.Logger; +import static com.cloud.utils.AutoCloseableUtil.closeAutoCloseable; -import javax.net.ssl.SSLContext; -import javax.net.ssl.SSLEngine; import java.io.IOException; import java.net.ConnectException; import java.net.InetSocketAddress; @@ -49,7 +44,14 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; -import static com.cloud.utils.AutoCloseableUtil.closeAutoCloseable; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; + +import org.apache.cloudstack.utils.security.SSLUtils; +import org.apache.log4j.Logger; + +import com.cloud.utils.concurrency.NamedThreadFactory; +import com.cloud.utils.exception.NioConnectionException; /** * NioConnection abstracts the NIO socket operations. The Java implementation @@ -69,7 +71,6 @@ public abstract class NioConnection implements Callable { protected HandlerFactory _factory; protected String _name; protected ExecutorService _executor; - protected ExecutorService _sslHandshakeExecutor; public NioConnection(final String name, final int port, final int workers, final HandlerFactory factory) { _name = name; @@ -78,7 +79,6 @@ public abstract class NioConnection implements Callable { _port = port; _factory = factory; _executor = new ThreadPoolExecutor(workers, 5 * workers, 1, TimeUnit.DAYS, new LinkedBlockingQueue(), new NamedThreadFactory(name + "-Handler")); - _sslHandshakeExecutor = Executors.newCachedThreadPool(new NamedThreadFactory(name + "-SSLHandshakeHandler")); } public void start() throws NioConnectionException { @@ -185,9 +185,8 @@ public abstract class NioConnection implements Callable { protected void accept(final SelectionKey key) throws IOException { final ServerSocketChannel serverSocketChannel = (ServerSocketChannel)key.channel(); - final SocketChannel socketChannel = serverSocketChannel.accept(); - socketChannel.configureBlocking(false); + final SocketChannel socketChannel = serverSocketChannel.accept(); final Socket socket = socketChannel.socket(); socket.setKeepAlive(true); @@ -195,52 +194,43 @@ public abstract class NioConnection implements Callable { s_logger.trace("Connection accepted for " + socket); } - final SSLEngine sslEngine; + // Begin SSL handshake in BLOCKING mode + socketChannel.configureBlocking(true); + + SSLEngine sslEngine = null; try { final SSLContext sslContext = Link.initSSLContext(false); sslEngine = sslContext.createSSLEngine(); sslEngine.setUseClientMode(false); sslEngine.setNeedClientAuth(false); sslEngine.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslEngine.getEnabledProtocols())); - final NioConnection nioConnection = this; - _sslHandshakeExecutor.submit(new Runnable() { - @Override - public void run() { - _selector.wakeup(); - try { - sslEngine.beginHandshake(); - if (!Link.doHandshake(socketChannel, sslEngine, false)) { - throw new IOException("SSL handshake timed out with " + socketChannel.getRemoteAddress()); - } - if (s_logger.isTraceEnabled()) { - s_logger.trace("SSL: Handshake done"); - } - final InetSocketAddress saddr = (InetSocketAddress)socket.getRemoteSocketAddress(); - final Link link = new Link(saddr, nioConnection); - link.setSSLEngine(sslEngine); - link.setKey(socketChannel.register(key.selector(), SelectionKey.OP_READ, link)); - final Task task = _factory.create(Task.Type.CONNECT, link, null); - registerLink(saddr, link); - _executor.submit(task); - } catch (IOException e) { - if (s_logger.isTraceEnabled()) { - s_logger.trace("Connection closed due to failure: " + e.getMessage()); - } - closeAutoCloseable(socket, "accepting socket"); - closeAutoCloseable(socketChannel, "accepting socketChannel"); - } finally { - _selector.wakeup(); - } - } - }); + + Link.doHandshake(socketChannel, sslEngine, false); + } catch (final Exception e) { if (s_logger.isTraceEnabled()) { - s_logger.trace("Connection closed due to failure: " + e.getMessage()); + s_logger.trace("Socket " + socket + " closed on read. Probably -1 returned: " + e.getMessage()); } - closeAutoCloseable(socket, "accepting socket"); closeAutoCloseable(socketChannel, "accepting socketChannel"); - } finally { - _selector.wakeup(); + closeAutoCloseable(socket, "opened socket"); + return; + } + + if (s_logger.isTraceEnabled()) { + s_logger.trace("SSL: Handshake done"); + } + socketChannel.configureBlocking(false); + final InetSocketAddress saddr = (InetSocketAddress)socket.getRemoteSocketAddress(); + final Link link = new Link(saddr, this); + link.setSSLEngine(sslEngine); + link.setKey(socketChannel.register(key.selector(), SelectionKey.OP_READ, link)); + final Task task = _factory.create(Task.Type.CONNECT, link, null); + registerLink(saddr, link); + + try { + _executor.submit(task); + } catch (final Exception e) { + s_logger.warn("Exception occurred when submitting the task", e); } } diff --git a/utils/src/main/java/com/cloud/utils/nio/NioServer.java b/utils/src/main/java/com/cloud/utils/nio/NioServer.java index 13d5476cba0..539c2bb13d8 100644 --- a/utils/src/main/java/com/cloud/utils/nio/NioServer.java +++ b/utils/src/main/java/com/cloud/utils/nio/NioServer.java @@ -43,10 +43,6 @@ public class NioServer extends NioConnection { _links = new WeakHashMap(1024); } - public int getPort() { - return _serverSocket.socket().getLocalPort(); - } - @Override protected void init() throws IOException { _selector = SelectorProvider.provider().openSelector(); @@ -57,9 +53,9 @@ public class NioServer extends NioConnection { _localAddr = new InetSocketAddress(_port); _serverSocket.socket().bind(_localAddr); - _serverSocket.register(_selector, SelectionKey.OP_ACCEPT); + _serverSocket.register(_selector, SelectionKey.OP_ACCEPT, null); - s_logger.info("NioConnection started and listening on " + _serverSocket.socket().getLocalSocketAddress()); + s_logger.info("NioConnection started and listening on " + _localAddr.toString()); } @Override diff --git a/utils/src/test/java/com/cloud/utils/testcase/NioTest.java b/utils/src/test/java/com/cloud/utils/testcase/NioTest.java index 20c31e21b1a..d8510cfcac2 100644 --- a/utils/src/test/java/com/cloud/utils/testcase/NioTest.java +++ b/utils/src/test/java/com/cloud/utils/testcase/NioTest.java @@ -19,7 +19,14 @@ package com.cloud.utils.testcase; -import com.cloud.utils.concurrency.NamedThreadFactory; +import java.nio.channels.ClosedChannelException; +import java.util.Random; + +import junit.framework.TestCase; + +import org.apache.log4j.Logger; +import org.junit.Assert; + import com.cloud.utils.exception.NioConnectionException; import com.cloud.utils.nio.HandlerFactory; import com.cloud.utils.nio.Link; @@ -27,200 +34,131 @@ import com.cloud.utils.nio.NioClient; import com.cloud.utils.nio.NioServer; import com.cloud.utils.nio.Task; import com.cloud.utils.nio.Task.Type; -import org.apache.log4j.Logger; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import java.io.IOException; -import java.net.InetSocketAddress; -import java.nio.channels.ClosedChannelException; -import java.nio.channels.Selector; -import java.nio.channels.SocketChannel; -import java.util.ArrayList; -import java.util.List; -import java.util.Random; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; /** - * NioTest demonstrates that NioServer can function without getting its main IO - * loop blocked when an aggressive or malicious client connects to the server but - * fail to participate in SSL handshake. In this test, we run bunch of clients - * that send a known payload to the server, to which multiple malicious clients - * also try to connect and hang. - * A malicious client could cause denial-of-service if the server's main IO loop - * along with SSL handshake was blocking. A passing tests shows that NioServer - * can still function in case of connection load and that the main IO loop along - * with SSL handshake is non-blocking with some internal timeout mechanism. + * + * + * + * */ -public class NioTest { +public class NioTest extends TestCase { - private static final Logger LOGGER = Logger.getLogger(NioTest.class); + private static final Logger s_logger = Logger.getLogger(NioTest.class); - // Test should fail in due time instead of looping forever - private static final int TESTTIMEOUT = 300000; + private NioServer _server; + private NioClient _client; - final private int totalTestCount = 5; - private int completedTestCount = 0; + private Link _clientLink; - private NioServer server; - private List clients = new ArrayList<>(); - private List maliciousClients = new ArrayList<>(); - - private ExecutorService clientExecutor = Executors.newFixedThreadPool(totalTestCount, new NamedThreadFactory("NioClientHandler"));; - private ExecutorService maliciousExecutor = Executors.newFixedThreadPool(5*totalTestCount, new NamedThreadFactory("MaliciousNioClientHandler"));; - - private Random randomGenerator = new Random(); - private byte[] testBytes; + private int _testCount; + private int _completedCount; private boolean isTestsDone() { boolean result; synchronized (this) { - result = totalTestCount == completedTestCount; + result = _testCount == _completedCount; } return result; } + private void getOneMoreTest() { + synchronized (this) { + _testCount++; + } + } + private void oneMoreTestDone() { synchronized (this) { - completedTestCount++; + _completedCount++; } } - @Before + @Override public void setUp() { - LOGGER.info("Setting up Benchmark Test"); + s_logger.info("Test"); - completedTestCount = 0; - testBytes = new byte[1000000]; - randomGenerator.nextBytes(testBytes); + _testCount = 0; + _completedCount = 0; - server = new NioServer("NioTestServer", 0, 1, new NioTestServer()); + _server = new NioServer("NioTestServer", 7777, 5, new NioTestServer()); try { - server.start(); + _server.start(); } catch (final NioConnectionException e) { - Assert.fail(e.getMessage()); + fail(e.getMessage()); } - for (int i = 0; i < totalTestCount; i++) { - for (int j = 0; j < 4; j++) { - final NioClient maliciousClient = new NioMaliciousClient("NioMaliciousTestClient-" + i, "127.0.0.1", server.getPort(), 1, new NioMaliciousTestClient()); - maliciousClients.add(maliciousClient); - maliciousExecutor.submit(new ThreadedNioClient(maliciousClient)); + _client = new NioClient("NioTestServer", "127.0.0.1", 7777, 5, new NioTestClient()); + try { + _client.start(); + } catch (final NioConnectionException e) { + fail(e.getMessage()); + } + + while (_clientLink == null) { + try { + s_logger.debug("Link is not up! Waiting ..."); + Thread.sleep(1000); + } catch (final InterruptedException e) { + // TODO Auto-generated catch block + e.printStackTrace(); } - final NioClient client = new NioClient("NioTestClient-" + i, "127.0.0.1", server.getPort(), 1, new NioTestClient()); - clients.add(client); - clientExecutor.submit(new ThreadedNioClient(client)); } } - @After + @Override public void tearDown() { + while (!isTestsDone()) { + try { + s_logger.debug(_completedCount + "/" + _testCount + " tests done. Waiting for completion"); + Thread.sleep(1000); + } catch (final InterruptedException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + } stopClient(); stopServer(); } protected void stopClient() { - for (NioClient client : clients) { - client.stop(); - } - for (NioClient maliciousClient : maliciousClients) { - maliciousClient.stop(); - } - LOGGER.info("Clients stopped."); + _client.stop(); + s_logger.info("Client stopped."); } protected void stopServer() { - server.stop(); - LOGGER.info("Server stopped."); + _server.stop(); + s_logger.info("Server stopped."); } - @Test(timeout=TESTTIMEOUT) + protected void setClientLink(final Link link) { + _clientLink = link; + } + + Random randomGenerator = new Random(); + + byte[] _testBytes; + public void testConnection() { - while (!isTestsDone()) { - try { - LOGGER.debug(completedTestCount + "/" + totalTestCount + " tests done. Waiting for completion"); - Thread.sleep(1000); - } catch (final InterruptedException e) { - Assert.fail(e.getMessage()); - } + _testBytes = new byte[1000000]; + randomGenerator.nextBytes(_testBytes); + try { + getOneMoreTest(); + _clientLink.send(_testBytes); + s_logger.info("Client: Data sent"); + getOneMoreTest(); + _clientLink.send(_testBytes); + s_logger.info("Client: Data sent"); + } catch (final ClosedChannelException e) { + // TODO Auto-generated catch block + e.printStackTrace(); } - LOGGER.debug(completedTestCount + "/" + totalTestCount + " tests done."); } protected void doServerProcess(final byte[] data) { oneMoreTestDone(); - Assert.assertArrayEquals(testBytes, data); - LOGGER.info("Verify data received by server done."); - } - - public byte[] getTestBytes() { - return testBytes; - } - - public class ThreadedNioClient implements Runnable { - final private NioClient client; - ThreadedNioClient(final NioClient client) { - this.client = client; - } - - @Override - public void run() { - try { - client.start(); - } catch (NioConnectionException e) { - Assert.fail(e.getMessage()); - } - } - } - - public class NioMaliciousClient extends NioClient { - - public NioMaliciousClient(String name, String host, int port, int workers, HandlerFactory factory) { - super(name, host, port, workers, factory); - } - - @Override - protected void init() throws IOException { - _selector = Selector.open(); - try { - _clientConnection = SocketChannel.open(); - LOGGER.info("Connecting to " + _host + ":" + _port); - final InetSocketAddress peerAddr = new InetSocketAddress(_host, _port); - _clientConnection.connect(peerAddr); - // This is done on purpose, the malicious client would connect - // to the server and then do nothing, hence using a large sleep value - Thread.sleep(Long.MAX_VALUE); - } catch (final IOException e) { - _selector.close(); - throw e; - } catch (InterruptedException e) { - LOGGER.debug(e.getMessage()); - } - } - } - - public class NioMaliciousTestClient implements HandlerFactory { - - @Override - public Task create(final Type type, final Link link, final byte[] data) { - return new NioMaliciousTestClientHandler(type, link, data); - } - - public class NioMaliciousTestClientHandler extends Task { - - public NioMaliciousTestClientHandler(final Type type, final Link link, final byte[] data) { - super(type, link, data); - } - - @Override - public void doTask(final Task task) { - LOGGER.info("Malicious Client: Received task " + task.getType().toString()); - } - } + Assert.assertArrayEquals(_testBytes, data); + s_logger.info("Verify done."); } public class NioTestClient implements HandlerFactory { @@ -239,23 +177,18 @@ public class NioTest { @Override public void doTask(final Task task) { if (task.getType() == Task.Type.CONNECT) { - LOGGER.info("Client: Received CONNECT task"); - try { - LOGGER.info("Sending data to server"); - task.getLink().send(getTestBytes()); - } catch (ClosedChannelException e) { - LOGGER.error(e.getMessage()); - e.printStackTrace(); - } + s_logger.info("Client: Received CONNECT task"); + setClientLink(task.getLink()); } else if (task.getType() == Task.Type.DATA) { - LOGGER.info("Client: Received DATA task"); + s_logger.info("Client: Received DATA task"); } else if (task.getType() == Task.Type.DISCONNECT) { - LOGGER.info("Client: Received DISCONNECT task"); + s_logger.info("Client: Received DISCONNECT task"); stopClient(); } else if (task.getType() == Task.Type.OTHER) { - LOGGER.info("Client: Received OTHER task"); + s_logger.info("Client: Received OTHER task"); } } + } } @@ -275,15 +208,15 @@ public class NioTest { @Override public void doTask(final Task task) { if (task.getType() == Task.Type.CONNECT) { - LOGGER.info("Server: Received CONNECT task"); + s_logger.info("Server: Received CONNECT task"); } else if (task.getType() == Task.Type.DATA) { - LOGGER.info("Server: Received DATA task"); + s_logger.info("Server: Received DATA task"); doServerProcess(task.getData()); } else if (task.getType() == Task.Type.DISCONNECT) { - LOGGER.info("Server: Received DISCONNECT task"); + s_logger.info("Server: Received DISCONNECT task"); stopServer(); } else if (task.getType() == Task.Type.OTHER) { - LOGGER.info("Server: Received OTHER task"); + s_logger.info("Server: Received OTHER task"); } }