From a3c2e06a58d99cfa7d2d5a35d21b7f9e2a7c04b0 Mon Sep 17 00:00:00 2001 From: Harikrishna Patnala Date: Wed, 13 May 2026 13:20:59 +0530 Subject: [PATCH] more unit tests --- .../cloud/network/IpAddressManagerTest.java | 175 +++++++++ .../network/firewall/FirewallManagerTest.java | 343 ++++++++++++++++++ 2 files changed, 518 insertions(+) diff --git a/server/src/test/java/com/cloud/network/IpAddressManagerTest.java b/server/src/test/java/com/cloud/network/IpAddressManagerTest.java index 824d4ee4701..b324a757000 100644 --- a/server/src/test/java/com/cloud/network/IpAddressManagerTest.java +++ b/server/src/test/java/com/cloud/network/IpAddressManagerTest.java @@ -19,10 +19,13 @@ package com.cloud.network; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -54,8 +57,12 @@ import com.cloud.network.dao.IPAddressDao; import com.cloud.network.dao.IPAddressVO; import com.cloud.network.dao.NetworkDao; import com.cloud.network.dao.NetworkVO; +import com.cloud.network.rules.FirewallRule; import com.cloud.network.rules.StaticNat; import com.cloud.network.rules.StaticNatImpl; +import com.cloud.network.vpc.VpcManager; +import com.cloud.dc.dao.VlanDao; +import com.cloud.dc.VlanVO; import com.cloud.offerings.NetworkOfferingVO; import com.cloud.offerings.dao.NetworkOfferingDao; import com.cloud.user.AccountVO; @@ -105,6 +112,12 @@ public class IpAddressManagerTest { @Mock AccountManager accountManagerMock; + @Mock + VpcManager vpcMgr; + + @Mock + VlanDao vlanDao; + final long dummyID = 1L; final String UUID = "uuid"; @@ -491,4 +504,166 @@ public class IpAddressManagerTest { Assert.assertTrue(result); } + + + private FirewallRule makeRule(Long networkId, Long vpcId, FirewallRule.Purpose purpose, + FirewallRule.TrafficType trafficType) { + FirewallRule rule = mock(FirewallRule.class); + lenient().when(rule.getNetworkId()).thenReturn(networkId); + lenient().when(rule.getVpcId()).thenReturn(vpcId); + lenient().when(rule.getPurpose()).thenReturn(purpose); + lenient().when(rule.getTrafficType()).thenReturn(trafficType); + return rule; + } + + /** Stub the two IP-association helper methods so they are no-ops. */ + private void stubIpAssocHelpers() throws ResourceUnavailableException { + doReturn(false).when(ipAddressManager).checkIfIpAssocRequired(any(Network.class), anyBoolean(), any()); + } + + /** + * Test: Non-VPC rules still resolve via networkId (backward compatibility). + */ + @Test + public void applyRulesNonVpcRuleStillWorksViaNetworkId() throws ResourceUnavailableException { + long networkId = 10L; + NetworkVO network = mock(NetworkVO.class); + when(network.getId()).thenReturn(networkId); + when(networkDao.findById(networkId)).thenReturn(network); + + FirewallRule rule = makeRule(networkId, null, FirewallRule.Purpose.Firewall, FirewallRule.TrafficType.Ingress); + NetworkRuleApplier applier = mock(NetworkRuleApplier.class); + + when(ipAddressDao.listByAssociatedNetwork(networkId, null)).thenReturn(new ArrayList<>()); + stubIpAssocHelpers(); + + boolean result = ipAddressManager.applyRules( + Collections.singletonList(rule), FirewallRule.Purpose.Firewall, applier, false); + + assertTrue(result); + verify(networkDao).findById(networkId); + verify(applier).applyRules(network, FirewallRule.Purpose.Firewall, Collections.singletonList(rule)); + } + + /** + * Test: VPC rule resolves network via VpcManager.getVpcNetworks() + * when networkId is null but vpcId is set. + */ + @Test + public void applyRulesVpcRuleResolvesNetworkViaVpcManager() throws ResourceUnavailableException { + long vpcId = 20L; + long resolvedNetworkId = 30L; + NetworkVO resolvedNetwork = mock(NetworkVO.class); + when(resolvedNetwork.getId()).thenReturn(resolvedNetworkId); + when(networkDao.findById(resolvedNetworkId)).thenReturn(resolvedNetwork); + + doReturn(Collections.singletonList(resolvedNetwork)).when(vpcMgr).getVpcNetworks(vpcId); + + FirewallRule rule = makeRule(null, vpcId, FirewallRule.Purpose.Firewall, FirewallRule.TrafficType.Ingress); + NetworkRuleApplier applier = mock(NetworkRuleApplier.class); + + IPAddressVO vpcIp = mock(IPAddressVO.class); + when(vpcIp.getVlanId()).thenReturn(1L); + VlanVO vlan = mock(VlanVO.class); + when(ipAddressDao.listByAssociatedVpc(vpcId, null)).thenReturn(Collections.singletonList(vpcIp)); + when(vlanDao.findById(1L)).thenReturn(vlan); + + stubIpAssocHelpers(); + + boolean result = ipAddressManager.applyRules( + Collections.singletonList(rule), FirewallRule.Purpose.Firewall, applier, false); + + assertTrue(result); + verify(vpcMgr).getVpcNetworks(vpcId); + verify(ipAddressDao).listByAssociatedVpc(vpcId, null); + verify(applier).applyRules(resolvedNetwork, FirewallRule.Purpose.Firewall, Collections.singletonList(rule)); + } + + + /** + * Test: For VPC egress firewall rules, IP collection should be skipped. + */ + @Test + public void applyRulesVpcEgressFirewallRuleSkipsIpCollection() throws ResourceUnavailableException { + long vpcId = 20L; + long resolvedNetworkId = 30L; + NetworkVO resolvedNetwork = mock(NetworkVO.class); + when(resolvedNetwork.getId()).thenReturn(resolvedNetworkId); + when(networkDao.findById(resolvedNetworkId)).thenReturn(resolvedNetwork); + doReturn(Collections.singletonList(resolvedNetwork)).when(vpcMgr).getVpcNetworks(vpcId); + + FirewallRule rule = makeRule(null, vpcId, FirewallRule.Purpose.Firewall, FirewallRule.TrafficType.Egress); + NetworkRuleApplier applier = mock(NetworkRuleApplier.class); + + stubIpAssocHelpers(); + + boolean result = ipAddressManager.applyRules( + Collections.singletonList(rule), FirewallRule.Purpose.Firewall, applier, false); + + assertTrue(result); + verify(ipAddressDao, never()).listByAssociatedVpc(anyLong(), any()); + verify(applier).applyRules(resolvedNetwork, FirewallRule.Purpose.Firewall, Collections.singletonList(rule)); + } + + /** + * Test: VPC ingress firewall rules collect public IPs from VPC (listByAssociatedVpc), + * NOT from network (listByAssociatedNetwork). + */ + @Test + public void applyRulesVpcIngressRuleCollectsIpsFromVpcNotNetwork() throws ResourceUnavailableException { + long vpcId = 20L; + long resolvedNetworkId = 30L; + NetworkVO resolvedNetwork = mock(NetworkVO.class); + when(resolvedNetwork.getId()).thenReturn(resolvedNetworkId); + when(networkDao.findById(resolvedNetworkId)).thenReturn(resolvedNetwork); + doReturn(Collections.singletonList(resolvedNetwork)).when(vpcMgr).getVpcNetworks(vpcId); + + IPAddressVO vpcIp = mock(IPAddressVO.class); + when(vpcIp.getVlanId()).thenReturn(1L); + VlanVO vlan = mock(VlanVO.class); + when(ipAddressDao.listByAssociatedVpc(vpcId, null)).thenReturn(Collections.singletonList(vpcIp)); + when(vlanDao.findById(1L)).thenReturn(vlan); + + stubIpAssocHelpers(); + + NetworkRuleApplier applier = mock(NetworkRuleApplier.class); + FirewallRule rule = makeRule(null, vpcId, FirewallRule.Purpose.Firewall, FirewallRule.TrafficType.Ingress); + + ipAddressManager.applyRules(Collections.singletonList(rule), FirewallRule.Purpose.Firewall, applier, false); + + verify(ipAddressDao).listByAssociatedVpc(vpcId, null); + verify(ipAddressDao, never()).listByAssociatedNetwork(anyLong(), any()); + } + + /** + * Test: Error handling respects continueOnError flag. + * When continueOnError=true, exceptions are caught and false is returned. + */ + @Test + public void applyRulesVpcRuleErrorHandlingWithContinueOnErrorTrue() throws ResourceUnavailableException { + long vpcId = 20L; + long resolvedNetworkId = 30L; + NetworkVO resolvedNetwork = mock(NetworkVO.class); + when(resolvedNetwork.getId()).thenReturn(resolvedNetworkId); + when(networkDao.findById(resolvedNetworkId)).thenReturn(resolvedNetwork); + doReturn(Collections.singletonList(resolvedNetwork)).when(vpcMgr).getVpcNetworks(vpcId); + + IPAddressVO vpcIp = mock(IPAddressVO.class); + when(vpcIp.getVlanId()).thenReturn(1L); + VlanVO vlan = mock(VlanVO.class); + when(ipAddressDao.listByAssociatedVpc(vpcId, null)).thenReturn(Collections.singletonList(vpcIp)); + when(vlanDao.findById(1L)).thenReturn(vlan); + + stubIpAssocHelpers(); + + NetworkRuleApplier applier = mock(NetworkRuleApplier.class); + when(applier.applyRules(any(), any(), any())).thenThrow(new ResourceUnavailableException("test", Network.class, 0L)); + + FirewallRule rule = makeRule(null, vpcId, FirewallRule.Purpose.Firewall, FirewallRule.TrafficType.Ingress); + + boolean result = ipAddressManager.applyRules( + Collections.singletonList(rule), FirewallRule.Purpose.Firewall, applier, true); + + assertFalse(result); + } } diff --git a/server/src/test/java/com/cloud/network/firewall/FirewallManagerTest.java b/server/src/test/java/com/cloud/network/firewall/FirewallManagerTest.java index c33af926e81..5ef60c699b2 100644 --- a/server/src/test/java/com/cloud/network/firewall/FirewallManagerTest.java +++ b/server/src/test/java/com/cloud/network/firewall/FirewallManagerTest.java @@ -17,6 +17,7 @@ package com.cloud.network.firewall; +import com.cloud.exception.InvalidParameterValueException; import com.cloud.exception.NetworkRuleConflictException; import com.cloud.exception.ResourceUnavailableException; import com.cloud.network.IpAddressManager; @@ -59,14 +60,18 @@ import org.mockito.junit.MockitoJUnitRunner; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -370,4 +375,342 @@ public class FirewallManagerTest { verify(_networkModel, Mockito.never()).getNetworkServiceCapabilities(Mockito.anyLong(), Mockito.eq(Service.Firewall)); } + + @Test + public void testIsVpcIpAddressReturnsTrueWhenVpcIdPresent() { + IPAddressVO ipAddress = Mockito.mock(IPAddressVO.class); + when(ipAddress.getVpcId()).thenReturn(5L); + Assert.assertTrue(_firewallMgr.isVpcIpAddress(ipAddress)); + } + + @Test + public void testIsVpcIpAddressReturnsFalseWhenVpcIdNull() { + IPAddressVO ipAddress = Mockito.mock(IPAddressVO.class); + when(ipAddress.getVpcId()).thenReturn(null); + Assert.assertFalse(_firewallMgr.isVpcIpAddress(ipAddress)); + } + + @Test + public void testValidateFirewallRuleForIsolatedIpReturnsNetworkId() { + IPAddressVO ipAddress = Mockito.mock(IPAddressVO.class); + when(ipAddress.getAssociatedWithNetworkId()).thenReturn(42L); + Long result = _firewallMgr.validateFirewallRuleForIsolatedIp(ipAddress); + Assert.assertEquals(Long.valueOf(42L), result); + } + + @Test(expected = InvalidParameterValueException.class) + public void testValidateFirewallRuleForIsolatedIpThrowsWhenNotAssociated() { + IPAddressVO ipAddress = Mockito.mock(IPAddressVO.class); + when(ipAddress.getAssociatedWithNetworkId()).thenReturn(null); + _firewallMgr.validateFirewallRuleForIsolatedIp(ipAddress); + } + + @Test + public void testValidateFirewallRuleForVpcIpReturnsNetworkId() { + IPAddressVO ipAddress = Mockito.mock(IPAddressVO.class); + Long result = _firewallMgr.validateFirewallRuleForVpcIp(ipAddress, 99L); + Assert.assertEquals(Long.valueOf(99L), result); + } + + @Test(expected = InvalidParameterValueException.class) + public void testValidateFirewallRuleForVpcIpThrowsWhenNetworkIdNull() { + IPAddressVO ipAddress = Mockito.mock(IPAddressVO.class); + _firewallMgr.validateFirewallRuleForVpcIp(ipAddress, null); + } + + @Test + public void testGetFirewallServiceCapabilitiesForNonVpcNetworkUsesNetworkModel() { + NetworkVO network = Mockito.mock(NetworkVO.class); + when(network.getId()).thenReturn(1L); + when(network.getVpcId()).thenReturn(null); + Map caps = new HashMap<>(); + caps.put(Capability.SupportedProtocols, "tcp,udp"); + when(_networkModel.getNetworkServiceCapabilities(1L, Service.Firewall)).thenReturn(caps); + + Map result = _firewallMgr.getFirewallServiceCapabilities(network); + + Assert.assertEquals(caps, result); + verify(_networkModel, times(1)).getNetworkServiceCapabilities(1L, Service.Firewall); + } + + @Test + public void testGetFirewallServiceCapabilitiesForVpcNetworkUsesVpcProvider() { + NetworkVO network = Mockito.mock(NetworkVO.class); + FirewallServiceProvider fwProvider = Mockito.mock(FirewallServiceProvider.class); + Map firewallCaps = new HashMap<>(); + firewallCaps.put(Capability.SupportedProtocols, "tcp,udp,icmp"); + Map> providerCapabilities = new HashMap<>(); + providerCapabilities.put(Service.Firewall, firewallCaps); + + when(network.getId()).thenReturn(1L); + when(network.getVpcId()).thenReturn(10L); + when(fwProvider.getProvider()).thenReturn(Network.Provider.VPCVirtualRouter); + when(fwProvider.getCapabilities()).thenReturn(providerCapabilities); + when(_vpcMgr.isProviderSupportServiceInVpc(10L, Service.Firewall, Network.Provider.VPCVirtualRouter)).thenReturn(true); + _firewallMgr._firewallElements = List.of(fwProvider); + + Map result = _firewallMgr.getFirewallServiceCapabilities(network); + + Assert.assertEquals(firewallCaps, result); + verify(_networkModel, never()).getNetworkServiceCapabilities(Mockito.anyLong(), Mockito.eq(Service.Firewall)); + } + + @Test + public void testGetFirewallServiceCapabilitiesForVpcNetworkFallsBackToNetworkModelWhenNoProvider() { + NetworkVO network = Mockito.mock(NetworkVO.class); + FirewallServiceProvider fwProvider = Mockito.mock(FirewallServiceProvider.class); + Map fallbackCaps = new HashMap<>(); + + when(network.getId()).thenReturn(1L); + when(network.getVpcId()).thenReturn(10L); + when(fwProvider.getProvider()).thenReturn(Network.Provider.VPCVirtualRouter); + when(_vpcMgr.isProviderSupportServiceInVpc(10L, Service.Firewall, Network.Provider.VPCVirtualRouter)).thenReturn(false); + when(_networkModel.getNetworkServiceCapabilities(1L, Service.Firewall)).thenReturn(fallbackCaps); + _firewallMgr._firewallElements = List.of(fwProvider); + + Map result = _firewallMgr.getFirewallServiceCapabilities(network); + + Assert.assertEquals(fallbackCaps, result); + verify(_networkModel, times(1)).getNetworkServiceCapabilities(1L, Service.Firewall); + } + + @Test + public void testGetFirewallServiceCapabilitiesForVpcReturnsCapabilitiesWhenProviderSupports() { + FirewallServiceProvider fwProvider = Mockito.mock(FirewallServiceProvider.class); + Map firewallCaps = new HashMap<>(); + firewallCaps.put(Capability.SupportedProtocols, "tcp,udp"); + Map> caps = new HashMap<>(); + caps.put(Service.Firewall, firewallCaps); + + when(fwProvider.getProvider()).thenReturn(Network.Provider.VPCVirtualRouter); + when(fwProvider.getCapabilities()).thenReturn(caps); + when(_vpcMgr.isProviderSupportServiceInVpc(10L, Service.Firewall, Network.Provider.VPCVirtualRouter)).thenReturn(true); + _firewallMgr._firewallElements = List.of(fwProvider); + + Map result = _firewallMgr.getFirewallServiceCapabilitiesForVpc(10L); + + Assert.assertNotNull(result); + Assert.assertEquals("tcp,udp", result.get(Capability.SupportedProtocols)); + } + + @Test + public void testGetFirewallServiceCapabilitiesForVpcReturnsNullWhenNoProviderSupports() { + FirewallServiceProvider fwProvider = Mockito.mock(FirewallServiceProvider.class); + when(fwProvider.getProvider()).thenReturn(Network.Provider.VPCVirtualRouter); + when(_vpcMgr.isProviderSupportServiceInVpc(10L, Service.Firewall, Network.Provider.VPCVirtualRouter)).thenReturn(false); + _firewallMgr._firewallElements = List.of(fwProvider); + + Map result = _firewallMgr.getFirewallServiceCapabilitiesForVpc(10L); + + Assert.assertNull(result); + } + + @Test(expected = InvalidParameterValueException.class) + public void testValidateFirewallRuleForVpcThrowsOnInvalidStartPort() { + Account caller = Mockito.mock(Account.class); + _firewallMgr.validateFirewallRuleForVpc(caller, null, -1, 80, "tcp", Purpose.Firewall, FirewallRuleType.User, 10L, FirewallRule.TrafficType.Ingress); + } + + @Test(expected = InvalidParameterValueException.class) + public void testValidateFirewallRuleForVpcThrowsOnInvalidEndPort() { + Account caller = Mockito.mock(Account.class); + _firewallMgr.validateFirewallRuleForVpc(caller, null, 80, 70000, "tcp", Purpose.Firewall, FirewallRuleType.User, 10L, FirewallRule.TrafficType.Ingress); + } + + @Test(expected = InvalidParameterValueException.class) + public void testValidateFirewallRuleForVpcThrowsWhenStartPortGreaterThanEndPort() { + Account caller = Mockito.mock(Account.class); + _firewallMgr.validateFirewallRuleForVpc(caller, null, 200, 100, "tcp", Purpose.Firewall, FirewallRuleType.User, 10L, FirewallRule.TrafficType.Ingress); + } + + @Test + public void testValidateFirewallRuleForVpcSystemTypeWithNullIpReturnsEarly() { + // System rule type + null IP should return without further validation + Account caller = Mockito.mock(Account.class); + // Should not throw even though vpcId checks come after this + _firewallMgr.validateFirewallRuleForVpc(caller, null, 80, 80, "tcp", Purpose.Firewall, FirewallRuleType.System, 10L, FirewallRule.TrafficType.Ingress); + } + + @Test(expected = InvalidParameterValueException.class) + public void testValidateFirewallRuleForVpcThrowsWhenVpcIdNullAndNotSystemRule() { + Account caller = Mockito.mock(Account.class); + IPAddressVO ipAddress = Mockito.mock(IPAddressVO.class); + _firewallMgr.validateFirewallRuleForVpc(caller, ipAddress, 80, 80, "tcp", Purpose.Firewall, FirewallRuleType.User, null, FirewallRule.TrafficType.Ingress); + } + + @Test(expected = InvalidParameterValueException.class) + public void testValidateFirewallRuleForVpcThrowsWhenActiveVpcNotFound() { + Account caller = Mockito.mock(Account.class); + IPAddressVO ipAddress = Mockito.mock(IPAddressVO.class); + when(_vpcMgr.getActiveVpc(10L)).thenReturn(null); + _firewallMgr.validateFirewallRuleForVpc(caller, ipAddress, 80, 80, "tcp", Purpose.Firewall, FirewallRuleType.User, 10L, FirewallRule.TrafficType.Ingress); + } + + @Test(expected = InvalidParameterValueException.class) + public void testValidateFirewallRuleForVpcThrowsOnUnsupportedProtocol() { + Account caller = Mockito.mock(Account.class); + IPAddressVO ipAddress = Mockito.mock(IPAddressVO.class); + Vpc vpc = Mockito.mock(Vpc.class); + FirewallServiceProvider fwProvider = Mockito.mock(FirewallServiceProvider.class); + Map firewallCaps = new HashMap<>(); + firewallCaps.put(Capability.SupportedProtocols, "tcp,udp"); + firewallCaps.put(Capability.SupportedTrafficDirection, "ingress,egress"); + Map> caps = new HashMap<>(); + caps.put(Service.Firewall, firewallCaps); + + when(_vpcMgr.getActiveVpc(10L)).thenReturn(vpc); + when(fwProvider.getProvider()).thenReturn(Network.Provider.VPCVirtualRouter); + when(fwProvider.getCapabilities()).thenReturn(caps); + when(_vpcMgr.isProviderSupportServiceInVpc(10L, Service.Firewall, Network.Provider.VPCVirtualRouter)).thenReturn(true); + _firewallMgr._firewallElements = List.of(fwProvider); + + _firewallMgr.validateFirewallRuleForVpc(caller, ipAddress, 80, 80, "gre", Purpose.Firewall, FirewallRuleType.User, 10L, FirewallRule.TrafficType.Ingress); + } + + @Test + public void testValidateFirewallRuleForVpcSucceedsWithSupportedProtocolAndTrafficType() { + Account caller = Mockito.mock(Account.class); + IPAddressVO ipAddress = Mockito.mock(IPAddressVO.class); + Vpc vpc = Mockito.mock(Vpc.class); + FirewallServiceProvider fwProvider = Mockito.mock(FirewallServiceProvider.class); + Map firewallCaps = new HashMap<>(); + firewallCaps.put(Capability.SupportedProtocols, "tcp,udp,icmp"); + firewallCaps.put(Capability.SupportedTrafficDirection, "ingress,egress"); + Map> caps = new HashMap<>(); + caps.put(Service.Firewall, firewallCaps); + + when(_vpcMgr.getActiveVpc(10L)).thenReturn(vpc); + when(fwProvider.getProvider()).thenReturn(Network.Provider.VPCVirtualRouter); + when(fwProvider.getCapabilities()).thenReturn(caps); + when(_vpcMgr.isProviderSupportServiceInVpc(10L, Service.Firewall, Network.Provider.VPCVirtualRouter)).thenReturn(true); + _firewallMgr._firewallElements = List.of(fwProvider); + + // Should not throw + _firewallMgr.validateFirewallRuleForVpc(caller, ipAddress, 80, 80, "tcp", Purpose.Firewall, FirewallRuleType.User, 10L, FirewallRule.TrafficType.Ingress); + + verify(_accountMgr, times(1)).checkAccess(caller, null, true, ipAddress); + } + + @Test + public void testCreateIngressFirewallRuleRoutesToVpcMethodWhenIpHasVpcId() throws NetworkRuleConflictException { + FirewallRule rule = Mockito.mock(FirewallRule.class); + IPAddressVO ipAddress = Mockito.mock(IPAddressVO.class); + + when(rule.getSourceIpAddressId()).thenReturn(1L); + when(ipAddress.getVpcId()).thenReturn(10L); + + doReturn(ipAddress).when(_firewallMgr).getSourceIpForIngressRule(1L); + doReturn(rule).when(_firewallMgr).createIngressFirewallRuleForVpcIp(rule, null, ipAddress); + + doReturn(rule).when(_firewallMgr).createIngressFirewallRuleForVpcIp( + Mockito.eq(rule), Mockito.any(), Mockito.eq(ipAddress)); + + verify(_firewallMgr, never()).createIngressFirewallRuleForIsolatedIp( + Mockito.any(), Mockito.any(), Mockito.any()); + } + + @Test + public void testCreateFirewallRuleRoutesToVpcWhenVpcIdProvided() throws NetworkRuleConflictException { + Account caller = Mockito.mock(Account.class); + FirewallRule vpcRule = Mockito.mock(FirewallRule.class); + + doReturn(vpcRule).when(_firewallMgr).createFirewallRuleForVpc( + Mockito.anyLong(), Mockito.eq(caller), Mockito.any(), Mockito.anyInt(), Mockito.anyInt(), + Mockito.anyString(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), + Mockito.any(), Mockito.any(FirewallRuleType.class), Mockito.anyLong(), + Mockito.any(FirewallRule.TrafficType.class), Mockito.anyBoolean()); + + _firewallMgr.createFirewallRule(1L, caller, "xid", 80, 80, "tcp", + Collections.singletonList("0.0.0.0/0"), null, null, null, null, + FirewallRuleType.User, null, 10L, FirewallRule.TrafficType.Ingress, true); + + verify(_firewallMgr, times(1)).createFirewallRuleForVpc( + Mockito.anyLong(), Mockito.eq(caller), Mockito.any(), Mockito.anyInt(), Mockito.anyInt(), + Mockito.anyString(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), + Mockito.any(), Mockito.any(FirewallRuleType.class), Mockito.anyLong(), + Mockito.any(FirewallRule.TrafficType.class), Mockito.anyBoolean()); + + verify(_firewallMgr, never()).createFirewallRuleForNonVPC( + Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), + Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), + Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); + } + + @Test + public void testCreateFirewallRuleRoutesToNonVpcWhenVpcIdNull() throws NetworkRuleConflictException { + Account caller = Mockito.mock(Account.class); + FirewallRule nonVpcRule = Mockito.mock(FirewallRule.class); + + doReturn(nonVpcRule).when(_firewallMgr).createFirewallRuleForNonVPC( + Mockito.any(), Mockito.eq(caller), Mockito.any(), Mockito.any(), Mockito.any(), + Mockito.anyString(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), + Mockito.any(), Mockito.any(FirewallRuleType.class), Mockito.anyLong(), + Mockito.any(FirewallRule.TrafficType.class), Mockito.anyBoolean()); + + _firewallMgr.createFirewallRule(null, caller, "xid", 80, 80, "tcp", + Collections.singletonList("0.0.0.0/0"), null, null, null, null, + FirewallRuleType.User, 2L, null, FirewallRule.TrafficType.Ingress, true); + + verify(_firewallMgr, times(1)).createFirewallRuleForNonVPC( + Mockito.any(), Mockito.eq(caller), Mockito.any(), Mockito.any(), Mockito.any(), + Mockito.anyString(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), + Mockito.any(), Mockito.any(FirewallRuleType.class), Mockito.anyLong(), + Mockito.any(FirewallRule.TrafficType.class), Mockito.anyBoolean()); + + verify(_firewallMgr, never()).createFirewallRuleForVpc( + Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), + Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), + Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); + } + + @Test + public void testApplyRulesForVpcNetworkUsesVpcProviderCheck() throws ResourceUnavailableException { + Network network = Mockito.mock(Network.class); + FirewallServiceProvider fwProvider = Mockito.mock(FirewallServiceProvider.class); + List rules = new ArrayList<>(); + FirewallRuleVO rule = new FirewallRuleVO("rule1", 1L, 80, 80, "tcp", 1L, 2, 3, Purpose.Firewall, + Collections.emptyList(), Collections.emptyList(), null, null, null, FirewallRule.TrafficType.Ingress); + rules.add(rule); + + when(network.getVpcId()).thenReturn(10L); + when(fwProvider.getProvider()).thenReturn(Network.Provider.VPCVirtualRouter); + when(_vpcMgr.isProviderSupportServiceInVpc(10L, Service.Firewall, Network.Provider.VPCVirtualRouter)).thenReturn(true); + when(fwProvider.applyFWRules(Mockito.eq(network), Mockito.anyList())).thenReturn(true); + _firewallMgr._firewallElements = List.of(fwProvider); + + boolean result = _firewallMgr.applyRules(network, Purpose.Firewall, rules); + + Assert.assertTrue(result); + verify(_vpcMgr, times(1)).isProviderSupportServiceInVpc(10L, Service.Firewall, Network.Provider.VPCVirtualRouter); + verify(_networkModel, never()).isProviderSupportServiceInNetwork(Mockito.anyLong(), Mockito.eq(Service.Firewall), Mockito.any()); + } + + @Test + public void testApplyRulesForNonVpcNetworkUsesNetworkModelProviderCheck() throws ResourceUnavailableException { + Network network = Mockito.mock(Network.class); + FirewallServiceProvider fwProvider = Mockito.mock(FirewallServiceProvider.class); + List rules = new ArrayList<>(); + FirewallRuleVO rule = new FirewallRuleVO("rule1", 1L, 80, 80, "tcp", 1L, 2, 3, Purpose.Firewall, + Collections.emptyList(), Collections.emptyList(), null, null, null, FirewallRule.TrafficType.Ingress); + rules.add(rule); + + when(network.getId()).thenReturn(1L); + when(network.getVpcId()).thenReturn(null); + when(fwProvider.getProvider()).thenReturn(Network.Provider.VirtualRouter); + when(_networkModel.isProviderSupportServiceInNetwork(1L, Service.Firewall, Network.Provider.VirtualRouter)).thenReturn(true); + when(fwProvider.applyFWRules(Mockito.eq(network), Mockito.anyList())).thenReturn(true); + _firewallMgr._firewallElements = List.of(fwProvider); + + boolean result = _firewallMgr.applyRules(network, Purpose.Firewall, rules); + + Assert.assertTrue(result); + verify(_networkModel, times(1)).isProviderSupportServiceInNetwork(1L, Service.Firewall, Network.Provider.VirtualRouter); + verify(_vpcMgr, never()).isProviderSupportServiceInVpc(Mockito.anyLong(), Mockito.eq(Service.Firewall), Mockito.any()); + } + + @Test + public void testGetSourceIpForIngressRuleReturnsNullWhenIdIsNull() { + IPAddressVO result = _firewallMgr.getSourceIpForIngressRule(null); + Assert.assertNull(result); + } }