This commit is contained in:
Harikrishna 2026-05-13 10:15:55 +00:00 committed by GitHub
commit 713e71bd4b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 1361 additions and 87 deletions

View File

@ -69,7 +69,9 @@ public interface FirewallRule extends ControlledEntity, Identity, InternalIdenti
State getState();
long getNetworkId();
Long getNetworkId();
Long getVpcId();
Long getSourceIpAddressId();

View File

@ -212,7 +212,7 @@ public class CreateEgressFirewallRuleCmd extends BaseAsyncCreateCmd implements F
}
@Override
public long getNetworkId() {
public Long getNetworkId() {
return networkId;
}

View File

@ -223,13 +223,9 @@ public class CreateFirewallRuleCmd extends BaseAsyncCreateCmd implements Firewal
}
@Override
public long getNetworkId() {
IpAddress ip = _entityMgr.findById(IpAddress.class, getIpAddressId());
Long ntwkId = null;
if (ip.getAssociatedWithNetworkId() != null) {
ntwkId = ip.getAssociatedWithNetworkId();
}
public Long getNetworkId() {
IpAddress ip = getIp();
Long ntwkId = isVpcIp(ip) ? getVpcNetworkIdForFirewallRule(ip) : getIsolatedNetworkIdForFirewallRule(ip);
if (ntwkId == null) {
throw new InvalidParameterValueException("Unable to create firewall rule for the IP address ID=" + ipAddressId +
@ -238,6 +234,12 @@ public class CreateFirewallRuleCmd extends BaseAsyncCreateCmd implements Firewal
return ntwkId;
}
@Override
public Long getVpcId() {
IpAddress ip = getIp();
return isVpcIp(ip) ? ip.getVpcId() : null;
}
@Override
public long getEntityOwnerId() {
Account account = CallContext.current().getCallingAccount();
@ -300,7 +302,21 @@ public class CreateFirewallRuleCmd extends BaseAsyncCreateCmd implements Firewal
@Override
public Long getSyncObjId() {
return getIp().getAssociatedWithNetworkId();
Long syncObjId = getIp().getAssociatedWithNetworkId();
return syncObjId != null ? syncObjId : getNetworkId();
}
private boolean isVpcIp(IpAddress ip) {
return ip.getVpcId() != null;
}
private Long getIsolatedNetworkIdForFirewallRule(IpAddress ip) {
return ip.getAssociatedWithNetworkId();
}
private Long getVpcNetworkIdForFirewallRule(IpAddress ip) {
// VPC flow is independent from tier association; manager resolves execution network.
return ip.getNetworkId();
}
private IpAddress getIp() {
@ -311,6 +327,7 @@ public class CreateFirewallRuleCmd extends BaseAsyncCreateCmd implements Firewal
return ip;
}
@Override
public Integer getIcmpCode() {
if (icmpCode != null) {

View File

@ -176,7 +176,7 @@ public class CreatePortForwardingRuleCmd extends BaseAsyncCreateCmd implements P
}
}
private Long getVpcId() {
public Long getVpcId() {
if (ipAddressId != null) {
IpAddress ipAddr = _networkService.getIp(ipAddressId);
if (ipAddr == null || !ipAddr.readyToUse()) {
@ -275,7 +275,7 @@ public class CreatePortForwardingRuleCmd extends BaseAsyncCreateCmd implements P
}
@Override
public long getNetworkId() {
public Long getNetworkId() {
IpAddress ip = _entityMgr.findById(IpAddress.class, getIpAddressId());
Long ntwkId = _networkService.getPreferredNetworkIdForPublicIpRuleAssignment(ip, networkId);
if (ntwkId == null) {

View File

@ -229,8 +229,13 @@ public class CreateIpForwardingRuleCmd extends BaseAsyncCreateCmd implements Sta
}
@Override
public long getNetworkId() {
return -1;
public Long getNetworkId() {
return -1L;
}
@Override
public Long getVpcId() {
return null;
}
@Override

View File

@ -51,6 +51,10 @@ public class FirewallResponse extends BaseResponse {
@Param(description = "The Network ID of the firewall rule")
private String networkId;
@SerializedName(ApiConstants.VPC_ID)
@Param(description = "The VPC ID of the firewall rule")
private String vpcId;
@SerializedName(ApiConstants.IP_ADDRESS)
@Param(description = "The public IP address for the firewall rule")
private String publicIpAddress;
@ -115,6 +119,10 @@ public class FirewallResponse extends BaseResponse {
this.networkId = networkId;
}
public void setVpcId(String vpcId) {
this.vpcId = vpcId;
}
public void setState(String state) {
this.state = state;
}

View File

@ -21,10 +21,15 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import com.cloud.network.IpAddress;
import com.cloud.network.NetworkService;
import com.cloud.utils.db.EntityManager;
import org.apache.commons.collections.CollectionUtils;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.test.util.ReflectionTestUtils;
@ -33,6 +38,12 @@ import com.cloud.utils.net.NetUtils;
@RunWith(MockitoJUnitRunner.class)
public class CreateFirewallRuleCmdTest {
@Mock
private EntityManager entityManager;
@Mock
private NetworkService networkService;
private void validateAllIp4Cidr(final CreateFirewallRuleCmd cmd) {
Assert.assertTrue(CollectionUtils.isNotEmpty(cmd.getSourceCidrList()));
Assert.assertEquals(1, cmd.getSourceCidrList().size());
@ -88,4 +99,22 @@ public class CreateFirewallRuleCmdTest {
Assert.assertEquals(2, cmd.getSourceCidrList().size());
Assert.assertEquals(cidr, cmd.getSourceCidrList().get(1));
}
@Test
public void testGetNetworkIdVpcWithoutAssociatedNetworkUsesVpcFallbackAndSyncObjId() {
final CreateFirewallRuleCmd cmd = new CreateFirewallRuleCmd();
final IpAddress ip = Mockito.mock(IpAddress.class);
cmd._entityMgr = entityManager;
cmd._networkService = networkService;
ReflectionTestUtils.setField(cmd, "ipAddressId", 42L);
Mockito.when(networkService.getIp(42L)).thenReturn(ip);
Mockito.when(ip.getAssociatedWithNetworkId()).thenReturn(null);
Mockito.when(ip.getVpcId()).thenReturn(100L);
Mockito.when(ip.getNetworkId()).thenReturn(2L);
Assert.assertEquals(Long.valueOf(2L), cmd.getNetworkId());
Assert.assertEquals(Long.valueOf(2L), cmd.getSyncObjId());
}
}

View File

@ -80,10 +80,15 @@ public class StaticNatRuleImpl implements StaticNatRule {
}
@Override
public long getNetworkId() {
public Long getNetworkId() {
return networkId;
}
@Override
public Long getVpcId() {
return null;
}
@Override
public long getId() {
return id;

View File

@ -561,6 +561,7 @@ public class NetworkOrchestrator extends ManagerBase implements NetworkOrchestra
defaultVPCOffProviders.put(Service.StaticNat, defaultProviders);
defaultVPCOffProviders.put(Service.PortForwarding, defaultProviders);
defaultVPCOffProviders.put(Service.Vpn, defaultProviders);
defaultVPCOffProviders.put(Service.Firewall, defaultProviders);
Transaction.execute(new TransactionCallbackNoReturn() {
@Override

View File

@ -85,6 +85,9 @@ public class FirewallRuleVO implements FirewallRule {
@Column(name = "network_id")
Long networkId;
@Column(name = "vpc_id")
Long vpcId;
@Column(name = "icmp_code")
Integer icmpCode;
@ -190,10 +193,18 @@ public class FirewallRuleVO implements FirewallRule {
}
@Override
public long getNetworkId() {
public Long getNetworkId() {
return networkId;
}
public Long getVpcId() {
return vpcId;
}
public void setVpcId(Long vpcId) {
this.vpcId = vpcId;
}
@Override
public FirewallRuleType getType() {
return type;
@ -207,7 +218,7 @@ public class FirewallRuleVO implements FirewallRule {
uuid = UUID.randomUUID().toString();
}
public FirewallRuleVO(String xId, Long ipAddressId, Integer portStart, Integer portEnd, String protocol, long networkId, long accountId, long domainId,
public FirewallRuleVO(String xId, Long ipAddressId, Integer portStart, Integer portEnd, String protocol, Long networkId, long accountId, long domainId,
Purpose purpose, List<String> sourceCidrs, Integer icmpCode, Integer icmpType, Long related, TrafficType trafficType) {
this.xId = xId;
if (xId == null) {
@ -251,7 +262,7 @@ public class FirewallRuleVO implements FirewallRule {
}
public FirewallRuleVO(String xId, Long ipAddressId, Integer portStart, Integer portEnd, String protocol, long networkId, long accountId, long domainId,
public FirewallRuleVO(String xId, Long ipAddressId, Integer portStart, Integer portEnd, String protocol, Long networkId, long accountId, long domainId,
Purpose purpose, List<String> sourceCidrs, List<String> destCidrs, Integer icmpCode, Integer icmpType, Long related, TrafficType trafficType) {
this(xId,ipAddressId, portStart, portEnd, protocol, networkId, accountId, domainId, purpose, sourceCidrs, icmpCode, icmpType, related, trafficType);
this.destinationCidrs = destCidrs;

View File

@ -131,3 +131,6 @@ CREATE TABLE IF NOT EXISTS `cloud_usage`.`quota_tariff_usage` (
-- Add the 'keep_mac_address_on_public_nic' column to the 'cloud.networks' and 'cloud.vpc' tables
CALL `cloud`.`IDEMPOTENT_ADD_COLUMN`('cloud.networks', 'keep_mac_address_on_public_nic', 'TINYINT(1) NOT NULL DEFAULT 1');
CALL `cloud`.`IDEMPOTENT_ADD_COLUMN`('cloud.vpc', 'keep_mac_address_on_public_nic', 'TINYINT(1) NOT NULL DEFAULT 1');
-- This is part of allowing firewall rules on public IP addresses in VPC network
ALTER TABLE `cloud`.`firewall_rules` MODIFY COLUMN `network_id` BIGINT UNSIGNED NULL;

View File

@ -2986,9 +2986,8 @@ public class KubernetesClusterManagerImpl extends ManagerBase implements Kuberne
defaultKubernetesServiceNetworkOfferingProviders.put(Service.UserData, provider);
if (forVpc) {
defaultKubernetesServiceNetworkOfferingProviders.put(Service.NetworkACL, forNsx ? Network.Provider.Nsx : provider);
} else {
defaultKubernetesServiceNetworkOfferingProviders.put(Service.Firewall, forNsx ? Network.Provider.Nsx : provider);
}
defaultKubernetesServiceNetworkOfferingProviders.put(Service.Firewall, forNsx ? Network.Provider.Nsx : provider);
defaultKubernetesServiceNetworkOfferingProviders.put(Service.Lb, forNsx ? Network.Provider.Nsx : provider);
defaultKubernetesServiceNetworkOfferingProviders.put(Service.SourceNat, forNsx ? Network.Provider.Nsx : provider);
defaultKubernetesServiceNetworkOfferingProviders.put(Service.StaticNat, forNsx ? Network.Provider.Nsx : provider);

View File

@ -155,7 +155,7 @@ public class KubernetesClusterManagerImplTest {
}
private FirewallRuleVO createRule(int startPort, int endPort) {
FirewallRuleVO rule = new FirewallRuleVO(null, null, startPort, endPort, "tcp", 1, 1, 1, FirewallRule.Purpose.Firewall, List.of("0.0.0.0/0"), null, null, null, FirewallRule.TrafficType.Ingress);
FirewallRuleVO rule = new FirewallRuleVO(null, null, startPort, endPort, "tcp", 1L, 1, 1, FirewallRule.Purpose.Firewall, List.of("0.0.0.0/0"), null, null, null, FirewallRule.TrafficType.Ingress);
return rule;
}

View File

@ -290,7 +290,7 @@ public class PaloAltoResourceTest {
List<FirewallRuleTO> rules = new ArrayList<FirewallRuleTO>();
List<String> cidrList = new ArrayList<String>();
cidrList.add("0.0.0.0/0");
FirewallRuleVO activeVO = new FirewallRuleVO(null, null, 80, 80, "tcp", 1, 1, 1, Purpose.Firewall, cidrList, null, null, null, FirewallRule.TrafficType.Egress);
FirewallRuleVO activeVO = new FirewallRuleVO(null, null, 80, 80, "tcp", 1L, 1, 1, Purpose.Firewall, cidrList, null, null, null, FirewallRule.TrafficType.Egress);
FirewallRuleTO active = new FirewallRuleTO(activeVO, Long.toString(vlanId), null, Purpose.Firewall, FirewallRule.TrafficType.Egress);
rules.add(active);
@ -319,7 +319,7 @@ public class PaloAltoResourceTest {
long vlanId = 3954;
List<FirewallRuleTO> rules = new ArrayList<FirewallRuleTO>();
FirewallRuleVO revokedVO = new FirewallRuleVO(null, null, 80, 80, "tcp", 1, 1, 1, Purpose.Firewall, null, null, null, null, FirewallRule.TrafficType.Egress);
FirewallRuleVO revokedVO = new FirewallRuleVO(null, null, 80, 80, "tcp", 1L, 1, 1, Purpose.Firewall, null, null, null, null, FirewallRule.TrafficType.Egress);
revokedVO.setState(State.Revoke);
FirewallRuleTO revoked = new FirewallRuleTO(revokedVO, Long.toString(vlanId), null, Purpose.Firewall, FirewallRule.TrafficType.Egress);
rules.add(revoked);

View File

@ -2957,8 +2957,21 @@ public class ApiResponseHelper implements ResponseGenerator, ResourceIdSupport {
}
}
Network network = ApiDBUtils.findNetworkById(fwRule.getNetworkId());
response.setNetworkId(network.getUuid());
Long networkId = fwRule.getNetworkId();
if (networkId != null) {
Network network = ApiDBUtils.findNetworkById(networkId);
if (network != null) {
response.setNetworkId(network.getUuid());
}
}
Long vpcId = fwRule.getVpcId();
if (vpcId != null) {
Vpc vpc = ApiDBUtils.findVpcById(vpcId);
if (vpc != null) {
response.setVpcId(vpc.getUuid());
}
}
FirewallRule.State state = fwRule.getState();
String stateToSet = state.toString();
@ -5405,8 +5418,27 @@ public class ApiResponseHelper implements ResponseGenerator, ResourceIdSupport {
response.setIcmpCode(fwRule.getIcmpCode());
response.setIcmpType(fwRule.getIcmpType());
Network network = ApiDBUtils.findNetworkById(fwRule.getNetworkId());
response.setNetworkId(network.getUuid());
Long networkId = fwRule.getNetworkId();
if (networkId != null) {
Network network = ApiDBUtils.findNetworkById(networkId);
if (network != null) {
response.setNetworkId(network.getUuid());
}
}
Long vpcId = fwRule.getVpcId();
if (vpcId == null && networkId != null) {
Network network = ApiDBUtils.findNetworkById(networkId);
if (network != null) {
vpcId = network.getVpcId();
}
}
if (vpcId != null) {
Vpc vpc = ApiDBUtils.findVpcById(vpcId);
if (vpc != null) {
response.setVpcId(vpc.getUuid());
}
}
FirewallRule.State state = fwRule.getState();
String stateToSet = state.toString();

View File

@ -7223,10 +7223,12 @@ public class ConfigurationManagerImpl extends ManagerBase implements Configurati
}
if (forVpc == null) {
if (service == Service.SecurityGroup || service == Service.Firewall) {
if (service == Service.SecurityGroup) {
forVpc = false;
} else if (service == Service.NetworkACL) {
forVpc = true;
} else if (service == Service.Firewall) {
forVpc = true;
}
}

View File

@ -656,28 +656,60 @@ public class IpAddressManagerImpl extends ManagerBase implements IpAddressManage
}
boolean success = true;
Network network = _networksDao.findById(rules.get(0).getNetworkId());
FirewallRuleVO.TrafficType trafficType = rules.get(0).getTrafficType();
FirewallRule firstRule = rules.get(0);
Long networkId = firstRule.getNetworkId();
Long vpcId = firstRule.getVpcId();
FirewallRuleVO.TrafficType trafficType = firstRule.getTrafficType();
List<PublicIp> publicIps = new ArrayList<PublicIp>();
if (!(rules.get(0).getPurpose() == FirewallRule.Purpose.Firewall && trafficType == FirewallRule.TrafficType.Egress)) {
// get the list of public ip's owned by the network
List<IPAddressVO> userIps = _ipAddressDao.listByAssociatedNetwork(network.getId(), null);
if (userIps != null && !userIps.isEmpty()) {
for (IPAddressVO userIp : userIps) {
PublicIp publicIp = PublicIp.createFromAddrAndVlan(userIp, _vlanDao.findById(userIp.getVlanId()));
publicIps.add(publicIp);
// For VPC firewall rules the networkId on the rule is null; resolve via VPC.
Network network;
if (networkId != null) {
network = _networksDao.findById(networkId);
} else if (vpcId != null) {
List<? extends Network> vpcNetworks = _vpcMgr.getVpcNetworks(vpcId);
network = (vpcNetworks != null && !vpcNetworks.isEmpty()) ? _networksDao.findById(vpcNetworks.get(0).getId()) : null;
} else {
network = null;
}
if (network == null) {
logger.warn("Unable to resolve network for firewall rules (networkId={}, vpcId={}); skipping IP association", networkId, vpcId);
} else if (!(firstRule.getPurpose() == FirewallRule.Purpose.Firewall && trafficType == FirewallRule.TrafficType.Egress)) {
// For VPC ingress rules, collect public IPs tied to the VPC rather than network association
if (vpcId != null && networkId == null) {
List<IPAddressVO> vpcIps = _ipAddressDao.listByAssociatedVpc(vpcId, null);
if (vpcIps != null) {
for (IPAddressVO userIp : vpcIps) {
PublicIp publicIp = PublicIp.createFromAddrAndVlan(userIp, _vlanDao.findById(userIp.getVlanId()));
publicIps.add(publicIp);
}
}
} else {
// get the list of public ip's owned by the network
List<IPAddressVO> userIps = _ipAddressDao.listByAssociatedNetwork(network.getId(), null);
if (userIps != null && !userIps.isEmpty()) {
for (IPAddressVO userIp : userIps) {
PublicIp publicIp = PublicIp.createFromAddrAndVlan(userIp, _vlanDao.findById(userIp.getVlanId()));
publicIps.add(publicIp);
}
}
}
}
// rules can not programmed unless IP is associated with network service provider, so run IP assoication for
// rules can not programmed unless IP is associated with network service provider, so run IP association for
// the network so as to ensure IP is associated before applying rules (in add state)
if (checkIfIpAssocRequired(network, false, publicIps)) {
if (network != null && checkIfIpAssocRequired(network, false, publicIps)) {
applyIpAssociations(network, false, continueOnError, publicIps);
}
try {
applier.applyRules(network, purpose, rules);
if (network != null) {
applier.applyRules(network, purpose, rules);
} else {
logger.warn("Skipping applyRules: no network resolved for rules (vpcId={})", vpcId);
success = false;
}
} catch (ResourceUnavailableException e) {
if (!continueOnError) {
throw e;
@ -688,7 +720,7 @@ public class IpAddressManagerImpl extends ManagerBase implements IpAddressManage
// if there are no active rules associated with a public IP, then public IP need not be associated with a provider.
// This IPAssoc ensures, public IP is dis-associated after last active rule is revoked.
if (checkIfIpAssocRequired(network, true, publicIps)) {
if (network != null && checkIfIpAssocRequired(network, true, publicIps)) {
applyIpAssociations(network, true, continueOnError, publicIps);
}

View File

@ -104,9 +104,11 @@ import com.cloud.network.rules.FirewallRuleVO;
import com.cloud.network.rules.dao.PortForwardingRulesDao;
import com.cloud.network.vpc.Vpc;
import com.cloud.network.vpc.VpcGatewayVO;
import com.cloud.network.vpc.VpcOfferingServiceMapVO;
import com.cloud.network.vpc.dao.PrivateIpDao;
import com.cloud.network.vpc.dao.VpcDao;
import com.cloud.network.vpc.dao.VpcGatewayDao;
import com.cloud.network.vpc.dao.VpcOfferingServiceMapDao;
import com.cloud.offering.NetworkOffering;
import com.cloud.offering.NetworkOffering.Detail;
import com.cloud.offerings.NetworkOfferingServiceMapVO;
@ -183,6 +185,8 @@ public class NetworkModelImpl extends ManagerBase implements NetworkModel, Confi
NetworkPermissionDao _networkPermissionDao;
@Inject
VpcDao vpcDao;
@Inject
VpcOfferingServiceMapDao _vpcOffSvcMapDao;
private List<NetworkElement> networkElements;
@ -457,12 +461,16 @@ public class NetworkModelImpl extends ManagerBase implements NetworkModel, Confi
// We only support one provider for one service now
Map<Service, Set<Provider>> serviceToProviders = getServiceProvidersMap(networkId);
// Since IP already has service to bind with, the oldProvider can't be null
Set<Provider> newProviders = serviceToProviders.get(service);
Set<Provider> newProviders = getProvidersForServiceWithVpcFallback(serviceToProviders, service, publicIp.getVpcId());
if (newProviders == null || newProviders.isEmpty()) {
throw new InvalidParameterValueException("There is no new provider for IP " + publicIp.getAddress() + " of service " + service.getName() + "!");
}
Provider newProvider = (Provider)newProviders.toArray()[0];
Set<Provider> oldProviders = serviceToProviders.get(services.toArray()[0]);
Service existingService = (Service) services.toArray()[0];
Set<Provider> oldProviders = getProvidersForServiceWithVpcFallback(serviceToProviders, existingService, publicIp.getVpcId());
if (oldProviders == null || oldProviders.isEmpty()) {
throw new InvalidParameterValueException("There is no existing provider for IP " + publicIp.getAddress() + " of service " + existingService.getName() + "!");
}
Provider oldProvider = (Provider)oldProviders.toArray()[0];
Network network = _networksDao.findById(networkId);
NetworkElement oldElement = getElementImplementingProvider(oldProvider.getName());
@ -477,6 +485,35 @@ public class NetworkModelImpl extends ManagerBase implements NetworkModel, Confi
return true;
}
private Set<Provider> getProvidersForServiceWithVpcFallback(Map<Service, Set<Provider>> serviceToProviders, Service service, Long vpcId) {
Set<Provider> providers = serviceToProviders.get(service);
if (providers != null && !providers.isEmpty()) {
return providers;
}
if (vpcId == null || service != Service.Firewall) {
return providers;
}
Set<Provider> vpcProviders = new HashSet<Provider>();
Vpc vpc = vpcDao.findById(vpcId);
if (vpc == null) {
return vpcProviders;
}
List<VpcOfferingServiceMapVO> offeringProviders = _vpcOffSvcMapDao.listProvidersForServiceForVpcOffering(vpc.getVpcOfferingId(), Service.Firewall);
if (offeringProviders != null) {
for (VpcOfferingServiceMapVO offeringProvider : offeringProviders) {
Provider provider = Provider.getProvider(offeringProvider.getProvider());
if (provider != null) {
vpcProviders.add(provider);
}
}
}
return vpcProviders;
}
Map<Provider, Set<Service>> getProviderServicesMap(long networkId) {
Map<Provider, Set<Service>> map = new HashMap<Provider, Set<Service>>();
List<NetworkServiceMapVO> nsms = _ntwkSrvcDao.getServicesInNetwork(networkId);

View File

@ -139,7 +139,13 @@ public class VpcVirtualRouterElement extends VirtualRouterElement implements Vpc
return false;
}
} else {
if (!_networkMdl.isProviderSupportServiceInNetwork(network.getId(), service, getProvider())) {
boolean supportsService;
if (service == Service.Firewall) {
supportsService = _vpcMgr.isProviderSupportServiceInVpc(network.getVpcId(), service, getProvider());
} else {
supportsService = _networkMdl.isProviderSupportServiceInNetwork(network.getId(), service, getProvider());
}
if (!supportsService) {
logger.trace("Element " + getProvider().getName() + " doesn't support service " + service.getName() + " in the network " + network);
return false;
}
@ -412,10 +418,6 @@ public class VpcVirtualRouterElement extends VirtualRouterElement implements Vpc
vpnCapabilities.putAll(capabilities.get(Service.Vpn));
vpnCapabilities.put(Capability.VpnTypes, "s2svpn");
capabilities.put(Service.Vpn, vpnCapabilities);
// remove firewall capability
capabilities.remove(Service.Firewall);
// add network ACL capability
final Map<Capability, String> networkACLCapabilities = new HashMap<Capability, String>();
networkACLCapabilities.put(Capability.SupportedProtocols, "tcp,udp,icmp");

View File

@ -198,25 +198,67 @@ public class FirewallManagerImpl extends ManagerBase implements FirewallService,
if (sourceCidrs != null && !sourceCidrs.isEmpty())
Collections.replaceAll(sourceCidrs, "0.0.0.0/0", network.getCidr());
return createFirewallRule(null, caller, rule.getXid(), rule.getSourcePortStart(), rule.getSourcePortEnd(), rule.getProtocol(), sourceCidrs, rule.getDestinationCidrList(),
rule.getIcmpCode(), rule.getIcmpType(), null, rule.getType(), rule.getNetworkId(), rule.getTrafficType(), rule.isDisplay());
return createFirewallRuleForNonVPC(null, caller, rule.getXid(), rule.getSourcePortStart(), rule.getSourcePortEnd(), rule.getProtocol(), sourceCidrs,
rule.getDestinationCidrList(), rule.getIcmpCode(), rule.getIcmpType(), null, rule.getType(), rule.getNetworkId(), rule.getTrafficType(), rule.isDisplay());
}
@Override
@ActionEvent(eventType = EventTypes.EVENT_FIREWALL_OPEN, eventDescription = "creating firewall rule", create = true)
public FirewallRule createIngressFirewallRule(FirewallRule rule) throws NetworkRuleConflictException {
Account caller = CallContext.current().getCallingAccount();
Account caller = CallContext.current().getCallingAccount();
Long sourceIpAddressId = rule.getSourceIpAddressId();
IPAddressVO sourceIp = getSourceIpForIngressRule(sourceIpAddressId);
return createFirewallRule(sourceIpAddressId, caller, rule.getXid(), rule.getSourcePortStart(), rule.getSourcePortEnd(), rule.getProtocol(),
rule.getSourceCidrList(), null, rule.getIcmpCode(), rule.getIcmpType(), null, rule.getType(), rule.getNetworkId(), rule.getTrafficType(), rule.isDisplay());
if (sourceIp.getVpcId() != null) {
return createIngressFirewallRuleForVpcIp(rule, caller, sourceIp);
}
return createIngressFirewallRuleForIsolatedIp(rule, caller, sourceIp);
}
protected IPAddressVO getSourceIpForIngressRule(Long sourceIpAddressId) {
if (sourceIpAddressId == null) {
return null;
}
IPAddressVO sourceIp = _ipAddressDao.findById(sourceIpAddressId);
if (sourceIp == null) {
throw new CloudRuntimeException("Unable to find IP address by id=" + sourceIpAddressId);
}
return sourceIp;
}
protected FirewallRule createIngressFirewallRuleForIsolatedIp(FirewallRule rule, Account caller, IPAddressVO sourceIp)
throws NetworkRuleConflictException {
return createFirewallRuleForNonVPC(rule.getSourceIpAddressId(), caller, rule.getXid(), rule.getSourcePortStart(), rule.getSourcePortEnd(),
rule.getProtocol(), rule.getSourceCidrList(), null, rule.getIcmpCode(), rule.getIcmpType(), null, rule.getType(),
rule.getNetworkId(), rule.getTrafficType(), rule.isDisplay());
}
protected FirewallRule createIngressFirewallRuleForVpcIp(FirewallRule rule, Account caller, IPAddressVO sourceIp)
throws NetworkRuleConflictException {
Long vpcId = sourceIp != null ? sourceIp.getVpcId() : null;
return createFirewallRuleForVpc(rule.getSourceIpAddressId(), caller, rule.getXid(), rule.getSourcePortStart(), rule.getSourcePortEnd(),
rule.getProtocol(), rule.getSourceCidrList(), null, rule.getIcmpCode(), rule.getIcmpType(), null, rule.getType(),
vpcId, rule.getTrafficType(), rule.isDisplay());
}
//Destination CIDR capability is currently implemented for egress rules only. For others, the field is passed as null.
@DB
protected FirewallRule createFirewallRule(final Long ipAddrId, Account caller, final String xId, final Integer portStart, final Integer portEnd, final String protocol,
final List<String> sourceCidrList, final List<String> destCidrList, final Integer icmpCode, final Integer icmpType, final Long relatedRuleId,
final FirewallRule.FirewallRuleType type, final Long networkId, final FirewallRule.TrafficType trafficType, final Boolean forDisplay) throws NetworkRuleConflictException {
final FirewallRule.FirewallRuleType type, final Long networkId, final Long vpcId, final FirewallRule.TrafficType trafficType, final Boolean forDisplay) throws NetworkRuleConflictException {
if (vpcId != null) {
return createFirewallRuleForVpc(ipAddrId, caller, xId, portStart, portEnd, protocol, sourceCidrList, destCidrList, icmpCode, icmpType, relatedRuleId,
type, vpcId, trafficType, forDisplay);
}
return createFirewallRuleForNonVPC(ipAddrId, caller, xId, portStart, portEnd, protocol, sourceCidrList, destCidrList, icmpCode, icmpType, relatedRuleId,
type, networkId, trafficType, forDisplay);
}
@DB
protected FirewallRule createFirewallRuleForNonVPC(final Long ipAddrId, Account caller, final String xId, final Integer portStart, final Integer portEnd, final String protocol,
final List<String> sourceCidrList, final List<String> destCidrList, final Integer icmpCode, final Integer icmpType, final Long relatedRuleId,
final FirewallRule.FirewallRuleType type, final Long networkId, final FirewallRule.TrafficType trafficType, final Boolean forDisplay) throws NetworkRuleConflictException {
IPAddressVO ipAddress = null;
try {
// Validate ip address
@ -283,6 +325,158 @@ public class FirewallManagerImpl extends ManagerBase implements FirewallService,
}
}
@DB
protected FirewallRule createFirewallRuleForVpc(final Long ipAddrId, Account caller, final String xId, final Integer portStart, final Integer portEnd, final String protocol,
final List<String> sourceCidrList, final List<String> destCidrList, final Integer icmpCode, final Integer icmpType,
final Long relatedRuleId, final FirewallRuleType type, final Long vpcId,
final FirewallRule.TrafficType trafficType, final Boolean forDisplay) throws NetworkRuleConflictException {
IPAddressVO ipAddress = null;
try {
Long resolvedVpcId = vpcId;
if (ipAddrId != null) {
ipAddress = _ipAddressDao.acquireInLockTable(ipAddrId);
if (ipAddress == null) {
throw new InvalidParameterValueException("Unable to create firewall rule; " + "couldn't locate IP address by id in the system");
}
resolvedVpcId = resolvedVpcId != null ? resolvedVpcId : ipAddress.getVpcId();
}
if (resolvedVpcId == null) {
throw new InvalidParameterValueException("Unable to create VPC firewall rule; couldn't locate VPC id");
}
validateFirewallRuleForVpc(caller, ipAddress, portStart, portEnd, protocol, Purpose.Firewall, type, resolvedVpcId, trafficType);
if (!protocol.equalsIgnoreCase(NetUtils.ICMP_PROTO) && (icmpCode != null || icmpType != null)) {
throw new InvalidParameterValueException("Can specify icmpCode and icmpType for ICMP protocol only");
}
if (protocol.equalsIgnoreCase(NetUtils.ICMP_PROTO) && (portStart != null || portEnd != null)) {
throw new InvalidParameterValueException("Can't specify start/end port when protocol is ICMP");
}
Long accountId = null;
Long domainId = null;
if (ipAddress != null) {
accountId = ipAddress.getAllocatedToAccountId();
domainId = ipAddress.getAllocatedInDomainId();
} else {
Vpc vpc = _vpcMgr.getActiveVpc(resolvedVpcId);
if (vpc == null) {
throw new InvalidParameterValueException("Unable to create VPC firewall rule; couldn't locate VPC by id=" + resolvedVpcId);
}
accountId = vpc.getAccountId();
domainId = vpc.getDomainId();
}
final Long accountIdFinal = accountId;
final Long domainIdFinal = domainId;
final Long resolvedNetworkIdFinal = null;
final Long resolvedVpcIdFinal = resolvedVpcId;
return Transaction.execute((TransactionCallbackWithException<FirewallRuleVO, NetworkRuleConflictException>) status -> {
FirewallRuleVO newRule = new FirewallRuleVO(xId, ipAddrId, portStart, portEnd, protocol.toLowerCase(), resolvedNetworkIdFinal, accountIdFinal, domainIdFinal, Purpose.Firewall,
sourceCidrList, destCidrList, icmpCode, icmpType, relatedRuleId, trafficType);
newRule.setVpcId(resolvedVpcIdFinal);
newRule.setType(type);
if (forDisplay != null) {
newRule.setDisplay(forDisplay);
}
newRule = _firewallDao.persist(newRule);
if (type == FirewallRuleType.User)
detectRulesConflict(newRule);
if (!_firewallDao.setStateToAdd(newRule)) {
throw new CloudRuntimeException("Unable to update the state to add for " + newRule);
}
CallContext.current().setEventDetails("Rule ID: " + newRule.getUuid());
CallContext.current().putContextParameter(FirewallRule.class, newRule.getId());
return newRule;
});
} finally {
if (ipAddrId != null) {
_ipAddressDao.releaseFromLockTable(ipAddrId);
}
}
}
protected void validateFirewallRuleForVpc(Account caller, IPAddressVO ipAddress, Integer portStart, Integer portEnd, String proto, Purpose purpose,
FirewallRuleType type, Long vpcId, FirewallRule.TrafficType trafficType) {
if (portStart != null && !NetUtils.isValidPort(portStart)) {
throw new InvalidParameterValueException("publicPort is an invalid value: " + portStart);
}
if (portEnd != null && !NetUtils.isValidPort(portEnd)) {
throw new InvalidParameterValueException("Public port range is an invalid value: " + portEnd);
}
if (portStart != null && portEnd != null && portStart > portEnd) {
throw new InvalidParameterValueException("Start port can't be bigger than end port");
}
if (ipAddress == null && type == FirewallRuleType.System) {
return;
}
if (vpcId == null) {
throw new InvalidParameterValueException("Unable to retrieve VPC id to validate the rule");
}
if (ipAddress != null) {
_accountMgr.checkAccess(caller, null, true, ipAddress);
}
Vpc vpc = _vpcMgr.getActiveVpc(vpcId);
if (vpc == null) {
throw new InvalidParameterValueException("Unable to retrieve VPC to validate the rule by id=" + vpcId);
}
Map<Network.Capability, String> caps = null;
if (purpose == Purpose.Firewall) {
caps = getFirewallServiceCapabilitiesForVpc(vpcId);
}
if (caps != null) {
String supportedTrafficTypes = null;
if (purpose == FirewallRule.Purpose.Firewall) {
supportedTrafficTypes = caps.get(Capability.SupportedTrafficDirection).toLowerCase();
}
String supportedProtocols;
if (purpose == FirewallRule.Purpose.Firewall && trafficType == FirewallRule.TrafficType.Egress) {
supportedProtocols = caps.get(Capability.SupportedEgressProtocols).toLowerCase();
} else {
supportedProtocols = caps.get(Capability.SupportedProtocols).toLowerCase();
}
if (!supportedProtocols.contains(proto.toLowerCase())) {
throw new InvalidParameterValueException("Protocol " + proto + " is not supported in VPC " + vpcId);
} else if (proto.equalsIgnoreCase(NetUtils.ICMP_PROTO) && purpose != Purpose.Firewall) {
throw new InvalidParameterValueException("Protocol " + proto + " is currently supported only for rules with purpose " + Purpose.Firewall);
} else if (purpose == Purpose.Firewall && !supportedTrafficTypes.contains(trafficType.toString().toLowerCase())) {
throw new InvalidParameterValueException(String.format("Traffic Type %s is currently supported by Firewall in VPC %s", trafficType, vpc.getUuid()));
}
}
}
protected Map<Network.Capability, String> getFirewallServiceCapabilitiesForVpc(Long vpcId) {
for (FirewallServiceProvider fwElement : _firewallElements) {
Network.Provider provider = fwElement.getProvider();
if (_vpcMgr.isProviderSupportServiceInVpc(vpcId, Service.Firewall, provider)) {
Map<Service, Map<Capability, String>> capabilities = fwElement.getCapabilities();
if (capabilities != null && capabilities.get(Service.Firewall) != null) {
return capabilities.get(Service.Firewall);
}
}
}
return null;
}
protected Long resolveIsolatedFirewallRuleNetworkId(IPAddressVO ipAddress, Long networkId) {
_networkModel.checkIpForService(ipAddress, Service.Firewall, networkId);
return ipAddress.getAssociatedWithNetworkId();
}
@Override
public Pair<List<? extends FirewallRule>, Integer> listFirewallRules(IListFirewallRulesCmd cmd) {
Long ipId = cmd.getIpAddressId();
@ -399,9 +593,16 @@ public class FirewallManagerImpl extends ManagerBase implements FirewallService,
assert (rules.size() >= 1);
}
NetworkVO newRuleNetwork = getNewRuleNetwork(newRule);
boolean newRuleIsOnVpcNetwork = newRuleNetwork.getVpcId() != null;
boolean vpcConserveModeEnabled = _vpcMgr.isNetworkOnVpcEnabledConserveMode(newRuleNetwork);
Long newRuleVpcId = newRule.getVpcId();
boolean newRuleIsVpc = newRuleVpcId != null;
NetworkVO newRuleNetwork = null;
boolean newRuleIsOnVpcNetwork = false;
boolean vpcConserveModeEnabled = false;
if (!newRuleIsVpc) {
newRuleNetwork = getNewRuleNetwork(newRule);
newRuleIsOnVpcNetwork = newRuleNetwork.getVpcId() != null;
vpcConserveModeEnabled = newRuleIsOnVpcNetwork && _vpcMgr.isNetworkOnVpcEnabledConserveMode(newRuleNetwork);
}
for (FirewallRuleVO rule : rules) {
if (rule.getId() == newRule.getId()) {
@ -452,8 +653,8 @@ public class FirewallManagerImpl extends ManagerBase implements FirewallService,
// Checking if the rule applied is to the same network that is passed in the rule.
// (except for VPCs with conserve mode = true)
if ((!newRuleIsOnVpcNetwork || !vpcConserveModeEnabled)
&& rule.getNetworkId() != newRule.getNetworkId() && rule.getState() != State.Revoke) {
if (!newRuleIsVpc && (!newRuleIsOnVpcNetwork || !vpcConserveModeEnabled)
&& !Objects.equals(rule.getNetworkId(), newRule.getNetworkId()) && rule.getState() != State.Revoke) {
String errMsg = String.format("New rule is for a different network than what's specified in rule %s", rule.getXid());
if (newRuleIsOnVpcNetwork) {
Vpc vpc = _vpcMgr.getActiveVpc(newRuleNetwork.getVpcId());
@ -575,11 +776,9 @@ public class FirewallManagerImpl extends ManagerBase implements FirewallService,
}
if (ipAddress != null) {
if (ipAddress.getAssociatedWithNetworkId() == null) {
throw new InvalidParameterValueException("Unable to create firewall rule ; ip with specified id is not associated with any network");
} else {
networkId = ipAddress.getAssociatedWithNetworkId();
}
networkId = isVpcIpAddress(ipAddress)
? validateFirewallRuleForVpcIp(ipAddress, networkId)
: validateFirewallRuleForIsolatedIp(ipAddress);
// Validate ip address
_accountMgr.checkAccess(caller, null, true, ipAddress);
@ -610,7 +809,7 @@ public class FirewallManagerImpl extends ManagerBase implements FirewallService,
if (routedIpv4Manager.isVirtualRouterGateway(network)) {
throw new CloudRuntimeException("Unable to create routing firewall rule. Please use routing firewall API instead.");
}
caps = _networkModel.getNetworkServiceCapabilities(network.getId(), Service.Firewall);
caps = getFirewallServiceCapabilities(network);
}
if (caps != null) {
@ -637,6 +836,41 @@ public class FirewallManagerImpl extends ManagerBase implements FirewallService,
}
protected boolean isVpcIpAddress(IPAddressVO ipAddress) {
return ipAddress.getVpcId() != null;
}
protected Long validateFirewallRuleForIsolatedIp(IPAddressVO ipAddress) {
if (ipAddress.getAssociatedWithNetworkId() == null) {
throw new InvalidParameterValueException("Unable to create firewall rule ; ip with specified id is not associated with any network");
}
return ipAddress.getAssociatedWithNetworkId();
}
protected Long validateFirewallRuleForVpcIp(IPAddressVO ipAddress, Long networkId) {
if (networkId == null) {
throw new InvalidParameterValueException("Unable to retrieve network id to validate the rule");
}
return networkId;
}
protected Map<Network.Capability, String> getFirewallServiceCapabilities(Network network) {
if (network.getVpcId() == null) {
return _networkModel.getNetworkServiceCapabilities(network.getId(), Service.Firewall);
}
for (FirewallServiceProvider fwElement : _firewallElements) {
Network.Provider provider = fwElement.getProvider();
if (_vpcMgr.isProviderSupportServiceInVpc(network.getVpcId(), Service.Firewall, provider)) {
Map<Service, Map<Capability, String>> capabilities = fwElement.getCapabilities();
if (capabilities != null && capabilities.get(Service.Firewall) != null) {
return capabilities.get(Service.Firewall);
}
}
}
return _networkModel.getNetworkServiceCapabilities(network.getId(), Service.Firewall);
}
@Override
public boolean applyRules(List<? extends FirewallRule> rules, boolean continueOnError, boolean updateRulesInDB) throws ResourceUnavailableException {
boolean success = true;
@ -665,7 +899,7 @@ public class FirewallManagerImpl extends ManagerBase implements FirewallService,
success = false;
} else {
removeRule(rule);
if (rule.getSourceIpAddressId() != null) {
if (rule.getSourceIpAddressId() != null && rule.getVpcId() == null) {
//if the rule is the last one for the ip address assigned to VPC, unassign it from the network
_vpcMgr.unassignIPFromVpcNetwork(rule.getSourceIpAddressId(), rule.getNetworkId());
}
@ -692,7 +926,12 @@ public class FirewallManagerImpl extends ManagerBase implements FirewallService,
case Ipv6Firewall:
for (FirewallServiceProvider fwElement : _firewallElements) {
Network.Provider provider = fwElement.getProvider();
boolean isFwProvider = _networkModel.isProviderSupportServiceInNetwork(network.getId(), Service.Firewall, provider);
boolean isFwProvider;
if (network.getVpcId() != null) {
isFwProvider = _vpcMgr.isProviderSupportServiceInVpc(network.getVpcId(), Service.Firewall, provider);
} else {
isFwProvider = _networkModel.isProviderSupportServiceInNetwork(network.getId(), Service.Firewall, provider);
}
if (!isFwProvider) {
continue;
}
@ -779,8 +1018,10 @@ public class FirewallManagerImpl extends ManagerBase implements FirewallService,
for (FirewallRuleVO rule : rules) {
// validate rule - for NSX
long networkId = rule.getNetworkId();
validateNsxConstraints(networkId, rule);
Long networkId = rule.getNetworkId();
if (networkId != null) {
validateNsxConstraints(networkId, rule);
}
// load cidrs if any
rule.setSourceCidrList(_firewallCidrsDao.getSourceCidrs(rule.getId()));
rule.setDestinationCidrsList(_firewallDcidrsDao.getDestCidrs(rule.getId()));
@ -1040,7 +1281,7 @@ public class FirewallManagerImpl extends ManagerBase implements FirewallService,
List<String> oneCidr = new ArrayList<String>();
oneCidr.add(NetUtils.ALL_IP4_CIDRS);
return createFirewallRule(ipAddrId, caller, null, startPort, endPort, protocol, oneCidr, null, icmpCode, icmpType, relatedRuleId, FirewallRule.FirewallRuleType.User,
networkId, FirewallRule.TrafficType.Ingress, true);
networkId, null, FirewallRule.TrafficType.Ingress, true);
}
@Override
@ -1155,7 +1396,7 @@ public class FirewallManagerImpl extends ManagerBase implements FirewallService,
_firewallDao.loadSourceCidrs(rule);
}
createFirewallRule(ip.getId(), acct, rule.getXid(), rule.getSourcePortStart(), rule.getSourcePortEnd(), rule.getProtocol(), rule.getSourceCidrList(),null,
rule.getIcmpCode(), rule.getIcmpType(), rule.getRelated(), FirewallRuleType.System, rule.getNetworkId(), rule.getTrafficType(), true);
rule.getIcmpCode(), rule.getIcmpType(), rule.getRelated(), FirewallRuleType.System, rule.getNetworkId(), rule.getVpcId(), rule.getTrafficType(), true);
} catch (Exception e) {
logger.debug("Failed to add system wide firewall rule, due to:" + e.toString());
}

View File

@ -332,7 +332,7 @@ public class VpcManagerImpl extends ManagerBase implements VpcManager, VpcProvis
private final ScheduledExecutorService _executor = Executors.newScheduledThreadPool(1, new NamedThreadFactory("VpcChecker"));
private List<VpcProvider> vpcElements = null;
private final List<Service> nonSupportedServices = Arrays.asList(Service.SecurityGroup, Service.Firewall);
private final List<Service> nonSupportedServices = Arrays.asList(Service.SecurityGroup);
private final List<Provider> supportedProviders = Arrays.asList(Provider.VPCVirtualRouter, Provider.NiciraNvp, Provider.InternalLbVm, Provider.Netscaler,
Provider.JuniperContrailVpcRouter, Provider.Ovs, Provider.BigSwitchBcf, Provider.ConfigDrive, Provider.Nsx, Provider.Netris);

View File

@ -989,15 +989,15 @@ public class RoutedIpv4ManagerImpl extends ComponentLifecycleBase implements Rou
@Override
public boolean isVirtualRouterGateway(Network network) {
return isRoutedNetwork(network)
&& (networkServiceMapDao.canProviderSupportServiceInNetwork(network.getId(), Service.Gateway, Provider.VirtualRouter))
|| networkServiceMapDao.canProviderSupportServiceInNetwork(network.getId(), Service.Gateway, Provider.VPCVirtualRouter);
&& (networkServiceMapDao.canProviderSupportServiceInNetwork(network.getId(), Service.Gateway, Provider.VirtualRouter)
|| networkServiceMapDao.canProviderSupportServiceInNetwork(network.getId(), Service.Gateway, Provider.VPCVirtualRouter));
}
@Override
public boolean isVirtualRouterGateway(NetworkOffering networkOffering) {
return NetworkOffering.NetworkMode.ROUTED.equals(networkOffering.getNetworkMode())
&& networkOfferingServiceMapDao.canProviderSupportServiceInNetworkOffering(networkOffering.getId(), Service.Gateway, Provider.VirtualRouter)
|| networkOfferingServiceMapDao.canProviderSupportServiceInNetworkOffering(networkOffering.getId(), Service.Gateway, Provider.VPCVirtualRouter);
&& (networkOfferingServiceMapDao.canProviderSupportServiceInNetworkOffering(networkOffering.getId(), Service.Gateway, Provider.VirtualRouter)
|| networkOfferingServiceMapDao.canProviderSupportServiceInNetworkOffering(networkOffering.getId(), Service.Gateway, Provider.VPCVirtualRouter));
}
@Override

View File

@ -19,10 +19,13 @@ package com.cloud.network;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ -54,8 +57,12 @@ import com.cloud.network.dao.IPAddressDao;
import com.cloud.network.dao.IPAddressVO;
import com.cloud.network.dao.NetworkDao;
import com.cloud.network.dao.NetworkVO;
import com.cloud.network.rules.FirewallRule;
import com.cloud.network.rules.StaticNat;
import com.cloud.network.rules.StaticNatImpl;
import com.cloud.network.vpc.VpcManager;
import com.cloud.dc.dao.VlanDao;
import com.cloud.dc.VlanVO;
import com.cloud.offerings.NetworkOfferingVO;
import com.cloud.offerings.dao.NetworkOfferingDao;
import com.cloud.user.AccountVO;
@ -105,6 +112,12 @@ public class IpAddressManagerTest {
@Mock
AccountManager accountManagerMock;
@Mock
VpcManager vpcMgr;
@Mock
VlanDao vlanDao;
final long dummyID = 1L;
final String UUID = "uuid";
@ -491,4 +504,166 @@ public class IpAddressManagerTest {
Assert.assertTrue(result);
}
private FirewallRule makeRule(Long networkId, Long vpcId, FirewallRule.Purpose purpose,
FirewallRule.TrafficType trafficType) {
FirewallRule rule = mock(FirewallRule.class);
lenient().when(rule.getNetworkId()).thenReturn(networkId);
lenient().when(rule.getVpcId()).thenReturn(vpcId);
lenient().when(rule.getPurpose()).thenReturn(purpose);
lenient().when(rule.getTrafficType()).thenReturn(trafficType);
return rule;
}
/** Stub the two IP-association helper methods so they are no-ops. */
private void stubIpAssocHelpers() throws ResourceUnavailableException {
doReturn(false).when(ipAddressManager).checkIfIpAssocRequired(any(Network.class), anyBoolean(), any());
}
/**
* Test: Non-VPC rules still resolve via networkId (backward compatibility).
*/
@Test
public void applyRulesNonVpcRuleStillWorksViaNetworkId() throws ResourceUnavailableException {
long networkId = 10L;
NetworkVO network = mock(NetworkVO.class);
when(network.getId()).thenReturn(networkId);
when(networkDao.findById(networkId)).thenReturn(network);
FirewallRule rule = makeRule(networkId, null, FirewallRule.Purpose.Firewall, FirewallRule.TrafficType.Ingress);
NetworkRuleApplier applier = mock(NetworkRuleApplier.class);
when(ipAddressDao.listByAssociatedNetwork(networkId, null)).thenReturn(new ArrayList<>());
stubIpAssocHelpers();
boolean result = ipAddressManager.applyRules(
Collections.singletonList(rule), FirewallRule.Purpose.Firewall, applier, false);
assertTrue(result);
verify(networkDao).findById(networkId);
verify(applier).applyRules(network, FirewallRule.Purpose.Firewall, Collections.singletonList(rule));
}
/**
* Test: VPC rule resolves network via VpcManager.getVpcNetworks()
* when networkId is null but vpcId is set.
*/
@Test
public void applyRulesVpcRuleResolvesNetworkViaVpcManager() throws ResourceUnavailableException {
long vpcId = 20L;
long resolvedNetworkId = 30L;
NetworkVO resolvedNetwork = mock(NetworkVO.class);
when(resolvedNetwork.getId()).thenReturn(resolvedNetworkId);
when(networkDao.findById(resolvedNetworkId)).thenReturn(resolvedNetwork);
doReturn(Collections.singletonList(resolvedNetwork)).when(vpcMgr).getVpcNetworks(vpcId);
FirewallRule rule = makeRule(null, vpcId, FirewallRule.Purpose.Firewall, FirewallRule.TrafficType.Ingress);
NetworkRuleApplier applier = mock(NetworkRuleApplier.class);
IPAddressVO vpcIp = mock(IPAddressVO.class);
when(vpcIp.getVlanId()).thenReturn(1L);
VlanVO vlan = mock(VlanVO.class);
when(ipAddressDao.listByAssociatedVpc(vpcId, null)).thenReturn(Collections.singletonList(vpcIp));
when(vlanDao.findById(1L)).thenReturn(vlan);
stubIpAssocHelpers();
boolean result = ipAddressManager.applyRules(
Collections.singletonList(rule), FirewallRule.Purpose.Firewall, applier, false);
assertTrue(result);
verify(vpcMgr).getVpcNetworks(vpcId);
verify(ipAddressDao).listByAssociatedVpc(vpcId, null);
verify(applier).applyRules(resolvedNetwork, FirewallRule.Purpose.Firewall, Collections.singletonList(rule));
}
/**
* Test: For VPC egress firewall rules, IP collection should be skipped.
*/
@Test
public void applyRulesVpcEgressFirewallRuleSkipsIpCollection() throws ResourceUnavailableException {
long vpcId = 20L;
long resolvedNetworkId = 30L;
NetworkVO resolvedNetwork = mock(NetworkVO.class);
when(resolvedNetwork.getId()).thenReturn(resolvedNetworkId);
when(networkDao.findById(resolvedNetworkId)).thenReturn(resolvedNetwork);
doReturn(Collections.singletonList(resolvedNetwork)).when(vpcMgr).getVpcNetworks(vpcId);
FirewallRule rule = makeRule(null, vpcId, FirewallRule.Purpose.Firewall, FirewallRule.TrafficType.Egress);
NetworkRuleApplier applier = mock(NetworkRuleApplier.class);
stubIpAssocHelpers();
boolean result = ipAddressManager.applyRules(
Collections.singletonList(rule), FirewallRule.Purpose.Firewall, applier, false);
assertTrue(result);
verify(ipAddressDao, never()).listByAssociatedVpc(anyLong(), any());
verify(applier).applyRules(resolvedNetwork, FirewallRule.Purpose.Firewall, Collections.singletonList(rule));
}
/**
* Test: VPC ingress firewall rules collect public IPs from VPC (listByAssociatedVpc),
* NOT from network (listByAssociatedNetwork).
*/
@Test
public void applyRulesVpcIngressRuleCollectsIpsFromVpcNotNetwork() throws ResourceUnavailableException {
long vpcId = 20L;
long resolvedNetworkId = 30L;
NetworkVO resolvedNetwork = mock(NetworkVO.class);
when(resolvedNetwork.getId()).thenReturn(resolvedNetworkId);
when(networkDao.findById(resolvedNetworkId)).thenReturn(resolvedNetwork);
doReturn(Collections.singletonList(resolvedNetwork)).when(vpcMgr).getVpcNetworks(vpcId);
IPAddressVO vpcIp = mock(IPAddressVO.class);
when(vpcIp.getVlanId()).thenReturn(1L);
VlanVO vlan = mock(VlanVO.class);
when(ipAddressDao.listByAssociatedVpc(vpcId, null)).thenReturn(Collections.singletonList(vpcIp));
when(vlanDao.findById(1L)).thenReturn(vlan);
stubIpAssocHelpers();
NetworkRuleApplier applier = mock(NetworkRuleApplier.class);
FirewallRule rule = makeRule(null, vpcId, FirewallRule.Purpose.Firewall, FirewallRule.TrafficType.Ingress);
ipAddressManager.applyRules(Collections.singletonList(rule), FirewallRule.Purpose.Firewall, applier, false);
verify(ipAddressDao).listByAssociatedVpc(vpcId, null);
verify(ipAddressDao, never()).listByAssociatedNetwork(anyLong(), any());
}
/**
* Test: Error handling respects continueOnError flag.
* When continueOnError=true, exceptions are caught and false is returned.
*/
@Test
public void applyRulesVpcRuleErrorHandlingWithContinueOnErrorTrue() throws ResourceUnavailableException {
long vpcId = 20L;
long resolvedNetworkId = 30L;
NetworkVO resolvedNetwork = mock(NetworkVO.class);
when(resolvedNetwork.getId()).thenReturn(resolvedNetworkId);
when(networkDao.findById(resolvedNetworkId)).thenReturn(resolvedNetwork);
doReturn(Collections.singletonList(resolvedNetwork)).when(vpcMgr).getVpcNetworks(vpcId);
IPAddressVO vpcIp = mock(IPAddressVO.class);
when(vpcIp.getVlanId()).thenReturn(1L);
VlanVO vlan = mock(VlanVO.class);
when(ipAddressDao.listByAssociatedVpc(vpcId, null)).thenReturn(Collections.singletonList(vpcIp));
when(vlanDao.findById(1L)).thenReturn(vlan);
stubIpAssocHelpers();
NetworkRuleApplier applier = mock(NetworkRuleApplier.class);
when(applier.applyRules(any(), any(), any())).thenThrow(new ResourceUnavailableException("test", Network.class, 0L));
FirewallRule rule = makeRule(null, vpcId, FirewallRule.Purpose.Firewall, FirewallRule.TrafficType.Ingress);
boolean result = ipAddressManager.applyRules(
Collections.singletonList(rule), FirewallRule.Purpose.Firewall, applier, true);
assertFalse(result);
}
}

View File

@ -19,10 +19,13 @@ package com.cloud.network.element;
import com.cloud.dc.DataCenterVO;
import com.cloud.dc.dao.DataCenterDao;
import com.cloud.exception.ResourceUnavailableException;
import com.cloud.network.Network;
import com.cloud.network.NetworkModel;
import com.cloud.network.RemoteAccessVpn;
import com.cloud.network.VpnUser;
import com.cloud.network.router.VpcVirtualNetworkApplianceManagerImpl;
import com.cloud.network.vpc.Vpc;
import com.cloud.network.vpc.VpcManager;
import com.cloud.network.vpc.dao.VpcDao;
import com.cloud.utils.db.EntityManager;
import com.cloud.vm.DomainRouterVO;
@ -43,7 +46,9 @@ import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ -60,6 +65,12 @@ public class VpcVirtualRouterElementTest {
@Mock
EntityManager _entityMgr;
@Mock
NetworkModel _networkMdl;
@Mock
VpcManager _vpcMgr;
@Mock
NetworkTopologyContext networkTopologyContext;
@ -188,4 +199,18 @@ public class VpcVirtualRouterElementTest {
verify(remoteAccessVpn, times(1)).getVpcId();
}
@Test
public void testCanHandleFirewallUsesVpcCapability() {
final Network network = Mockito.mock(Network.class);
when(_networkMdl.getPhysicalNetworkId(network)).thenReturn(1L);
when(network.getVpcId()).thenReturn(100L);
when(_networkMdl.isProviderEnabledInPhysicalNetwork(1L, Network.Provider.VPCVirtualRouter.getName())).thenReturn(true);
when(_vpcMgr.isProviderSupportServiceInVpc(100L, Network.Service.Firewall, Network.Provider.VPCVirtualRouter)).thenReturn(true);
assertTrue(vpcVirtualRouterElement.canHandle(network, Network.Service.Firewall));
verify(_vpcMgr).isProviderSupportServiceInVpc(100L, Network.Service.Firewall, Network.Provider.VPCVirtualRouter);
verify(_networkMdl, never()).isProviderSupportServiceInNetwork(network.getId(), Network.Service.Firewall, Network.Provider.VPCVirtualRouter);
}
}

View File

@ -17,27 +17,36 @@
package com.cloud.network.firewall;
import com.cloud.exception.InvalidParameterValueException;
import com.cloud.exception.NetworkRuleConflictException;
import com.cloud.exception.ResourceUnavailableException;
import com.cloud.network.IpAddressManager;
import com.cloud.network.Network;
import com.cloud.network.Network.Capability;
import com.cloud.network.Network.Service;
import com.cloud.network.NetworkModel;
import com.cloud.network.NetworkRuleApplier;
import com.cloud.network.dao.FirewallRulesDao;
import com.cloud.network.dao.IPAddressDao;
import com.cloud.network.dao.IPAddressVO;
import com.cloud.network.dao.NetworkDao;
import com.cloud.network.dao.NetworkVO;
import com.cloud.network.element.FirewallServiceProvider;
import com.cloud.network.element.VirtualRouterElement;
import com.cloud.network.element.VpcVirtualRouterElement;
import com.cloud.network.rules.FirewallRule;
import com.cloud.network.rules.FirewallRule.FirewallRuleType;
import com.cloud.network.rules.FirewallRule.Purpose;
import com.cloud.network.rules.FirewallRuleVO;
import com.cloud.network.vpc.Vpc;
import com.cloud.network.vpc.VpcManager;
import com.cloud.user.Account;
import com.cloud.user.AccountManager;
import com.cloud.user.DomainManager;
import com.cloud.utils.component.ComponentContext;
import com.cloud.utils.exception.CloudRuntimeException;
import org.apache.cloudstack.engine.orchestration.service.NetworkOrchestrationService;
import org.apache.cloudstack.network.RoutedIpv4Manager;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
@ -53,12 +62,18 @@ import org.mockito.junit.MockitoJUnitRunner;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ -79,9 +94,13 @@ public class FirewallManagerTest {
@Mock
IpAddressManager _ipAddrMgr;
@Mock
RoutedIpv4Manager routedIpv4Manager;
@Mock
FirewallRulesDao _firewallDao;
@Mock
NetworkDao _networkDao;
@Mock
IPAddressDao _ipAddressDao;
@Spy
@InjectMocks
@ -115,7 +134,7 @@ public class FirewallManagerTest {
}
private FirewallRule createFirewallRule(int startPort, int endPort, Purpose purpose) {
return new FirewallRuleVO("xid", 1L, startPort, endPort, "TCP", 2, 3, 4, purpose, new ArrayList<>(),
return new FirewallRuleVO("xid", 1L, startPort, endPort, "TCP", 2L, 3, 4, purpose, new ArrayList<>(),
new ArrayList<>(), 5, 6, null, FirewallRule.TrafficType.Ingress);
}
@ -332,4 +351,357 @@ public class FirewallManagerTest {
Assert.assertFalse(result);
}
@Test
public void testValidateFirewallRuleVpcWithoutAssociatedNetworkUsesVpcCapabilities() {
final Account caller = Mockito.mock(Account.class);
final IPAddressVO ipAddress = Mockito.mock(IPAddressVO.class);
final NetworkVO network = Mockito.mock(NetworkVO.class);
final FirewallServiceProvider firewallServiceProvider = Mockito.mock(FirewallServiceProvider.class);
final Map<Capability, String> firewallCaps = new HashMap<>();
final Map<Service, Map<Capability, String>> capabilities = new HashMap<>();
firewallCaps.put(Capability.SupportedTrafficDirection, "ingress, egress");
firewallCaps.put(Capability.SupportedProtocols, "tcp,udp,icmp");
firewallCaps.put(Capability.SupportedEgressProtocols, "tcp,udp,icmp");
capabilities.put(Service.Firewall, firewallCaps);
when(ipAddress.getVpcId()).thenReturn(10L);
when(_networkModel.getNetwork(2L)).thenReturn(network);
when(network.getVpcId()).thenReturn(10L);
when(routedIpv4Manager.isVirtualRouterGateway(network)).thenReturn(false);
when(firewallServiceProvider.getProvider()).thenReturn(Network.Provider.VPCVirtualRouter);
when(firewallServiceProvider.getCapabilities()).thenReturn(capabilities);
when(_vpcMgr.isProviderSupportServiceInVpc(10L, Service.Firewall, Network.Provider.VPCVirtualRouter)).thenReturn(true);
_firewallMgr._firewallElements = List.of(firewallServiceProvider);
_firewallMgr.validateFirewallRule(caller, ipAddress, 80, 80, "tcp", Purpose.Firewall, FirewallRuleType.User, 2L, FirewallRule.TrafficType.Ingress);
verify(_networkModel, Mockito.never()).getNetworkServiceCapabilities(Mockito.anyLong(), Mockito.eq(Service.Firewall));
}
@Test
public void testIsVpcIpAddressReturnsTrueWhenVpcIdPresent() {
IPAddressVO ipAddress = Mockito.mock(IPAddressVO.class);
when(ipAddress.getVpcId()).thenReturn(5L);
Assert.assertTrue(_firewallMgr.isVpcIpAddress(ipAddress));
}
@Test
public void testIsVpcIpAddressReturnsFalseWhenVpcIdNull() {
IPAddressVO ipAddress = Mockito.mock(IPAddressVO.class);
when(ipAddress.getVpcId()).thenReturn(null);
Assert.assertFalse(_firewallMgr.isVpcIpAddress(ipAddress));
}
@Test
public void testValidateFirewallRuleForIsolatedIpReturnsNetworkId() {
IPAddressVO ipAddress = Mockito.mock(IPAddressVO.class);
when(ipAddress.getAssociatedWithNetworkId()).thenReturn(42L);
Long result = _firewallMgr.validateFirewallRuleForIsolatedIp(ipAddress);
Assert.assertEquals(Long.valueOf(42L), result);
}
@Test(expected = InvalidParameterValueException.class)
public void testValidateFirewallRuleForIsolatedIpThrowsWhenNotAssociated() {
IPAddressVO ipAddress = Mockito.mock(IPAddressVO.class);
when(ipAddress.getAssociatedWithNetworkId()).thenReturn(null);
_firewallMgr.validateFirewallRuleForIsolatedIp(ipAddress);
}
@Test
public void testValidateFirewallRuleForVpcIpReturnsNetworkId() {
IPAddressVO ipAddress = Mockito.mock(IPAddressVO.class);
Long result = _firewallMgr.validateFirewallRuleForVpcIp(ipAddress, 99L);
Assert.assertEquals(Long.valueOf(99L), result);
}
@Test(expected = InvalidParameterValueException.class)
public void testValidateFirewallRuleForVpcIpThrowsWhenNetworkIdNull() {
IPAddressVO ipAddress = Mockito.mock(IPAddressVO.class);
_firewallMgr.validateFirewallRuleForVpcIp(ipAddress, null);
}
@Test
public void testGetFirewallServiceCapabilitiesForNonVpcNetworkUsesNetworkModel() {
NetworkVO network = Mockito.mock(NetworkVO.class);
when(network.getId()).thenReturn(1L);
when(network.getVpcId()).thenReturn(null);
Map<Capability, String> caps = new HashMap<>();
caps.put(Capability.SupportedProtocols, "tcp,udp");
when(_networkModel.getNetworkServiceCapabilities(1L, Service.Firewall)).thenReturn(caps);
Map<Network.Capability, String> result = _firewallMgr.getFirewallServiceCapabilities(network);
Assert.assertEquals(caps, result);
verify(_networkModel, times(1)).getNetworkServiceCapabilities(1L, Service.Firewall);
}
@Test
public void testGetFirewallServiceCapabilitiesForVpcNetworkUsesVpcProvider() {
NetworkVO network = Mockito.mock(NetworkVO.class);
FirewallServiceProvider fwProvider = Mockito.mock(FirewallServiceProvider.class);
Map<Capability, String> firewallCaps = new HashMap<>();
firewallCaps.put(Capability.SupportedProtocols, "tcp,udp,icmp");
Map<Service, Map<Capability, String>> providerCapabilities = new HashMap<>();
providerCapabilities.put(Service.Firewall, firewallCaps);
when(network.getVpcId()).thenReturn(10L);
when(fwProvider.getProvider()).thenReturn(Network.Provider.VPCVirtualRouter);
when(fwProvider.getCapabilities()).thenReturn(providerCapabilities);
when(_vpcMgr.isProviderSupportServiceInVpc(10L, Service.Firewall, Network.Provider.VPCVirtualRouter)).thenReturn(true);
_firewallMgr._firewallElements = List.of(fwProvider);
Map<Network.Capability, String> result = _firewallMgr.getFirewallServiceCapabilities(network);
Assert.assertEquals(firewallCaps, result);
verify(_networkModel, never()).getNetworkServiceCapabilities(Mockito.anyLong(), Mockito.eq(Service.Firewall));
}
@Test
public void testGetFirewallServiceCapabilitiesForVpcNetworkFallsBackToNetworkModelWhenNoProvider() {
NetworkVO network = Mockito.mock(NetworkVO.class);
FirewallServiceProvider fwProvider = Mockito.mock(FirewallServiceProvider.class);
Map<Capability, String> fallbackCaps = new HashMap<>();
when(network.getId()).thenReturn(1L);
when(network.getVpcId()).thenReturn(10L);
when(fwProvider.getProvider()).thenReturn(Network.Provider.VPCVirtualRouter);
when(_vpcMgr.isProviderSupportServiceInVpc(10L, Service.Firewall, Network.Provider.VPCVirtualRouter)).thenReturn(false);
when(_networkModel.getNetworkServiceCapabilities(1L, Service.Firewall)).thenReturn(fallbackCaps);
_firewallMgr._firewallElements = List.of(fwProvider);
Map<Network.Capability, String> result = _firewallMgr.getFirewallServiceCapabilities(network);
Assert.assertEquals(fallbackCaps, result);
verify(_networkModel, times(1)).getNetworkServiceCapabilities(1L, Service.Firewall);
}
@Test
public void testGetFirewallServiceCapabilitiesForVpcReturnsCapabilitiesWhenProviderSupports() {
FirewallServiceProvider fwProvider = Mockito.mock(FirewallServiceProvider.class);
Map<Capability, String> firewallCaps = new HashMap<>();
firewallCaps.put(Capability.SupportedProtocols, "tcp,udp");
Map<Service, Map<Capability, String>> caps = new HashMap<>();
caps.put(Service.Firewall, firewallCaps);
when(fwProvider.getProvider()).thenReturn(Network.Provider.VPCVirtualRouter);
when(fwProvider.getCapabilities()).thenReturn(caps);
when(_vpcMgr.isProviderSupportServiceInVpc(10L, Service.Firewall, Network.Provider.VPCVirtualRouter)).thenReturn(true);
_firewallMgr._firewallElements = List.of(fwProvider);
Map<Network.Capability, String> result = _firewallMgr.getFirewallServiceCapabilitiesForVpc(10L);
Assert.assertNotNull(result);
Assert.assertEquals("tcp,udp", result.get(Capability.SupportedProtocols));
}
@Test
public void testGetFirewallServiceCapabilitiesForVpcReturnsNullWhenNoProviderSupports() {
FirewallServiceProvider fwProvider = Mockito.mock(FirewallServiceProvider.class);
when(fwProvider.getProvider()).thenReturn(Network.Provider.VPCVirtualRouter);
when(_vpcMgr.isProviderSupportServiceInVpc(10L, Service.Firewall, Network.Provider.VPCVirtualRouter)).thenReturn(false);
_firewallMgr._firewallElements = List.of(fwProvider);
Map<Network.Capability, String> result = _firewallMgr.getFirewallServiceCapabilitiesForVpc(10L);
Assert.assertNull(result);
}
@Test(expected = InvalidParameterValueException.class)
public void testValidateFirewallRuleForVpcThrowsOnInvalidStartPort() {
Account caller = Mockito.mock(Account.class);
_firewallMgr.validateFirewallRuleForVpc(caller, null, -1, 80, "tcp", Purpose.Firewall, FirewallRuleType.User, 10L, FirewallRule.TrafficType.Ingress);
}
@Test(expected = InvalidParameterValueException.class)
public void testValidateFirewallRuleForVpcThrowsOnInvalidEndPort() {
Account caller = Mockito.mock(Account.class);
_firewallMgr.validateFirewallRuleForVpc(caller, null, 80, 70000, "tcp", Purpose.Firewall, FirewallRuleType.User, 10L, FirewallRule.TrafficType.Ingress);
}
@Test(expected = InvalidParameterValueException.class)
public void testValidateFirewallRuleForVpcThrowsWhenStartPortGreaterThanEndPort() {
Account caller = Mockito.mock(Account.class);
_firewallMgr.validateFirewallRuleForVpc(caller, null, 200, 100, "tcp", Purpose.Firewall, FirewallRuleType.User, 10L, FirewallRule.TrafficType.Ingress);
}
@Test
public void testValidateFirewallRuleForVpcSystemTypeWithNullIpReturnsEarly() {
// System rule type + null IP should return without further validation
Account caller = Mockito.mock(Account.class);
// Should not throw even though vpcId checks come after this
_firewallMgr.validateFirewallRuleForVpc(caller, null, 80, 80, "tcp", Purpose.Firewall, FirewallRuleType.System, 10L, FirewallRule.TrafficType.Ingress);
}
@Test(expected = InvalidParameterValueException.class)
public void testValidateFirewallRuleForVpcThrowsWhenVpcIdNullAndNotSystemRule() {
Account caller = Mockito.mock(Account.class);
IPAddressVO ipAddress = Mockito.mock(IPAddressVO.class);
_firewallMgr.validateFirewallRuleForVpc(caller, ipAddress, 80, 80, "tcp", Purpose.Firewall, FirewallRuleType.User, null, FirewallRule.TrafficType.Ingress);
}
@Test(expected = InvalidParameterValueException.class)
public void testValidateFirewallRuleForVpcThrowsWhenActiveVpcNotFound() {
Account caller = Mockito.mock(Account.class);
IPAddressVO ipAddress = Mockito.mock(IPAddressVO.class);
when(_vpcMgr.getActiveVpc(10L)).thenReturn(null);
_firewallMgr.validateFirewallRuleForVpc(caller, ipAddress, 80, 80, "tcp", Purpose.Firewall, FirewallRuleType.User, 10L, FirewallRule.TrafficType.Ingress);
}
@Test(expected = InvalidParameterValueException.class)
public void testValidateFirewallRuleForVpcThrowsOnUnsupportedProtocol() {
Account caller = Mockito.mock(Account.class);
IPAddressVO ipAddress = Mockito.mock(IPAddressVO.class);
Vpc vpc = Mockito.mock(Vpc.class);
FirewallServiceProvider fwProvider = Mockito.mock(FirewallServiceProvider.class);
Map<Capability, String> firewallCaps = new HashMap<>();
firewallCaps.put(Capability.SupportedProtocols, "tcp,udp");
firewallCaps.put(Capability.SupportedTrafficDirection, "ingress,egress");
Map<Service, Map<Capability, String>> caps = new HashMap<>();
caps.put(Service.Firewall, firewallCaps);
when(_vpcMgr.getActiveVpc(10L)).thenReturn(vpc);
when(fwProvider.getProvider()).thenReturn(Network.Provider.VPCVirtualRouter);
when(fwProvider.getCapabilities()).thenReturn(caps);
when(_vpcMgr.isProviderSupportServiceInVpc(10L, Service.Firewall, Network.Provider.VPCVirtualRouter)).thenReturn(true);
_firewallMgr._firewallElements = List.of(fwProvider);
_firewallMgr.validateFirewallRuleForVpc(caller, ipAddress, 80, 80, "gre", Purpose.Firewall, FirewallRuleType.User, 10L, FirewallRule.TrafficType.Ingress);
}
@Test
public void testValidateFirewallRuleForVpcSucceedsWithSupportedProtocolAndTrafficType() {
Account caller = Mockito.mock(Account.class);
IPAddressVO ipAddress = Mockito.mock(IPAddressVO.class);
Vpc vpc = Mockito.mock(Vpc.class);
FirewallServiceProvider fwProvider = Mockito.mock(FirewallServiceProvider.class);
Map<Capability, String> firewallCaps = new HashMap<>();
firewallCaps.put(Capability.SupportedProtocols, "tcp,udp,icmp");
firewallCaps.put(Capability.SupportedTrafficDirection, "ingress,egress");
Map<Service, Map<Capability, String>> caps = new HashMap<>();
caps.put(Service.Firewall, firewallCaps);
when(_vpcMgr.getActiveVpc(10L)).thenReturn(vpc);
when(fwProvider.getProvider()).thenReturn(Network.Provider.VPCVirtualRouter);
when(fwProvider.getCapabilities()).thenReturn(caps);
when(_vpcMgr.isProviderSupportServiceInVpc(10L, Service.Firewall, Network.Provider.VPCVirtualRouter)).thenReturn(true);
_firewallMgr._firewallElements = List.of(fwProvider);
// Should not throw
_firewallMgr.validateFirewallRuleForVpc(caller, ipAddress, 80, 80, "tcp", Purpose.Firewall, FirewallRuleType.User, 10L, FirewallRule.TrafficType.Ingress);
verify(_accountMgr, times(1)).checkAccess(caller, null, true, ipAddress);
}
@Test
public void testCreateFirewallRuleRoutesToVpcWhenVpcIdProvided() throws NetworkRuleConflictException {
Account caller = Mockito.mock(Account.class);
FirewallRule vpcRule = Mockito.mock(FirewallRule.class);
doReturn(vpcRule).when(_firewallMgr).createFirewallRuleForVpc(
Mockito.anyLong(), Mockito.eq(caller), Mockito.any(), Mockito.anyInt(), Mockito.anyInt(),
Mockito.anyString(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(),
Mockito.any(), Mockito.any(FirewallRuleType.class), Mockito.anyLong(),
Mockito.any(FirewallRule.TrafficType.class), Mockito.anyBoolean());
_firewallMgr.createFirewallRule(1L, caller, "xid", 80, 80, "tcp",
Collections.singletonList("0.0.0.0/0"), null, null, null, null,
FirewallRuleType.User, null, 10L, FirewallRule.TrafficType.Ingress, true);
verify(_firewallMgr, times(1)).createFirewallRuleForVpc(
Mockito.anyLong(), Mockito.eq(caller), Mockito.any(), Mockito.anyInt(), Mockito.anyInt(),
Mockito.anyString(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(),
Mockito.any(), Mockito.any(FirewallRuleType.class), Mockito.anyLong(),
Mockito.any(FirewallRule.TrafficType.class), Mockito.anyBoolean());
verify(_firewallMgr, never()).createFirewallRuleForNonVPC(
Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(),
Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(),
Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any());
}
@Test
public void testCreateFirewallRuleRoutesToNonVpcWhenVpcIdNull() throws NetworkRuleConflictException {
Account caller = Mockito.mock(Account.class);
FirewallRule nonVpcRule = Mockito.mock(FirewallRule.class);
doReturn(nonVpcRule).when(_firewallMgr).createFirewallRuleForNonVPC(
Mockito.any(), Mockito.eq(caller), Mockito.any(), Mockito.any(), Mockito.any(),
Mockito.anyString(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(),
Mockito.any(), Mockito.any(FirewallRuleType.class), Mockito.anyLong(),
Mockito.any(FirewallRule.TrafficType.class), Mockito.anyBoolean());
_firewallMgr.createFirewallRule(null, caller, "xid", 80, 80, "tcp",
Collections.singletonList("0.0.0.0/0"), null, null, null, null,
FirewallRuleType.User, 2L, null, FirewallRule.TrafficType.Ingress, true);
verify(_firewallMgr, times(1)).createFirewallRuleForNonVPC(
Mockito.any(), Mockito.eq(caller), Mockito.any(), Mockito.any(), Mockito.any(),
Mockito.anyString(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(),
Mockito.any(), Mockito.any(FirewallRuleType.class), Mockito.anyLong(),
Mockito.any(FirewallRule.TrafficType.class), Mockito.anyBoolean());
verify(_firewallMgr, never()).createFirewallRuleForVpc(
Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(),
Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(),
Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any());
}
@Test
public void testApplyRulesForVpcNetworkUsesVpcProviderCheck() throws ResourceUnavailableException {
Network network = Mockito.mock(Network.class);
FirewallServiceProvider fwProvider = Mockito.mock(FirewallServiceProvider.class);
List<FirewallRule> rules = new ArrayList<>();
FirewallRuleVO rule = new FirewallRuleVO("rule1", 1L, 80, 80, "tcp", 1L, 2, 3, Purpose.Firewall,
Collections.emptyList(), Collections.emptyList(), null, null, null, FirewallRule.TrafficType.Ingress);
rules.add(rule);
when(network.getVpcId()).thenReturn(10L);
when(fwProvider.getProvider()).thenReturn(Network.Provider.VPCVirtualRouter);
when(_vpcMgr.isProviderSupportServiceInVpc(10L, Service.Firewall, Network.Provider.VPCVirtualRouter)).thenReturn(true);
when(fwProvider.applyFWRules(Mockito.eq(network), Mockito.anyList())).thenReturn(true);
_firewallMgr._firewallElements = List.of(fwProvider);
boolean result = _firewallMgr.applyRules(network, Purpose.Firewall, rules);
Assert.assertTrue(result);
verify(_vpcMgr, times(1)).isProviderSupportServiceInVpc(10L, Service.Firewall, Network.Provider.VPCVirtualRouter);
verify(_networkModel, never()).isProviderSupportServiceInNetwork(Mockito.anyLong(), Mockito.eq(Service.Firewall), Mockito.any());
}
@Test
public void testApplyRulesForNonVpcNetworkUsesNetworkModelProviderCheck() throws ResourceUnavailableException {
Network network = Mockito.mock(Network.class);
FirewallServiceProvider fwProvider = Mockito.mock(FirewallServiceProvider.class);
List<FirewallRule> rules = new ArrayList<>();
FirewallRuleVO rule = new FirewallRuleVO("rule1", 1L, 80, 80, "tcp", 1L, 2, 3, Purpose.Firewall,
Collections.emptyList(), Collections.emptyList(), null, null, null, FirewallRule.TrafficType.Ingress);
rules.add(rule);
when(network.getId()).thenReturn(1L);
when(network.getVpcId()).thenReturn(null);
when(fwProvider.getProvider()).thenReturn(Network.Provider.VirtualRouter);
when(_networkModel.isProviderSupportServiceInNetwork(1L, Service.Firewall, Network.Provider.VirtualRouter)).thenReturn(true);
when(fwProvider.applyFWRules(Mockito.eq(network), Mockito.anyList())).thenReturn(true);
_firewallMgr._firewallElements = List.of(fwProvider);
boolean result = _firewallMgr.applyRules(network, Purpose.Firewall, rules);
Assert.assertTrue(result);
verify(_networkModel, times(1)).isProviderSupportServiceInNetwork(1L, Service.Firewall, Network.Provider.VirtualRouter);
verify(_vpcMgr, never()).isProviderSupportServiceInVpc(Mockito.anyLong(), Mockito.eq(Service.Firewall), Mockito.any());
}
@Test
public void testGetSourceIpForIngressRuleReturnsNullWhenIdIsNull() {
IPAddressVO result = _firewallMgr.getSourceIpForIngressRule(null);
Assert.assertNull(result);
}
@Test(expected = CloudRuntimeException.class)
public void testGetSourceIpForIngressRuleReturnsNullWhenIpIsnotPresent() {
when(_ipAddressDao.findById(1L)).thenReturn(null);
_firewallMgr.getSourceIpForIngressRule(1L);
}
}

View File

@ -703,14 +703,79 @@ class CsAcl(CsDataBag):
self.add_routing_rules()
return
desired_firewall_ips = self._get_desired_vpc_firewall_ips()
fw_chains_created = set()
for item in self.dbag:
if item == "id":
continue
if self.config.is_vpc():
if self.config.is_vpc() and not ("purpose" in self.dbag[item] and self.dbag[item]["purpose"] == "Firewall"):
self.AclDevice(self.dbag[item], self.config).create()
else:
# For VPC firewall rules, create the PREROUTING jump and chain skeleton
# once per public IP before adding the individual rule
if self.config.is_vpc() and self.dbag[item].get("purpose") == "Firewall":
src_ip = self.dbag[item].get("src_ip")
if src_ip and src_ip not in fw_chains_created:
fw = self.config.get_fw()
fw.append(["mangle", "front",
"-A PREROUTING -d %s/32 -j FIREWALL_%s" % (src_ip, src_ip)])
fw.append(["mangle", "front",
"-A FIREWALL_%s -m state --state RELATED,ESTABLISHED -j RETURN" % src_ip])
fw.append(["mangle", "",
"-A FIREWALL_%s -j DROP" % src_ip])
fw_chains_created.add(src_ip)
self.AclIP(self.dbag[item], self.config).create()
if self.config.is_vpc():
self._cleanup_removed_vpc_firewall_chains(desired_firewall_ips)
def _get_desired_vpc_firewall_ips(self):
desired_firewall_ips = set()
if not self.config.is_vpc():
return desired_firewall_ips
for item in self.dbag:
if item == "id":
continue
rule = self.dbag[item]
if rule.get("purpose") == "Firewall":
src_ip = rule.get("src_ip")
if src_ip:
desired_firewall_ips.add(src_ip)
return desired_firewall_ips
def _cleanup_removed_vpc_firewall_chains(self, desired_firewall_ips):
"""Delete FIREWALL_<ip> chain only when no firewall rule remains for that VPC public IP."""
try:
mangle_save = CsHelper.execute("iptables-save -t mangle")
existing_firewall_ips = []
for line in mangle_save:
if line.startswith(":FIREWALL_"):
chain = line.split(" ")[0][1:]
existing_firewall_ips.append(chain.replace("FIREWALL_", "", 1))
for src_ip in existing_firewall_ips:
if src_ip in desired_firewall_ips:
continue
self._delete_vpc_firewall_chain(src_ip)
except Exception as e:
logging.debug("Failed VPC firewall chain cleanup: %s", e)
def _delete_vpc_firewall_chain(self, src_ip):
chain = "FIREWALL_%s" % src_ip
try:
prerouting_rules = CsHelper.execute("iptables -t mangle -S PREROUTING")
for rule in prerouting_rules:
if ("-d %s/32" % src_ip) in rule and ("-j %s" % chain) in rule:
delete_rule = rule.replace("-A PREROUTING", "-D PREROUTING", 1)
CsHelper.execute2("iptables -t mangle %s" % delete_rule, False)
CsHelper.execute2("iptables -t mangle -F %s" % chain, False)
CsHelper.execute2("iptables -t mangle -X %s" % chain, False)
logging.info("Deleted VPC firewall chain %s as last firewall rule was removed", chain)
except Exception as e:
logging.debug("Failed deleting VPC firewall chain %s: %s", chain, e)
class CsIpv6Firewall(CsDataBag):
"""
Deal with IPv6 Firewall

View File

@ -647,6 +647,7 @@ class CsIP:
(self.address['network'], self.address['network'])])
if self.get_type() in ["public"]:
self.fw.append(
["mangle", "", "-A FORWARD -j VPN_STATS_%s" % self.dev])
self.fw.append(

View File

@ -0,0 +1,182 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Smoke tests for firewall rules on VPC public IPs."""
from nose.plugins.attrib import attr
from marvin.cloudstackTestCase import cloudstackTestCase
from marvin.lib.base import Account, FireWallRule, Network, NetworkOffering, PublicIPAddress, VPC, VpcOffering
from marvin.lib.common import get_domain, get_zone, list_publicIP
from marvin.lib.utils import cleanup_resources, wait_until
class TestVpcFirewallRules(cloudstackTestCase):
@classmethod
def setUpClass(cls):
cls.testClient = super(TestVpcFirewallRules, cls).getClsTestClient()
cls.apiclient = cls.testClient.getApiClient()
cls.services = cls.testClient.getParsedTestDataConfig()
cls.zone = get_zone(cls.apiclient, cls.testClient.getZoneForTests())
cls.domain = get_domain(cls.apiclient)
cls._cleanup = []
cls.account = Account.create(
cls.apiclient,
cls.services["account"],
domainid=cls.domain.id
)
cls._cleanup.append(cls.account)
cls.services["vpc_offering"]["supportedservices"] = "Vpn,Dhcp,Dns,SourceNat,Lb,UserData,StaticNat,NetworkACL,PortForwarding,Firewall"
cls.services["vpc_offering"]["serviceProviderList"] = {
"Vpn": "VpcVirtualRouter",
"Dhcp": "VpcVirtualRouter",
"Dns": "VpcVirtualRouter",
"SourceNat": "VpcVirtualRouter",
"Lb": "VpcVirtualRouter",
"UserData": "VpcVirtualRouter",
"StaticNat": "VpcVirtualRouter",
"NetworkACL": "VpcVirtualRouter",
"PortForwarding": "VpcVirtualRouter",
"Firewall": "VpcVirtualRouter"
}
cls.vpc_offering = VpcOffering.create(
cls.apiclient,
cls.services["vpc_offering"]
)
cls.vpc_offering.update(cls.apiclient, state="Enabled")
cls._cleanup.append(cls.vpc_offering)
network_offering = NetworkOffering.list(
cls.apiclient,
name="DefaultIsolatedNetworkOfferingForVpcNetworks"
)
cls.assertTrue(network_offering is not None and len(network_offering) > 0,
"No VPC tier network offering found")
cls.network_offering = network_offering[0]
cls.services["vpc"]["cidr"] = "10.20.30.0/24"
cls.vpc = VPC.create(
cls.apiclient,
cls.services["vpc"],
vpcofferingid=cls.vpc_offering.id,
zoneid=cls.zone.id,
account=cls.account.name,
domainid=cls.account.domainid
)
cls._cleanup.append(cls.vpc)
cls.tier = Network.create(
cls.apiclient,
services={"name": "vpc-fw-tier", "displaytext": "vpc-fw-tier"},
accountid=cls.account.name,
domainid=cls.account.domainid,
networkofferingid=cls.network_offering.id,
zoneid=cls.zone.id,
vpcid=cls.vpc.id,
gateway="10.20.30.1",
netmask="255.255.255.0"
)
cls._cleanup.append(cls.tier)
@classmethod
def tearDownClass(cls):
try:
cleanup_resources(cls.apiclient, cls._cleanup)
except Exception as e:
raise Exception("Warning: Exception during cleanup: %s" % e)
def setUp(self):
self.apiclient = self.testClient.getApiClient()
self.cleanup = []
def tearDown(self):
cleanup_resources(self.apiclient, self.cleanup)
def _wait_for_firewall_rule(self, rule_id):
rules = FireWallRule.list(self.apiclient, id=rule_id, listall=True)
if rules and len(rules) == 1:
return True, rules[0]
return False, None
@attr(tags=["advanced", "advancedns", "smoke"], required_hardware="false")
def test_01_create_firewall_rule_on_vpc_public_ip(self):
"""Verify firewall rule can be created and listed on a dedicated VPC public IP."""
public_ip = PublicIPAddress.create(
self.apiclient,
zoneid=self.zone.id,
accountid=self.account.name,
domainid=self.account.domainid,
vpcid=self.vpc.id
)
self.cleanup.append(public_ip)
firewall_rule = FireWallRule.create(
self.apiclient,
ipaddressid=public_ip.ipaddress.id,
protocol="tcp",
cidrlist=["0.0.0.0/0"],
startport=19090,
endport=19090,
vpcid=self.vpc.id
)
self.cleanup.insert(0, firewall_rule)
result, listed_rule = wait_until(2, 10, self._wait_for_firewall_rule, firewall_rule.id)
self.assertTrue(result, "Firewall rule was not listed for the VPC public IP")
self.assertEqual(listed_rule.id, firewall_rule.id)
self.assertEqual(listed_rule.ipaddressid, public_ip.ipaddress.id)
self.assertEqual(listed_rule.vpcid, self.vpc.id)
self.assertEqual(listed_rule.protocol.lower(), "tcp")
self.assertEqual(int(listed_rule.startport), 19090)
self.assertEqual(int(listed_rule.endport), 19090)
@attr(tags=["advanced", "advancedns", "smoke"], required_hardware="false")
def test_02_create_firewall_rule_on_vpc_source_nat_ip(self):
"""Verify firewall rule can be created and listed on the VPC source NAT IP."""
source_nat_ips = list_publicIP(
self.apiclient,
vpcid=self.vpc.id,
listall=True,
issourcenat=True
)
self.assertTrue(source_nat_ips is not None and len(source_nat_ips) > 0,
"No source NAT IP found for the VPC")
source_nat_ip = source_nat_ips[0]
firewall_rule = FireWallRule.create(
self.apiclient,
ipaddressid=source_nat_ip.id,
protocol="tcp",
cidrlist=["0.0.0.0/0"],
startport=19443,
endport=19443,
vpcid=self.vpc.id
)
self.cleanup.append(firewall_rule)
result, listed_rule = wait_until(2, 10, self._wait_for_firewall_rule, firewall_rule.id)
self.assertTrue(result, "Firewall rule was not listed for the VPC source NAT IP")
self.assertEqual(listed_rule.id, firewall_rule.id)
self.assertEqual(listed_rule.ipaddressid, source_nat_ip.id)
self.assertEqual(listed_rule.vpcid, self.vpc.id)
self.assertEqual(listed_rule.protocol.lower(), "tcp")
self.assertEqual(int(listed_rule.startport), 19443)
self.assertEqual(int(listed_rule.endport), 19443)

View File

@ -136,23 +136,34 @@ export default {
}
if (this.resource && this.resource.vpcid) {
const vpc = await this.fetchVpc()
const hasFirewallCapability = this.hasVpcFirewallCapability(vpc)
// VPC IPs with source nat have only VPN when VPC offering conserve mode = false
if (this.resource.issourcenat && vpc?.vpcofferingconservemode === false) {
this.tabs = this.defaultTabs.concat(this.$route.meta.tabs.filter(tab => tab.name === 'vpn'))
const tabs = this.defaultTabs.concat(this.$route.meta.tabs.filter(tab => tab.name === 'vpn'))
this.tabs = hasFirewallCapability ? this.addFirewallTab(tabs) : tabs
return
}
// VPC IPs with static nat have nothing
// VPC IPs with static nat keep existing VPN behavior and always show firewall
if (this.resource.isstaticnat) {
if (this.resource.virtualmachinetype === 'DomainRouter') {
this.tabs = this.defaultTabs.concat(this.$route.meta.tabs.filter(tab => tab.name === 'vpn'))
}
const tabs = this.addFirewallTab(this.$route.meta.tabs).map(tab => {
if (tab.name !== 'firewall') {
return tab
}
const staticNatFirewallTab = { ...tab }
delete staticNatFirewallTab.networkServiceFilter
return staticNatFirewallTab
})
this.tabs = tabs
return
}
// VPC IPs don't have firewall
let tabs = this.$route.meta.tabs.filter(tab => tab.name !== 'firewall')
// VPC IPs have all tabs; firewall is shown only if VPC has firewall capability
let tabs = this.$route.meta.tabs
if (!hasFirewallCapability) {
tabs = tabs.filter(tab => tab.name !== 'firewall')
}
const network = await this.fetchNetwork()
if (network && network.networkofferingconservemode) {
@ -168,12 +179,12 @@ export default {
this.portFWRuleCount = await this.fetchPortFWRule()
this.loadBalancerRuleCount = await this.fetchLoadBalancerRule()
// VPC IPs with PF only have PF
// VPC IPs with PF only have PF (and firewall)
if (this.portFWRuleCount > 0) {
tabs = tabs.filter(tab => tab.name !== 'loadbalancing')
}
// VPC IPs with LB rules only have LB
// VPC IPs with LB rules only have LB (and firewall)
if (this.loadBalancerRuleCount > 0) {
tabs = tabs.filter(tab => tab.name !== 'portforwarding')
}
@ -200,6 +211,17 @@ export default {
fetchAction () {
this.actions = this.$route.meta.actions || []
},
addFirewallTab (tabs) {
const firewallTab = this.$route.meta.tabs.find(tab => tab.name === 'firewall')
if (!firewallTab || tabs.some(tab => tab.name === 'firewall')) {
return tabs
}
return tabs.concat(firewallTab)
},
hasVpcFirewallCapability (vpc) {
const services = vpc?.service || []
return Array.isArray(services) && services.some(service => (service?.name || '').toLowerCase() === 'firewall')
},
fetchVpc () {
if (!this.resource.vpcid) {
return null

View File

@ -534,6 +534,12 @@ export default {
{ name: 'ConfigDrive' }
]
})
services.push({
name: 'Firewall',
provider: [
{ name: 'VpcVirtualRouter' }
]
})
services.push({
name: 'Lb',
provider: [