From d925aa32e57dd0efc1e7963eb84aae2e17377c71 Mon Sep 17 00:00:00 2001 From: Sheng Yang Date: Fri, 15 Jul 2011 18:56:02 -0700 Subject: [PATCH] bug 10714: Implement packet fragmentation Also add an simple nio unit test. status 10714: resolved fixed --- utils/src/com/cloud/utils/nio/Link.java | 126 +++++++---- .../com/cloud/utils/testcase/NioTest.java | 204 ++++++++++++++++++ 2 files changed, 293 insertions(+), 37 deletions(-) create mode 100644 utils/test/com/cloud/utils/testcase/NioTest.java diff --git a/utils/src/com/cloud/utils/nio/Link.java b/utils/src/com/cloud/utils/nio/Link.java index ea0c003d26d..a893a26af07 100755 --- a/utils/src/com/cloud/utils/nio/Link.java +++ b/utils/src/com/cloud/utils/nio/Link.java @@ -56,8 +56,10 @@ public class Link { private SelectionKey _key; private final ConcurrentLinkedQueue _writeQueue; private ByteBuffer _readBuffer; + private ByteBuffer _plaintextBuffer; private Object _attach; - private boolean _readSize; + private boolean _readHeader; + private boolean _gotFollowingPacket; private SSLEngine _sslEngine; @@ -68,7 +70,8 @@ public class Link { _attach = null; _key = null; _writeQueue = new ConcurrentLinkedQueue(); - _readSize = true; + _readHeader = true; + _gotFollowingPacket = false; } public Link (Link link) { @@ -135,39 +138,57 @@ public class Link { */ private static void doWrite(SocketChannel ch, ByteBuffer[] buffers, SSLEngine sslEngine) throws IOException { - ByteBuffer pkgBuf; SSLSession sslSession = sslEngine.getSession(); + ByteBuffer pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40); SSLEngineResult engResult; ByteBuffer headBuf = ByteBuffer.allocate(4); - - pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40); - engResult = sslEngine.wrap(buffers, pkgBuf); - if (engResult.getHandshakeStatus() != HandshakeStatus.FINISHED && - engResult.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING && - engResult.getStatus() != SSLEngineResult.Status.OK) { - throw new IOException("SSL: SSLEngine return bad result! " + engResult); + + int totalLen = 0; + for (ByteBuffer buffer : buffers) { + totalLen += buffer.limit(); } - int dataRemaining = pkgBuf.position(); - int headRemaining = 4; - pkgBuf.flip(); - headBuf.putInt(dataRemaining); - headBuf.flip(); + int processedLen = 0; + while (processedLen < totalLen) { + headBuf.clear(); + pkgBuf.clear(); + engResult = sslEngine.wrap(buffers, pkgBuf); + if (engResult.getHandshakeStatus() != HandshakeStatus.FINISHED && + engResult.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING && + engResult.getStatus() != SSLEngineResult.Status.OK) { + throw new IOException("SSL: SSLEngine return bad result! " + engResult); + } + + processedLen = 0; + for (ByteBuffer buffer : buffers) { + processedLen += buffer.position(); + } - while (headRemaining > 0) { - if (s_logger.isTraceEnabled()) { - s_logger.trace("Writing Header " + headRemaining); + int dataRemaining = pkgBuf.position(); + int header = dataRemaining; + int headRemaining = 4; + pkgBuf.flip(); + if (processedLen < totalLen) { + header = header | HEADER_FLAG_FOLLOWING; } - long count = ch.write(headBuf); - headRemaining -= count; - } - while (dataRemaining > 0) { - if (s_logger.isTraceEnabled()) { - s_logger.trace("Writing Data " + dataRemaining); + headBuf.putInt(header); + headBuf.flip(); + + while (headRemaining > 0) { + if (s_logger.isTraceEnabled()) { + s_logger.trace("Writing Header " + headRemaining); + } + long count = ch.write(headBuf); + headRemaining -= count; + } + while (dataRemaining > 0) { + if (s_logger.isTraceEnabled()) { + s_logger.trace("Writing Data " + dataRemaining); + } + long count = ch.write(pkgBuf); + dataRemaining -= count; } - long count = ch.write(pkgBuf); - dataRemaining -= count; } } @@ -186,8 +207,12 @@ public class Link { } } + /* SSL has limitation of 16k, we may need to split packets. 18000 is 16k + some extra SSL informations */ + protected static final int MAX_SIZE_PER_PACKET = 18000; + protected static final int HEADER_FLAG_FOLLOWING = 0x10000; + public byte[] read(SocketChannel ch) throws IOException { - if (_readSize) { // Start of a packet + if (_readHeader) { // Start of a packet if (_readBuffer.position() == 0) { _readBuffer.limit(4); } @@ -201,16 +226,28 @@ public class Link { return null; } _readBuffer.flip(); - int readSize = _readBuffer.getInt(); + int header = _readBuffer.getInt(); + int readSize = (short)header; if (s_logger.isTraceEnabled()) { s_logger.trace("Packet length is " + readSize); } - if (readSize > 65535) { - throw new IOException("Packet is too big! Discard it. Size: " + readSize); + if (readSize > MAX_SIZE_PER_PACKET) { + throw new IOException("Wrong packet size: " + readSize); } + + if (!_gotFollowingPacket) { + _plaintextBuffer = ByteBuffer.allocate(2000); + } + + if ((header & HEADER_FLAG_FOLLOWING) != 0) { + _gotFollowingPacket = true; + } else { + _gotFollowingPacket = false; + } + _readBuffer.clear(); - _readSize = false; + _readHeader = false; if (_readBuffer.capacity() < readSize) { if (s_logger.isTraceEnabled()) { @@ -239,7 +276,6 @@ public class Link { SSLSession sslSession = _sslEngine.getSession(); SSLEngineResult engResult; - //TODO may need to adjust the buffer size appBuf = ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40); engResult = _sslEngine.unwrap(_readBuffer, appBuf); if (engResult.getHandshakeStatus() != HandshakeStatus.FINISHED && @@ -248,17 +284,33 @@ public class Link { throw new IOException("SSL: SSLEngine return bad result! " + engResult); } - byte[] result = new byte[appBuf.position()]; appBuf.flip(); - appBuf.get(result); + if (_plaintextBuffer.remaining() < appBuf.limit()) { + // We need to expand _plaintextBuffer for more data + ByteBuffer newBuffer = ByteBuffer.allocate(_plaintextBuffer.capacity() + appBuf.limit() * 5); + _plaintextBuffer.flip(); + newBuffer.put(_plaintextBuffer); + _plaintextBuffer = newBuffer; + } + _plaintextBuffer.put(appBuf); _readBuffer.clear(); - _readSize = true; + _readHeader = true; if (s_logger.isTraceEnabled()) { - s_logger.trace("Done with packet: " + result.length); + s_logger.trace("Done with packet: " + appBuf.limit()); } - return result; + if (!_gotFollowingPacket) { + _plaintextBuffer.flip(); + byte[] result = new byte[_plaintextBuffer.limit()]; + _plaintextBuffer.get(result); + return result; + } else { + if (s_logger.isTraceEnabled()) { + s_logger.trace("Waiting for more packets"); + } + return null; + } } public void send(byte[] data) throws ClosedChannelException { diff --git a/utils/test/com/cloud/utils/testcase/NioTest.java b/utils/test/com/cloud/utils/testcase/NioTest.java new file mode 100644 index 00000000000..72758f27ae8 --- /dev/null +++ b/utils/test/com/cloud/utils/testcase/NioTest.java @@ -0,0 +1,204 @@ +package com.cloud.utils.testcase; + +import java.nio.channels.ClosedChannelException; +import java.util.Random; + +import org.apache.log4j.Logger; + +import com.cloud.utils.nio.HandlerFactory; +import com.cloud.utils.nio.Link; +import com.cloud.utils.nio.NioClient; +import com.cloud.utils.nio.NioServer; +import com.cloud.utils.nio.Task; +import com.cloud.utils.nio.Task.Type; + +import org.junit.Assert; +import junit.framework.TestCase; + +/** + * Copyright (C) 2010 Cloud.com, Inc. All rights reserved. + * + * This software is licensed under the GNU General Public License v3 or later. + * + * It is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or any later version. + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +public class NioTest extends TestCase { + + private static final Logger s_logger = Logger.getLogger(NioTest.class); + + private NioServer _server; + private NioClient _client; + + private Link _clientLink; + + private int _testCount; + private int _completedCount; + + private boolean isTestsDone() { + boolean result; + synchronized(this) { + result = (_testCount == _completedCount); + } + return result; + } + + private void getOneMoreTest() { + synchronized(this) { + _testCount ++; + } + } + private void oneMoreTestDone() { + synchronized(this) { + _completedCount ++; + } + } + + public void setUp() { + s_logger.info("Test"); + + _testCount = 0; + _completedCount = 0; + + _server = new NioServer("NioTestServer", 7777, 5, new NioTestServer()); + _server.start(); + + _client = new NioClient("NioTestServer", "127.0.0.1", 7777, 5, new NioTestClient()); + _client.start(); + + while (_clientLink == null) { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + } + } + + public void tearDown() { + while (!isTestsDone()) { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + } + stopClient(); + stopServer(); + } + + protected void stopClient(){ + _client.stop(); + s_logger.info("Client stopped."); + } + + protected void stopServer(){ + _server.stop(); + s_logger.info("Server stopped."); + } + + protected void setClientLink(Link link) + { + _clientLink = link; + } + + Random randomGenerator = new Random(); + + byte[] _testBytes; + + public void testConnection() { + _testBytes = new byte[1000000]; + randomGenerator.nextBytes(_testBytes); + try { + getOneMoreTest(); + _clientLink.send(_testBytes); + s_logger.info("Client: Data sent"); + getOneMoreTest(); + _clientLink.send(_testBytes); + s_logger.info("Client: Data sent"); + } catch (ClosedChannelException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + } + + protected void doServerProcess(byte[] data) { + oneMoreTestDone(); + Assert.assertArrayEquals(_testBytes, data); + s_logger.info("Verify done."); + } + + public class NioTestClient implements HandlerFactory { + + @Override + public Task create(Type type, Link link, byte[] data) { + return new NioTestClientHandler(type, link, data); + } + + public class NioTestClientHandler extends Task { + + public NioTestClientHandler(Type type, Link link, byte[] data) { + super(type, link, data); + } + + @Override + public void doTask(final Task task) { + if (task.getType() == Task.Type.CONNECT) { + s_logger.info("Client: Received CONNECT task"); + setClientLink(task.getLink()); + } else if (task.getType() == Task.Type.DATA) { + s_logger.info("Client: Received DATA task"); + } else if (task.getType() == Task.Type.DISCONNECT) { + s_logger.info("Client: Received DISCONNECT task"); + stopClient(); + } else if (task.getType() == Task.Type.OTHER) { + s_logger.info("Client: Received OTHER task"); + } + } + + } + } + + public class NioTestServer implements HandlerFactory { + + @Override + public Task create(Type type, Link link, byte[] data) { + return new NioTestServerHandler(type, link, data); + } + + public class NioTestServerHandler extends Task { + + public NioTestServerHandler(Type type, Link link, byte[] data) { + super(type, link, data); + } + + @Override + public void doTask(final Task task) { + if (task.getType() == Task.Type.CONNECT) { + s_logger.info("Server: Received CONNECT task"); + } else if (task.getType() == Task.Type.DATA) { + s_logger.info("Server: Received DATA task"); + doServerProcess(task.getData()); + } else if (task.getType() == Task.Type.DISCONNECT) { + s_logger.info("Server: Received DISCONNECT task"); + stopServer(); + } else if (task.getType() == Task.Type.OTHER) { + s_logger.info("Server: Received OTHER task"); + } + } + + } + } +}