bug 10714: Implement packet fragmentation

Also add an simple nio unit test.

status 10714: resolved fixed
This commit is contained in:
Sheng Yang 2011-07-15 18:56:02 -07:00
parent bb60543fbf
commit 15bf729927
2 changed files with 293 additions and 37 deletions

View File

@ -56,8 +56,10 @@ public class Link {
private SelectionKey _key;
private final ConcurrentLinkedQueue<ByteBuffer[]> _writeQueue;
private ByteBuffer _readBuffer;
private ByteBuffer _plaintextBuffer;
private Object _attach;
private boolean _readSize;
private boolean _readHeader;
private boolean _gotFollowingPacket;
private SSLEngine _sslEngine;
@ -68,7 +70,8 @@ public class Link {
_attach = null;
_key = null;
_writeQueue = new ConcurrentLinkedQueue<ByteBuffer[]>();
_readSize = true;
_readHeader = true;
_gotFollowingPacket = false;
}
public Link (Link link) {
@ -135,39 +138,57 @@ public class Link {
*/
private static void doWrite(SocketChannel ch, ByteBuffer[] buffers, SSLEngine sslEngine) throws IOException {
ByteBuffer pkgBuf;
SSLSession sslSession = sslEngine.getSession();
ByteBuffer pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
SSLEngineResult engResult;
ByteBuffer headBuf = ByteBuffer.allocate(4);
pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
engResult = sslEngine.wrap(buffers, pkgBuf);
if (engResult.getHandshakeStatus() != HandshakeStatus.FINISHED &&
engResult.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING &&
engResult.getStatus() != SSLEngineResult.Status.OK) {
throw new IOException("SSL: SSLEngine return bad result! " + engResult);
int totalLen = 0;
for (ByteBuffer buffer : buffers) {
totalLen += buffer.limit();
}
int dataRemaining = pkgBuf.position();
int headRemaining = 4;
pkgBuf.flip();
headBuf.putInt(dataRemaining);
headBuf.flip();
int processedLen = 0;
while (processedLen < totalLen) {
headBuf.clear();
pkgBuf.clear();
engResult = sslEngine.wrap(buffers, pkgBuf);
if (engResult.getHandshakeStatus() != HandshakeStatus.FINISHED &&
engResult.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING &&
engResult.getStatus() != SSLEngineResult.Status.OK) {
throw new IOException("SSL: SSLEngine return bad result! " + engResult);
}
processedLen = 0;
for (ByteBuffer buffer : buffers) {
processedLen += buffer.position();
}
while (headRemaining > 0) {
if (s_logger.isTraceEnabled()) {
s_logger.trace("Writing Header " + headRemaining);
int dataRemaining = pkgBuf.position();
int header = dataRemaining;
int headRemaining = 4;
pkgBuf.flip();
if (processedLen < totalLen) {
header = header | HEADER_FLAG_FOLLOWING;
}
long count = ch.write(headBuf);
headRemaining -= count;
}
while (dataRemaining > 0) {
if (s_logger.isTraceEnabled()) {
s_logger.trace("Writing Data " + dataRemaining);
headBuf.putInt(header);
headBuf.flip();
while (headRemaining > 0) {
if (s_logger.isTraceEnabled()) {
s_logger.trace("Writing Header " + headRemaining);
}
long count = ch.write(headBuf);
headRemaining -= count;
}
while (dataRemaining > 0) {
if (s_logger.isTraceEnabled()) {
s_logger.trace("Writing Data " + dataRemaining);
}
long count = ch.write(pkgBuf);
dataRemaining -= count;
}
long count = ch.write(pkgBuf);
dataRemaining -= count;
}
}
@ -186,8 +207,12 @@ public class Link {
}
}
/* SSL has limitation of 16k, we may need to split packets. 18000 is 16k + some extra SSL informations */
protected static final int MAX_SIZE_PER_PACKET = 18000;
protected static final int HEADER_FLAG_FOLLOWING = 0x10000;
public byte[] read(SocketChannel ch) throws IOException {
if (_readSize) { // Start of a packet
if (_readHeader) { // Start of a packet
if (_readBuffer.position() == 0) {
_readBuffer.limit(4);
}
@ -201,16 +226,28 @@ public class Link {
return null;
}
_readBuffer.flip();
int readSize = _readBuffer.getInt();
int header = _readBuffer.getInt();
int readSize = (short)header;
if (s_logger.isTraceEnabled()) {
s_logger.trace("Packet length is " + readSize);
}
if (readSize > 65535) {
throw new IOException("Packet is too big! Discard it. Size: " + readSize);
if (readSize > MAX_SIZE_PER_PACKET) {
throw new IOException("Wrong packet size: " + readSize);
}
if (!_gotFollowingPacket) {
_plaintextBuffer = ByteBuffer.allocate(2000);
}
if ((header & HEADER_FLAG_FOLLOWING) != 0) {
_gotFollowingPacket = true;
} else {
_gotFollowingPacket = false;
}
_readBuffer.clear();
_readSize = false;
_readHeader = false;
if (_readBuffer.capacity() < readSize) {
if (s_logger.isTraceEnabled()) {
@ -239,7 +276,6 @@ public class Link {
SSLSession sslSession = _sslEngine.getSession();
SSLEngineResult engResult;
//TODO may need to adjust the buffer size
appBuf = ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40);
engResult = _sslEngine.unwrap(_readBuffer, appBuf);
if (engResult.getHandshakeStatus() != HandshakeStatus.FINISHED &&
@ -248,17 +284,33 @@ public class Link {
throw new IOException("SSL: SSLEngine return bad result! " + engResult);
}
byte[] result = new byte[appBuf.position()];
appBuf.flip();
appBuf.get(result);
if (_plaintextBuffer.remaining() < appBuf.limit()) {
// We need to expand _plaintextBuffer for more data
ByteBuffer newBuffer = ByteBuffer.allocate(_plaintextBuffer.capacity() + appBuf.limit() * 5);
_plaintextBuffer.flip();
newBuffer.put(_plaintextBuffer);
_plaintextBuffer = newBuffer;
}
_plaintextBuffer.put(appBuf);
_readBuffer.clear();
_readSize = true;
_readHeader = true;
if (s_logger.isTraceEnabled()) {
s_logger.trace("Done with packet: " + result.length);
s_logger.trace("Done with packet: " + appBuf.limit());
}
return result;
if (!_gotFollowingPacket) {
_plaintextBuffer.flip();
byte[] result = new byte[_plaintextBuffer.limit()];
_plaintextBuffer.get(result);
return result;
} else {
if (s_logger.isTraceEnabled()) {
s_logger.trace("Waiting for more packets");
}
return null;
}
}
public void send(byte[] data) throws ClosedChannelException {

View File

@ -0,0 +1,204 @@
package com.cloud.utils.testcase;
import java.nio.channels.ClosedChannelException;
import java.util.Random;
import org.apache.log4j.Logger;
import com.cloud.utils.nio.HandlerFactory;
import com.cloud.utils.nio.Link;
import com.cloud.utils.nio.NioClient;
import com.cloud.utils.nio.NioServer;
import com.cloud.utils.nio.Task;
import com.cloud.utils.nio.Task.Type;
import org.junit.Assert;
import junit.framework.TestCase;
/**
* Copyright (C) 2010 Cloud.com, Inc. All rights reserved.
*
* This software is licensed under the GNU General Public License v3 or later.
*
* It is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or any later version.
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
*/
public class NioTest extends TestCase {
private static final Logger s_logger = Logger.getLogger(NioTest.class);
private NioServer _server;
private NioClient _client;
private Link _clientLink;
private int _testCount;
private int _completedCount;
private boolean isTestsDone() {
boolean result;
synchronized(this) {
result = (_testCount == _completedCount);
}
return result;
}
private void getOneMoreTest() {
synchronized(this) {
_testCount ++;
}
}
private void oneMoreTestDone() {
synchronized(this) {
_completedCount ++;
}
}
public void setUp() {
s_logger.info("Test");
_testCount = 0;
_completedCount = 0;
_server = new NioServer("NioTestServer", 7777, 5, new NioTestServer());
_server.start();
_client = new NioClient("NioTestServer", "127.0.0.1", 7777, 5, new NioTestClient());
_client.start();
while (_clientLink == null) {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
public void tearDown() {
while (!isTestsDone()) {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
stopClient();
stopServer();
}
protected void stopClient(){
_client.stop();
s_logger.info("Client stopped.");
}
protected void stopServer(){
_server.stop();
s_logger.info("Server stopped.");
}
protected void setClientLink(Link link)
{
_clientLink = link;
}
Random randomGenerator = new Random();
byte[] _testBytes;
public void testConnection() {
_testBytes = new byte[1000000];
randomGenerator.nextBytes(_testBytes);
try {
getOneMoreTest();
_clientLink.send(_testBytes);
s_logger.info("Client: Data sent");
getOneMoreTest();
_clientLink.send(_testBytes);
s_logger.info("Client: Data sent");
} catch (ClosedChannelException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
protected void doServerProcess(byte[] data) {
oneMoreTestDone();
Assert.assertArrayEquals(_testBytes, data);
s_logger.info("Verify done.");
}
public class NioTestClient implements HandlerFactory {
@Override
public Task create(Type type, Link link, byte[] data) {
return new NioTestClientHandler(type, link, data);
}
public class NioTestClientHandler extends Task {
public NioTestClientHandler(Type type, Link link, byte[] data) {
super(type, link, data);
}
@Override
public void doTask(final Task task) {
if (task.getType() == Task.Type.CONNECT) {
s_logger.info("Client: Received CONNECT task");
setClientLink(task.getLink());
} else if (task.getType() == Task.Type.DATA) {
s_logger.info("Client: Received DATA task");
} else if (task.getType() == Task.Type.DISCONNECT) {
s_logger.info("Client: Received DISCONNECT task");
stopClient();
} else if (task.getType() == Task.Type.OTHER) {
s_logger.info("Client: Received OTHER task");
}
}
}
}
public class NioTestServer implements HandlerFactory {
@Override
public Task create(Type type, Link link, byte[] data) {
return new NioTestServerHandler(type, link, data);
}
public class NioTestServerHandler extends Task {
public NioTestServerHandler(Type type, Link link, byte[] data) {
super(type, link, data);
}
@Override
public void doTask(final Task task) {
if (task.getType() == Task.Type.CONNECT) {
s_logger.info("Server: Received CONNECT task");
} else if (task.getType() == Task.Type.DATA) {
s_logger.info("Server: Received DATA task");
doServerProcess(task.getData());
} else if (task.getType() == Task.Type.DISCONNECT) {
s_logger.info("Server: Received DISCONNECT task");
stopServer();
} else if (task.getType() == Task.Type.OTHER) {
s_logger.info("Server: Received OTHER task");
}
}
}
}
}