Fix policy rule ID and add more unit tests

This commit is contained in:
nvazquez 2023-11-13 22:59:21 -03:00
parent 1a24ba6949
commit d72829c602
No known key found for this signature in database
GPG Key ID: 656E1BCC8CB54F84
3 changed files with 44 additions and 14 deletions

View File

@ -767,37 +767,36 @@ public class NsxApiClient {
}
}
public void createSegmentDistributedFirewall(String policyName, List<NsxNetworkRule> nsxRules) {
public void createSegmentDistributedFirewall(String segmentName, List<NsxNetworkRule> nsxRules) {
try {
SecurityPolicies services = (SecurityPolicies) nsxService.apply(SecurityPolicies.class);
List<Rule> rules = getRulesForDistributedFirewall(policyName, nsxRules);
List<Rule> rules = getRulesForDistributedFirewall(segmentName, nsxRules);
SecurityPolicy policy = new SecurityPolicy.Builder()
.setDisplayName(policyName)
.setId(policyName)
.setDisplayName(segmentName)
.setId(segmentName)
.setCategory("Application")
.setRules(rules)
.build();
services.patch(DEFAULT_DOMAIN, policyName, policy);
services.patch(DEFAULT_DOMAIN, segmentName, policy);
} catch (Error error) {
ApiError ae = error.getData()._convertTo(ApiError.class);
String msg = String.format("Failed to create NSX distributed firewall policy for segment %s, due to: %s", policyName, ae.getErrorMessage());
String msg = String.format("Failed to create NSX distributed firewall policy for segment %s, due to: %s", segmentName, ae.getErrorMessage());
LOGGER.error(msg);
throw new CloudRuntimeException(msg);
}
}
private List<Rule> getRulesForDistributedFirewall(String policyName, List<NsxNetworkRule> nsxRules) {
private List<Rule> getRulesForDistributedFirewall(String segmentName, List<NsxNetworkRule> nsxRules) {
List<Rule> rules = new ArrayList<>();
for (NsxNetworkRule rule: nsxRules) {
String ruleId = String.format("%s-%s", policyName, rule.getRuleId());
String trafficType = rule.getTrafficType();
String ruleId = NsxControllerUtils.getNsxDistributedFirewallPolicyRuleId(segmentName, rule.getRuleId());
Rule ruleToAdd = new Rule.Builder()
.setAction(rule.getAclAction().toUpperCase())
.setId(ruleId)
.setDisplayName(ruleId)
.setResourceType("SecurityPolicy")
.setSourceGroups(getGroupsForTraffic(rule, trafficType, policyName, true))
.setDestinationGroups(getGroupsForTraffic(rule, trafficType, policyName, false))
.setSourceGroups(getGroupsForTraffic(rule, segmentName, true))
.setDestinationGroups(getGroupsForTraffic(rule, segmentName, false))
.setServices(List.of("ANY"))
.setScope(List.of("ANY"))
.build();
@ -806,11 +805,12 @@ public class NsxApiClient {
return rules;
}
private List<String> getGroupsForTraffic(NsxNetworkRule rule, String trafficType,
String policyName, boolean source) {
List<String> segmentGroup = List.of(String.format("%s/%s", GROUPS_PATH_PREFIX, policyName));
protected List<String> getGroupsForTraffic(NsxNetworkRule rule,
String segmentName, boolean source) {
List<String> segmentGroup = List.of(String.format("%s/%s", GROUPS_PATH_PREFIX, segmentName));
List<String> ruleCidrList = rule.getCidrList();
String trafficType = rule.getTrafficType();
if (trafficType.equalsIgnoreCase("ingress")) {
return source ? ruleCidrList : segmentGroup;
} else if (trafficType.equalsIgnoreCase("egress")) {

View File

@ -45,6 +45,10 @@ public class NsxControllerUtils {
return String.format("D%s-A%s-Z%s-%s%s-NAT", domainId, accountId, dataCenterId, resourcePrefix, resourceId);
}
public static String getNsxDistributedFirewallPolicyRuleId(String segmentName, long ruleId) {
return String.format("%s-P%s", segmentName, ruleId);
}
public NsxAnswer sendNsxCommand(NsxCommand cmd, long zoneId) throws IllegalArgumentException {
NsxProviderVO nsxProviderVO = nsxProviderDao.findByZoneId(zoneId);

View File

@ -20,6 +20,8 @@ import com.vmware.nsx_policy.infra.domains.Groups;
import com.vmware.nsx_policy.model.Group;
import com.vmware.nsx_policy.model.PathExpression;
import com.vmware.vapi.bindings.Service;
import org.apache.cloudstack.resource.NsxNetworkRule;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
@ -64,4 +66,28 @@ public class NsxApiClientTest {
Mockito.verify(pathExpressions[0]).setPaths(List.of(segmentPath));
}
}
@Test
public void testGetGroupsForTrafficIngress() {
NsxNetworkRule rule = Mockito.mock(NsxNetworkRule.class);
Mockito.when(rule.getCidrList()).thenReturn(List.of("ANY"));
Mockito.when(rule.getTrafficType()).thenReturn("Ingress");
String segmentName = "segment";
List<String> sourceGroups = client.getGroupsForTraffic(rule, segmentName, true);
List<String> destinationGroups = client.getGroupsForTraffic(rule, segmentName, false);
Assert.assertEquals(List.of("ANY"), sourceGroups);
Assert.assertEquals(List.of(String.format("%s/%s", NsxApiClient.GROUPS_PATH_PREFIX, segmentName)), destinationGroups);
}
@Test
public void testGetGroupsForTrafficEgress() {
NsxNetworkRule rule = Mockito.mock(NsxNetworkRule.class);
Mockito.when(rule.getCidrList()).thenReturn(List.of("ANY"));
Mockito.when(rule.getTrafficType()).thenReturn("Egress");
String segmentName = "segment";
List<String> sourceGroups = client.getGroupsForTraffic(rule, segmentName, true);
List<String> destinationGroups = client.getGroupsForTraffic(rule, segmentName, false);
Assert.assertEquals(List.of(String.format("%s/%s", NsxApiClient.GROUPS_PATH_PREFIX, segmentName)), sourceGroups);
Assert.assertEquals(List.of("ANY"), destinationGroups);
}
}