diff --git a/server/src/com/cloud/network/security/LocalSecurityGroupWorkQueue.java b/server/src/com/cloud/network/security/LocalSecurityGroupWorkQueue.java index 42cd42f430a..0486927a1f8 100644 --- a/server/src/com/cloud/network/security/LocalSecurityGroupWorkQueue.java +++ b/server/src/com/cloud/network/security/LocalSecurityGroupWorkQueue.java @@ -18,6 +18,8 @@ package com.cloud.network.security; import java.util.ArrayList; +import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Set; import java.util.TreeSet; @@ -37,7 +39,9 @@ import com.cloud.network.security.SecurityGroupWork.Step; public class LocalSecurityGroupWorkQueue implements SecurityGroupWorkQueue { protected static Logger s_logger = Logger.getLogger(LocalSecurityGroupWorkQueue.class); - protected TreeSet _currentWork = new TreeSet(); + protected Set _currentWork = new HashSet(); + //protected Set _currentWork = new TreeSet(); + private final ReentrantLock _lock = new ReentrantLock(); private final Condition _notEmpty = _lock.newCondition(); private final AtomicInteger _count = new AtomicInteger(0); @@ -83,6 +87,20 @@ public class LocalSecurityGroupWorkQueue implements SecurityGroupWorkQueue { public int compareTo(LocalSecurityGroupWork o) { return this._instanceId.compareTo(o.getInstanceId()); } + + @Override + public boolean equals(Object obj) { + if (obj instanceof LocalSecurityGroupWork) { + LocalSecurityGroupWork other = (LocalSecurityGroupWork)obj; + return this.getInstanceId().longValue()==other.getInstanceId().longValue(); + } + return false; + } + + @Override + public int hashCode() { + return getInstanceId().hashCode(); + } } @@ -124,7 +142,7 @@ public class LocalSecurityGroupWorkQueue implements SecurityGroupWorkQueue { @Override - public List getWork(int numberOfWorkItems) { + public List getWork(int numberOfWorkItems) throws InterruptedException { List work = new ArrayList(numberOfWorkItems); _lock.lock(); int i = 0; @@ -133,16 +151,14 @@ public class LocalSecurityGroupWorkQueue implements SecurityGroupWorkQueue { _notEmpty.await(); } int n = Math.min(numberOfWorkItems, _count.get()); + Iterator iter = _currentWork.iterator(); while (i < n ) { - SecurityGroupWork w = _currentWork.first(); + SecurityGroupWork w = iter.next(); w.setStep(Step.Processing); work.add(w); - _currentWork.remove(w); + iter.remove(); ++i; } - } catch (InterruptedException e) { - // TODO Auto-generated catch block - e.printStackTrace(); } finally { int c = _count.addAndGet(-i); if (c > 0) diff --git a/server/src/com/cloud/network/security/SecurityGroupWorkQueue.java b/server/src/com/cloud/network/security/SecurityGroupWorkQueue.java index 0fb0773d558..29a592731c2 100644 --- a/server/src/com/cloud/network/security/SecurityGroupWorkQueue.java +++ b/server/src/com/cloud/network/security/SecurityGroupWorkQueue.java @@ -32,7 +32,7 @@ public interface SecurityGroupWorkQueue { int submitWorkForVms(Set vmIds); - List getWork(int numberOfWorkItems); + List getWork(int numberOfWorkItems) throws InterruptedException; int size(); diff --git a/server/test/com/cloud/network/security/SecurityGroupManagerImpl2Test.java b/server/test/com/cloud/network/security/SecurityGroupManagerImpl2Test.java index e36dadaad31..c81640342cc 100644 --- a/server/test/com/cloud/network/security/SecurityGroupManagerImpl2Test.java +++ b/server/test/com/cloud/network/security/SecurityGroupManagerImpl2Test.java @@ -65,7 +65,7 @@ public class SecurityGroupManagerImpl2Test extends TestCase { } public void testSchedule() { - final int numVms = 100000; + final int numVms = 10000; System.out.println("Starting"); ComponentLocator locator = ComponentLocator.getCurrentLocator(); SecurityGroupManagerImpl2 sgMgr = ComponentLocator.inject(SecurityGroupManagerImpl2.class); diff --git a/server/test/com/cloud/network/security/SecurityGroupQueueTest.java b/server/test/com/cloud/network/security/SecurityGroupQueueTest.java index 8cf0da187f7..e84f65dcf0c 100644 --- a/server/test/com/cloud/network/security/SecurityGroupQueueTest.java +++ b/server/test/com/cloud/network/security/SecurityGroupQueueTest.java @@ -1,9 +1,12 @@ package com.cloud.network.security; +import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; +import com.cloud.utils.Profiler; + import junit.framework.TestCase; public class SecurityGroupQueueTest extends TestCase { @@ -44,7 +47,13 @@ public class SecurityGroupQueueTest extends TestCase { } public void run() { - List result = queue.getWork(_numJobsToDequeue); + List result = new ArrayList(); + try { + result = queue.getWork(_numJobsToDequeue); + } catch (InterruptedException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } this._numJobsDequeued = result.size(); } @@ -79,18 +88,18 @@ public class SecurityGroupQueueTest extends TestCase { } } System.out.println("Num Vms= " + numProducers + " Queue size = " + queue.size()); - assert(numProducers == queue.size()); + assertEquals(numProducers, queue.size()); } - public void testNumJobsEqToNumVms2() { + protected void testNumJobsEqToNumVms2(int numProducers, int maxVmId) { queue.clear(); - final int numProducers = 50; Thread [] pThreads = new Thread[numProducers]; Producer [] producers = new Producer[numProducers]; int numProduced = 0; - int maxVmId = 10000; + Profiler p = new Profiler(); + p.start(); for (int i=0; i < numProducers; i++) { producers[i] = new Producer(maxVmId); pThreads[i] = new Thread(producers[i]); @@ -104,16 +113,22 @@ public class SecurityGroupQueueTest extends TestCase { ie.printStackTrace(); } } - System.out.println("Num Vms= " + maxVmId + " Queue size = " + queue.size()); - assert(maxVmId == queue.size()); + p.stop(); + System.out.println("Num Vms= " + maxVmId + " Queue size = " + queue.size() + " time=" + p.getDuration() + " ms"); + assertEquals(maxVmId, queue.size()); } - public void testDequeueOneJob() { - queue.clear(); + public void testNumJobsEqToNumVms3() { + testNumJobsEqToNumVms2(50,20000); + testNumJobsEqToNumVms2(400,5000); + testNumJobsEqToNumVms2(1,1); + testNumJobsEqToNumVms2(1,1000000); + testNumJobsEqToNumVms2(1000,1); - final int numProducers = 2; - final int numConsumers = 5; - final int maxVmId = 200; + } + + protected void _testDequeueOneJob(final int numConsumers, final int numProducers, final int maxVmId) { + queue.clear(); Thread [] pThreads = new Thread[numProducers]; Thread [] cThreads = new Thread[numConsumers]; @@ -141,12 +156,14 @@ public class SecurityGroupQueueTest extends TestCase { ie.printStackTrace(); } } -// try { -// Thread.sleep(2000); -// } catch (InterruptedException e) { -// // TODO Auto-generated catch block -// e.printStackTrace(); -// } + for (int i=0; i < numProducers ; i++) { + try { + pThreads[i].join(); + } catch (InterruptedException ie){ + ie.printStackTrace(); + } + } + int totalDequeued = 0; for (int i=0; i < numConsumers; i++) { //System.out.println("Consumer " + i + " ask to dequeue " + consumers[i].getNumJobsToDequeue() + ", dequeued " + consumers[i].getNumJobsDequeued()); @@ -158,8 +175,29 @@ public class SecurityGroupQueueTest extends TestCase { totalQueued += producers[i].getNewWork(); } System.out.println("Total jobs dequeued = " + totalDequeued + ", num queued=" + totalQueued + " queue current size=" + queue.size()); - assert(totalDequeued == numConsumers); - assert(totalQueued - totalDequeued == queue.size()); + assertEquals(totalDequeued, numConsumers); + assertEquals(totalQueued - totalDequeued, queue.size()); + } + + public void testDequeueOneJobAgain() { + _testDequeueOneJob(10,10,1000); + int queueSize = queue.size(); + Thread cThread = new Thread(new Consumer(1)); + cThread.start(); + try { + cThread.join(); + } catch (InterruptedException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + assertEquals(queue.size(), queueSize-1); + } + + public void testDequeueOneJob() { + _testDequeueOneJob(10,10,1000); + _testDequeueOneJob(1,10,1000); + _testDequeueOneJob(10,1,1000); + _testDequeueOneJob(10,1,10); } }