From d3e7b43ee9398690f094ae2eb2a1a1957f7f3b07 Mon Sep 17 00:00:00 2001 From: Likitha Shetty Date: Thu, 17 Jul 2014 13:04:58 +0530 Subject: [PATCH] CLOUDSTACK-7119. [VMware] Don't allow VM reset when VM has snapshots. --- server/src/com/cloud/vm/UserVmManagerImpl.java | 6 ++++++ server/test/com/cloud/vm/UserVmManagerTest.java | 17 +++++++++++++++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/server/src/com/cloud/vm/UserVmManagerImpl.java b/server/src/com/cloud/vm/UserVmManagerImpl.java index dac4acf1826..d0bc1867b87 100755 --- a/server/src/com/cloud/vm/UserVmManagerImpl.java +++ b/server/src/com/cloud/vm/UserVmManagerImpl.java @@ -4627,6 +4627,12 @@ public class UserVmManagerImpl extends ManagerBase implements UserVmManager, Vir templateId = vm.getIsoId(); } + // If target VM has associated VM snapshots then don't allow restore of VM + List vmSnapshots = _vmSnapshotDao.findByVm(vmId); + if (vmSnapshots.size() > 0 && vm.getHypervisorType() == HypervisorType.VMware) { + throw new InvalidParameterValueException("Unable to restore VM, please specify a VM that does not have VM snapshots"); + } + VMTemplateVO template = null; //newTemplateId can be either template or ISO id. In the following snippet based on the vm deployment (from template or ISO) it is handled accordingly if (newTemplateId != null) { diff --git a/server/test/com/cloud/vm/UserVmManagerTest.java b/server/test/com/cloud/vm/UserVmManagerTest.java index 3188d047923..aed468d3b8b 100755 --- a/server/test/com/cloud/vm/UserVmManagerTest.java +++ b/server/test/com/cloud/vm/UserVmManagerTest.java @@ -27,6 +27,7 @@ import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -86,6 +87,8 @@ import com.cloud.utils.db.EntityManager; import com.cloud.utils.exception.CloudRuntimeException; import com.cloud.vm.dao.UserVmDao; import com.cloud.vm.dao.VMInstanceDao; +import com.cloud.vm.snapshot.VMSnapshotVO; +import com.cloud.vm.snapshot.dao.VMSnapshotDao; public class UserVmManagerTest { @@ -149,6 +152,8 @@ public class UserVmManagerTest { PrimaryDataStoreDao _storagePoolDao; @Mock UsageEventDao _usageEventDao; + @Mock + VMSnapshotDao _vmSnapshotDao; @Before public void setup() { @@ -172,6 +177,7 @@ public class UserVmManagerTest { _userVmMgr._scaleRetry = 2; _userVmMgr._entityMgr = _entityMgr; _userVmMgr._storagePoolDao = _storagePoolDao; + _userVmMgr._vmSnapshotDao = _vmSnapshotDao; doReturn(3L).when(_account).getId(); doReturn(8L).when(_vmMock).getAccountId(); @@ -181,6 +187,9 @@ public class UserVmManagerTest { when(_vmMock.getId()).thenReturn(314L); when(_vmInstance.getId()).thenReturn(1L); when(_vmInstance.getServiceOfferingId()).thenReturn(2L); + List mockList = mock(List.class); + when(_vmSnapshotDao.findByVm(anyLong())).thenReturn(mockList); + when(mockList.size()).thenReturn(0); } @@ -298,7 +307,9 @@ public class UserVmManagerTest { doNothing().when(_volsDao).attachVolume(anyLong(), anyLong(), anyLong()); when(_volumeMock.getId()).thenReturn(3L); doNothing().when(_volsDao).detachVolume(anyLong()); - + List mockList = mock(List.class); + when(_vmSnapshotDao.findByVm(anyLong())).thenReturn(mockList); + when(mockList.size()).thenReturn(0); when(_templateMock.getUuid()).thenReturn("b1a3626e-72e0-4697-8c7c-a110940cc55d"); Account account = new AccountVO("testaccount", 1L, "networkdomain", (short)0, "uuid"); @@ -343,7 +354,9 @@ public class UserVmManagerTest { doNothing().when(_volsDao).attachVolume(anyLong(), anyLong(), anyLong()); when(_volumeMock.getId()).thenReturn(3L); doNothing().when(_volsDao).detachVolume(anyLong()); - + List mockList = mock(List.class); + when(_vmSnapshotDao.findByVm(anyLong())).thenReturn(mockList); + when(mockList.size()).thenReturn(0); when(_templateMock.getUuid()).thenReturn("b1a3626e-72e0-4697-8c7c-a110940cc55d"); Account account = new AccountVO("testaccount", 1L, "networkdomain", (short)0, "uuid");