Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import pytest
from unittest.mock import MagicMock
from botocore.exceptions import ClientError
from tevico.engine.entities.report.check_model import (
CheckStatus,
CheckMetadata,
Remediation,
RemediationCode,
RemediationRecommendation,
)
from library.aws.checks.apigateway.apigateway_rest_api_waf_acl_attached import (
apigateway_rest_api_waf_acl_attached,
)


class TestApiGatewayRestApiWafAclAttached:
"""Test cases for API Gateway REST API WAF ACL Attached check."""

def setup_method(self):
self.metadata = CheckMetadata(
Provider="AWS",
CheckID="apigateway_rest_api_waf_acl_attached",
CheckTitle="API Gateway REST API has WAF ACL attached",
CheckType=["Security"],
ServiceName="APIGateway",
SubServiceName="REST API",
ResourceIdTemplate="arn:aws:apigateway:{region}::/restapis/{restapi_id}",
Severity="medium",
ResourceType="AWS::ApiGateway::RestApi",
Risk="APIs without WAF ACL may be vulnerable to web attacks.",
Description="Checks if API Gateway REST APIs have a WAF ACL attached.",
Remediation=Remediation(
Code=RemediationCode(CLI="", NativeIaC="", Terraform=""),
Recommendation=RemediationRecommendation(
Text="Attach a WAF ACL to API Gateway REST APIs.",
Url="https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-control-access-to-api.html"
)
)
)
self.check = apigateway_rest_api_waf_acl_attached(metadata=self.metadata)
self.mock_session = MagicMock()
self.mock_apigw = MagicMock()
self.mock_session.client.return_value = self.mock_apigw
self.mock_session.region_name = "us-west-2"

def test_no_rest_apis(self):
self.mock_apigw.get_rest_apis.return_value = {"items": []}
report = self.check.execute(self.mock_session)
assert report.status == CheckStatus.PASSED
assert len(report.resource_ids_status) == 0

def test_waf_acl_attached(self):
self.mock_apigw.get_rest_apis.return_value = {
"items": [{"id": "api-1", "name": "API 1"}]
}
self.mock_apigw.get_stages.return_value = {
"item": [{"stageName": "prod", "webAclArn": "arn:aws:wafv2:us-west-2:123456789012:regional/webacl/sample"}]
}
report = self.check.execute(self.mock_session)
assert report.status == CheckStatus.PASSED
assert report.resource_ids_status[0].status == CheckStatus.PASSED
assert "WAF is attached to stage prod of API API 1." in report.resource_ids_status[0].summary

def test_waf_acl_not_attached(self):
self.mock_apigw.get_rest_apis.return_value = {
"items": [{"id": "api-2", "name": "API 2"}]
}
self.mock_apigw.get_stages.return_value = {
"item": [{"stageName": "dev"}]
}
report = self.check.execute(self.mock_session)
assert report.status == CheckStatus.FAILED
assert report.resource_ids_status[0].status == CheckStatus.FAILED
assert "No WAF attached to stage dev of API API 2." in report.resource_ids_status[0].summary

def test_api_with_no_stages(self):
self.mock_apigw.get_rest_apis.return_value = {
"items": [{"id": "api-3", "name": "API 3"}]
}
self.mock_apigw.get_stages.return_value = {"item": []}
report = self.check.execute(self.mock_session)
assert report.status == CheckStatus.FAILED
assert report.resource_ids_status[0].status == CheckStatus.FAILED
assert "API API 3 has no stages." in report.resource_ids_status[0].summary

def test_get_stages_exception(self):
self.mock_apigw.get_rest_apis.return_value = {
"items": [{"id": "api-4", "name": "API 4"}]
}
self.mock_apigw.get_stages.side_effect = Exception("stage error")
report = self.check.execute(self.mock_session)
assert report.status == CheckStatus.UNKNOWN
assert report.resource_ids_status[0].status == CheckStatus.UNKNOWN
assert "Error fetching stages for API API 4." in report.resource_ids_status[0].summary
assert report.resource_ids_status[0].exception is not None

def test_get_rest_apis_exception(self):
self.mock_apigw.get_rest_apis.side_effect = ClientError(
{"Error": {"Code": "AccessDeniedException", "Message": "Access denied"}},
"GetRestApis"
)
report = self.check.execute(self.mock_session)
assert report.status == CheckStatus.UNKNOWN
assert report.resource_ids_status[0].status == CheckStatus.UNKNOWN
assert "API Gateway listing error occurred." in report.resource_ids_status[0].summary
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. No partial WAF coverage test Add test with one stage with webAclArn, one without
  2. No multiple API coverage Add test with two APIs (one passes, one fails)
  3. Exception field untested Assert presence of exception in the result for exceptions
  4. All APIs missing WAF Add test for multiple APIs, all stages missing WAF

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partial WAF coverage (some stages with WAF, some without)
Multiple APIs with mixed WAF coverage
Exception field coverage for get_stages and get_rest_apis
All APIs missing WAF scenario

assert report.resource_ids_status[0].exception is not None

def test_partial_waf_coverage(self):
self.mock_apigw.get_rest_apis.return_value = {
"items": [{"id": "api-6", "name": "API 6"}]
}
self.mock_apigw.get_stages.return_value = {
"item": [
{"stageName": "prod", "webAclArn": "arn:aws:waf::..."},
{"stageName": "dev"}
]
}
report = self.check.execute(self.mock_session)
assert report.status == CheckStatus.FAILED
statuses = {r.status for r in report.resource_ids_status}
assert CheckStatus.FAILED in statuses
assert CheckStatus.PASSED in statuses

def test_multiple_apis_mixed_results(self):
self.mock_apigw.get_rest_apis.return_value = {
"items": [
{"id": "api-7", "name": "API 7"},
{"id": "api-8", "name": "API 8"}
]
}

def get_stages_side_effect(**kwargs):
if kwargs["restApiId"] == "api-7":
return {"item": [{"stageName": "prod", "webAclArn": "arn:aws:waf::..."}]}
else:
return {"item": [{"stageName": "dev"}]}

self.mock_apigw.get_stages.side_effect = get_stages_side_effect

report = self.check.execute(self.mock_session)
assert report.status == CheckStatus.FAILED
statuses = {r.status for r in report.resource_ids_status}
assert CheckStatus.FAILED in statuses
assert CheckStatus.PASSED in statuses

def test_all_apis_missing_waf(self):
self.mock_apigw.get_rest_apis.return_value = {
"items": [
{"id": "api-9", "name": "API 9"},
{"id": "api-10", "name": "API 10"}
]
}
self.mock_apigw.get_stages.return_value = {
"item": [{"stageName": "stage1"}]
}

report = self.check.execute(self.mock_session)
assert report.status == CheckStatus.FAILED
assert all(r.status == CheckStatus.FAILED for r in report.resource_ids_status)