diff --git a/.github/workflows/deploy-infra.yml b/.github/workflows/deploy-infra.yml index a1d11451..c87e7c57 100644 --- a/.github/workflows/deploy-infra.yml +++ b/.github/workflows/deploy-infra.yml @@ -177,4 +177,27 @@ jobs: ParameterKey=UniqueIdentifier,ParameterValue='v1' \ ParameterKey=ResaleServiceEndpoint,ParameterValue=$RESALE_SERVICE_ENDPOINT \ ParameterKey=SenderEmail,ParameterValue=$SENDER_EMAIL \ - ParameterKey=PagerDutyEndpoint,ParameterValue=$PAGERDUTY_ENDPOINT \ No newline at end of file + ParameterKey=PagerDutyEndpoint,ParameterValue=$PAGERDUTY_ENDPOINT + + integration-tests: + needs: deploy-aws + if: github.event_name == 'push' + runs-on: ubuntu-latest + environment: ${{ github.ref_name }} + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v3 + - name: Run Integration Tests + run: | + cd tests/integration + pip install -r requirements.txt + python -m pytest test_transfer_flow.py -v + env: + API_ENDPOINT: ${{ vars.ZENOBIA_API_ENDPOINT }} + AUTH0_DOMAIN: ${{ vars.AUTH0_DOMAIN }} + AUTH0_M2M_MERCHANT_CLIENT_ID: ${{ secrets.AUTH0_M2M_MERCHANT_CLIENT_ID }} + AUTH0_M2M_MERCHANT_CLIENT_SECRET: ${{ secrets.AUTH0_M2M_MERCHANT_CLIENT_SECRET }} + CUSTOMER_SUB: ${{ vars.CUSTOMER_SUB }} + CUSTOMER_REFRESH_TOKEN: ${{ secrets.CUSTOMER_REFRESH_TOKEN }} + CUSTOMER_PRIVATE_KEY: ${{ secrets.CUSTOMER_PRIVATE_KEY }} + CUSTOMER_DEVICE_ID: ${{ secrets.CUSTOMER_DEVICE_ID }} diff --git a/__tests__/unit/handlers/get-all-items.test.mjs b/__tests__/unit/handlers/get-all-items.test.mjs deleted file mode 100644 index f9b642f1..00000000 --- a/__tests__/unit/handlers/get-all-items.test.mjs +++ /dev/null @@ -1,38 +0,0 @@ -// Import getAllItemsHandler function from get-all-items.mjs -import { getAllItemsHandler } from '../../../src/handlers/get-all-items.mjs'; -// Import dynamodb from aws-sdk -import { DynamoDBDocumentClient, ScanCommand } from '@aws-sdk/lib-dynamodb'; -import { mockClient } from "aws-sdk-client-mock"; - -// This includes all tests for getAllItemsHandler() -describe('Test getAllItemsHandler', () => { - const ddbMock = mockClient(DynamoDBDocumentClient); - - beforeEach(() => { - ddbMock.reset(); - }); - - it('should return ids', async () => { - const items = [{ id: 'id1' }, { id: 'id2' }]; - - // Return the specified value whenever the spied scan function is called - ddbMock.on(ScanCommand).resolves({ - Items: items, - }); - - const event = { - httpMethod: 'GET' - }; - - // Invoke helloFromLambdaHandler() - const result = await getAllItemsHandler(event); - - const expectedResult = { - statusCode: 200, - body: JSON.stringify(items) - }; - - // Compare the result with the expected result - expect(result).toEqual(expectedResult); - }); -}); diff --git a/__tests__/unit/handlers/get-by-id.test.mjs b/__tests__/unit/handlers/get-by-id.test.mjs deleted file mode 100644 index 2d97c98d..00000000 --- a/__tests__/unit/handlers/get-by-id.test.mjs +++ /dev/null @@ -1,43 +0,0 @@ -// Import getByIdHandler function from get-by-id.mjs -import { getByIdHandler } from '../../../src/handlers/get-by-id.mjs'; -// Import dynamodb from aws-sdk -import { DynamoDBDocumentClient, GetCommand } from '@aws-sdk/lib-dynamodb'; -import { mockClient } from "aws-sdk-client-mock"; - -// This includes all tests for getByIdHandler() -describe('Test getByIdHandler', () => { - const ddbMock = mockClient(DynamoDBDocumentClient); - - beforeEach(() => { - ddbMock.reset(); - }); - - // This test invokes getByIdHandler() and compare the result - it('should get item by id', async () => { - const item = { id: 'id1' }; - - // Return the specified value whenever the spied get function is called - ddbMock.on(GetCommand).resolves({ - Item: item, - }); - - const event = { - httpMethod: 'GET', - pathParameters: { - id: 'id1' - } - }; - - // Invoke getByIdHandler() - const result = await getByIdHandler(event); - - const expectedResult = { - statusCode: 200, - body: JSON.stringify(item) - }; - - // Compare the result with the expected result - expect(result).toEqual(expectedResult); - }); -}); - \ No newline at end of file diff --git a/__tests__/unit/handlers/put-item.test.mjs b/__tests__/unit/handlers/put-item.test.mjs deleted file mode 100644 index 64ea140d..00000000 --- a/__tests__/unit/handlers/put-item.test.mjs +++ /dev/null @@ -1,40 +0,0 @@ -// Import putItemHandler function from put-item.mjs -import { putItemHandler } from '../../../src/handlers/put-item.mjs'; -// Import dynamodb from aws-sdk -import { DynamoDBDocumentClient, PutCommand } from '@aws-sdk/lib-dynamodb'; -import { mockClient } from "aws-sdk-client-mock"; -// This includes all tests for putItemHandler() -describe('Test putItemHandler', function () { - const ddbMock = mockClient(DynamoDBDocumentClient); - - beforeEach(() => { - ddbMock.reset(); - }); - - // This test invokes putItemHandler() and compare the result - it('should add id to the table', async () => { - const returnedItem = { id: 'id1', name: 'name1' }; - - // Return the specified value whenever the spied put function is called - ddbMock.on(PutCommand).resolves({ - returnedItem - }); - - const event = { - httpMethod: 'POST', - body: '{"id": "id1","name": "name1"}' - }; - - // Invoke putItemHandler() - const result = await putItemHandler(event); - - const expectedResult = { - statusCode: 200, - body: JSON.stringify(returnedItem) - }; - - // Compare the result with the expected result - expect(result).toEqual(expectedResult); - }); -}); - \ No newline at end of file diff --git a/golang/authorizer/main.go b/golang/authorizer/main.go index 6ee366a0..8390a552 100644 --- a/golang/authorizer/main.go +++ b/golang/authorizer/main.go @@ -31,6 +31,7 @@ func handler(ctx context.Context, event events.APIGatewayCustomAuthorizerRequest } else if authorization, ok := event.Headers["Authorization"]; ok { hasUnauthenticatedHeader = authorization == "NONE" // Explicitly set by app if no auth is provided } + println("Has unauthenticated header: ", hasUnauthenticatedHeader) if isValidPath(event.Path, validOrumRoutes) { return handleOrumWebhookEndpoint(ctx, event) } else if isValidPath(event.Path, validPlaidRoutes) { diff --git a/kotlin/lambda/transfer-handler/src/main/kotlin/com/zenobiapay/transfer/operations/GetAdminTransferOperation.kt b/kotlin/lambda/transfer-handler/src/main/kotlin/com/zenobiapay/transfer/operations/GetAdminTransferOperation.kt index 236a3a6a..4f3e8aa8 100644 --- a/kotlin/lambda/transfer-handler/src/main/kotlin/com/zenobiapay/transfer/operations/GetAdminTransferOperation.kt +++ b/kotlin/lambda/transfer-handler/src/main/kotlin/com/zenobiapay/transfer/operations/GetAdminTransferOperation.kt @@ -24,8 +24,7 @@ class GetAdminTransferOperation @Inject constructor( context: Context, userId: String? ): GetAdminTransfer200Response { - val transferId = input.queryStringParameters?.get("id") - ?: throw ResourceNotFoundException("Missing transfer ID") + val transferId = request.id ?: throw ResourceNotFoundException("Missing transfer ID") val transferItem = transferDao.getTransfer(transferId) ?: throw ResourceNotFoundException("TRANSFER") diff --git a/kotlin/lambda/transfer-handler/src/test/kotlin/com/zenobiapay/transfer/operations/FulfillTransferOperationTest.kt b/kotlin/lambda/transfer-handler/src/test/kotlin/com/zenobiapay/transfer/operations/FulfillTransferOperationTest.kt index c11c1a6f..f74474b0 100644 --- a/kotlin/lambda/transfer-handler/src/test/kotlin/com/zenobiapay/transfer/operations/FulfillTransferOperationTest.kt +++ b/kotlin/lambda/transfer-handler/src/test/kotlin/com/zenobiapay/transfer/operations/FulfillTransferOperationTest.kt @@ -29,6 +29,7 @@ import com.zenobiapay.table.user.dao.UserDao import com.zenobiapay.table.user.model.MerchantData import com.zenobiapay.table.user.model.UserItem import com.zenobiapay.table.user.model.UserItemData +import com.zenobiapay.transfer.model.FulfillTransferRequestMixin import io.mockk.every import io.mockk.mockk import io.mockk.mockkStatic @@ -140,6 +141,22 @@ class FulfillTransferOperationTest { } } + @Test + fun testSignatureOrdering() { + val mixin = objectMapper.copy() + .addMixIn(FulfillTransferRequest::class.java, FulfillTransferRequestMixin::class.java) + + val request = FulfillTransferRequest() + .transferRequestId("4a573cdf-8e70-4e38-8c93-02f98439db88") + .bankAccountId("BjKzqRAad9TLKWE6D7DEFMjrDBDLa4S4xPAzo") + .deviceId("AFB81FDD-DC90-4BD0-857D-9C21BA5BD65B") + .signature(FulfillTransferRequestSignature() + .signatureType(SIGNATURE_TYPE) + .signatureValue(SIGNATURE)) + val json = mixin.writeValueAsString(request) + println(json) + } + private fun createMockGatewayEvent(): APIGatewayProxyRequestEvent { val mockEvent = mockk() return mockEvent diff --git a/tests/integration/.env.template b/tests/integration/.env.template new file mode 100644 index 00000000..3631404e --- /dev/null +++ b/tests/integration/.env.template @@ -0,0 +1,12 @@ +# API endpoint (e.g., https://your-api-endpoint.execute-api.us-east-1.amazonaws.com/prod) +API_ENDPOINT= + +# Auth0 configuration +AUTH0_DOMAIN= +AUTH0_M2M_CLIENT_ID= +AUTH0_M2M_CLIENT_SECRET= + +CUSTOMER_CLIENT_ID= +CUSTOMER_REFRESH_TOKEN= +CUSTOMER_PRIVATE_KEY= +CUSTOMER_DEVICE_ID= diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 00000000..49d74fa8 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,29 @@ +import os +import pytest +import boto3 +from dotenv import load_dotenv + +def pytest_sessionstart(session): + """ + Called after the Session object has been created and before tests are executed. + """ + # Load environment variables from .env file if it exists (for local testing) + load_dotenv() + + # Verify required environment variables + required_vars = [ + 'API_ENDPOINT', + 'AUTH0_DOMAIN', + 'AUTH0_M2M_MERCHANT_CLIENT_ID', + 'AUTH0_M2M_MERCHANT_CLIENT_SECRET', + ] + + # Check for either M2M credentials or customer refresh token + customer_auth_vars = ['CUSTOMER_SUB', 'CUSTOMER_REFRESH_TOKEN'] + if not any(os.environ.get(var) for var in customer_auth_vars): + print("Warning: No customer authentication variables found. Tests may fail if they require customer authentication.") + print("Consider setting CUSTOMER_SUB and CUSTOMER_REFRESH_TOKEN in your .env file.") + + missing = [var for var in required_vars if not os.environ.get(var)] + if missing: + pytest.exit(f"Missing required environment variables: {', '.join(missing)}") diff --git a/tests/integration/requirements.txt b/tests/integration/requirements.txt new file mode 100644 index 00000000..18fec4d3 --- /dev/null +++ b/tests/integration/requirements.txt @@ -0,0 +1,8 @@ +pytest==7.4.0 +pytest-xdist==3.3.1 +requests==2.31.0 +python-dotenv==1.0.0 +boto3==1.28.38 +cryptography==41.0.5 +ecdsa==0.18.0 +asn1crypto==1.5.1 diff --git a/tests/integration/run_local_tests.sh b/tests/integration/run_local_tests.sh new file mode 100755 index 00000000..0091cc97 --- /dev/null +++ b/tests/integration/run_local_tests.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# Activate virtual environment if not already activated +if [[ "$VIRTUAL_ENV" == "" ]]; then + echo "Activating virtual environment..." + source ../../.venv/bin/activate +fi + +# Check if .env file exists, if not, create from template +if [ ! -f ".env" ]; then + echo "No .env file found. Creating from template..." + if [ -f ".env.template" ]; then + cp .env.template .env + echo "Please edit .env file with your actual values before running tests." + exit 1 + else + echo "Error: .env.template not found. Please create a .env file manually." + exit 1 + fi +fi + +# Run the tests (python-dotenv will automatically load .env file) +# Adding -s flag to show print statements +python -m pytest test_transfer_flow.py -v -s diff --git a/tests/integration/test_transfer_flow.py b/tests/integration/test_transfer_flow.py new file mode 100644 index 00000000..2c216ff9 --- /dev/null +++ b/tests/integration/test_transfer_flow.py @@ -0,0 +1,94 @@ +import pytest +import time +import uuid +import os +import sys +from utils.api import MerchantApi, CustomerApi + +def test_full_transfer_flow(): + """Test the full transfer flow from creation to completion""" + + transfer_amount = 1 # 1 cent + + # Step 1: Create a transfer request as merchant + merchant_api = MerchantApi() + statement_items = [ + { + "name": "Test Item", + "amount": transfer_amount, + "itemId": f"test-item-{uuid.uuid4()}" + } + ] + transfer_metadata = {} + create_response = merchant_api.create_transfer_request( + amount=transfer_amount, + statement_items=statement_items, + transfer_metadata=transfer_metadata + ) + assert create_response.status_code == 200, f"Failed to create transfer: {create_response.text}" + transfer_id = create_response.json()["transferRequestId"] + print(f"Created transfer with ID: {transfer_id}") + + # Step 2: List bank accounts as customer + customer_api = CustomerApi() + + bank_accounts_response = customer_api.list_bank_accounts() + assert bank_accounts_response.status_code == 200, f"Failed to list bank accounts: {bank_accounts_response.text}" + + # Get a bank account ID for fulfillment + try: + response_json = bank_accounts_response.json() + print(f"JSON response: {response_json}") + # The API returns bank accounts in the 'items' field, not 'bankAccounts' + bank_accounts = response_json.get("items", []) + print(f"Bank accounts: {bank_accounts}") + except Exception as e: + print(f"Error parsing JSON: {e}") + bank_accounts = [] + print("--- END DEBUG ---\n") + + if not bank_accounts: + print("\n*** SKIPPING TEST: No bank accounts available for testing ***\n") + pytest.skip("No bank accounts available for testing") + bank_account_id = bank_accounts[0].get("bankAccountId") + + # Step 3: Get customer transfer (authenticated) + transfer_authed_response = customer_api.get_customer_transfer( + transfer_id=transfer_id, + authed=True + ) + assert transfer_authed_response.status_code == 200, f"Failed to get transfer (authed): {transfer_authed_response.text}" + print(f"Successfully retrieved transfer (authenticated)") + + # Step 4: Get customer transfer (unauthenticated) + unauthed_api = CustomerApi() # No customer_id = no token + transfer_unauthed_response = unauthed_api.get_customer_transfer( + transfer_id=transfer_id, + authed=False + ) + assert transfer_unauthed_response.status_code == 200, f"Failed to get transfer (unauthed): {transfer_unauthed_response.text}" + print(f"Successfully retrieved transfer (unauthenticated)") + + # Step 5: Fulfill the transfer + fulfill_response = customer_api.fulfill_transfer( + transfer_id=transfer_id, + bank_account_id=bank_account_id, + ) + assert fulfill_response.status_code == 200, f"Failed to fulfill transfer: {fulfill_response.text}" + print(f"Successfully fulfilled transfer") + + # Wait for transfer processing (adjust as needed) + print(f"Waiting for transfer processing...") + time.sleep(10) # Increased wait time to ensure processing completes + + # Step 6: List customer transfers and verify success + transfers_response = customer_api.list_customer_transfers(continuation_token=None) + assert transfers_response.status_code == 200, f"Failed to list transfers: {transfers_response.text}" + + # Verify the transfer is in the list and has status "PAID" + transfers = transfers_response.json().get("items", []) + matching_transfers = [t for t in transfers if t.get("transferRequestId") == transfer_id] + assert len(matching_transfers) == 1, f"Transfer {transfer_id} not found in list" + assert matching_transfers[0].get("status") == "PAID", f"Transfer status is not PAID: {matching_transfers[0].get('status')}" + print(f"Successfully verified transfer status is PAID") + diff --git a/tests/integration/utils/__init__.py b/tests/integration/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/utils/api.py b/tests/integration/utils/api.py new file mode 100644 index 00000000..905e5aac --- /dev/null +++ b/tests/integration/utils/api.py @@ -0,0 +1,158 @@ +import requests +import os +import json +import base64 +from collections import OrderedDict +from hashlib import sha256 +import ecdsa +from asn1crypto.algos import DSASignature +from .auth import get_merchant_token, get_customer_token + +class MerchantApi: + def __init__(self): + self.base_url = os.environ.get('API_ENDPOINT') + self.token = get_merchant_token() + + def create_transfer_request(self, amount, statement_items=None, transfer_metadata=None): + """Create a new transfer request + + Args: + amount: Integer amount in cents + statement_items: List of statement items (optional) + transfer_metadata: Dictionary of metadata (optional) + """ + if statement_items is None: + statement_items = [] + if transfer_metadata is None: + transfer_metadata = {} + + return requests.post( + f"{self.base_url}/create-transfer-request", + headers={"Authorization": f"Bearer {self.token}"}, + json={ + "amount": amount, + "statementItems": statement_items, + "transferMetadata": transfer_metadata + } + ) + +class CustomerApi: + def __init__(self): + self.base_url = os.environ.get('API_ENDPOINT') + self.device_id = os.environ.get('CUSTOMER_DEVICE_ID') + self.token = get_customer_token() + + def list_bank_accounts(self): + """List bank accounts for a customer""" + headers = {"Authorization": f"Bearer {self.token}"} if self.token else {} + return requests.post( + f"{self.base_url}/list-bank-accounts", + headers=headers, + json={ + "deviceId": self.device_id + } + ) + + def get_customer_transfer(self, transfer_id, authed=True): + """Get a customer transfer (with or without auth) + + Args: + transfer_id: ID of the transfer to retrieve + authed: Whether to include authentication token (default: True) + """ + print("Token: ", self.token) + print("Transfer id: ", transfer_id) + headers = {"Authorization": f"Bearer {self.token}"} if authed and self.token else {"Authorization": "NONE"} + params = {"id": transfer_id} + return requests.get( + f"{self.base_url}/get-customer-transfer", + headers=headers, + params=params + ) + + def list_customer_transfers(self, continuation_token=None): + """List transfers for a customer + + Args: + continuation_token: Token for pagination (optional) + """ + headers = {"Authorization": f"Bearer {self.token}"} if self.token else {} + json_body = {} + if continuation_token: + json_body["continuationToken"] = continuation_token + + return requests.post( + f"{self.base_url}/list-customer-transfers", + headers=headers, + json=json_body + ) + + def fulfill_transfer(self, transfer_id, bank_account_id): + """Fulfill a transfer with proper request signing + + Args: + transfer_id: ID of the transfer to fulfill + bank_account_id: ID of the bank account to use + """ + # Create request body with all fields except signature + request_body = { + "transferRequestId": transfer_id, + "bankAccountId": bank_account_id, + "deviceId": self.device_id + } + + # Sort fields alphabetically for signature calculation + sorted_body = OrderedDict(sorted(request_body.items())) + + # Convert to JSON string for signing + json_body = json.dumps(sorted_body, separators=(',', ':')) + + # Generate signature using private key from environment variable + # Expecting base64 encoded private key + signature = self.sign_with_private_key(json_body) + + # Add the signature to the request body after sorting/serializing + request_body["signature"] = signature + + headers = {"Authorization": f"Bearer {self.token}"} if self.token else {} + return requests.post( + f"{self.base_url}/fulfill-transfer", + headers=headers, + json=request_body + ) + + def sign_with_private_key(self, data): + encoded_private_key = os.environ.get('CUSTOMER_PRIVATE_KEY') + if not encoded_private_key: + raise ValueError("CUSTOMER_PRIVATE_KEY environment variable is required for signing") + try: + # Decode the base64 encoded private key + decoded_private_key = base64.b64decode(encoded_private_key) + + # Parse the private key using ecdsa + private_key = ecdsa.SigningKey.from_string(decoded_private_key, curve=ecdsa.NIST256p) + + # Sign the data using raw R||S format first + data_to_sign = data.encode('utf-8') + raw_signature = private_key.sign_deterministic( + data_to_sign, + hashfunc=sha256, + sigencode=ecdsa.util.sigencode_string # Use raw R||S encoding + ) + + # Convert raw signature to ASN.1/DER format (X9.62 format used by iOS) + r_int = int.from_bytes(raw_signature[:32], byteorder='big') + s_int = int.from_bytes(raw_signature[32:], byteorder='big') + der_signature = DSASignature({'r': r_int, 's': s_int}).dump() + + # Convert signature to base64 + sig_b64 = base64.b64encode(der_signature).decode('utf-8') + + # Create signature object + signature = { + "signatureType": "SHA256_WITH_ECDSA", + "signatureValue": sig_b64 + } + except Exception as e: + raise Exception(f"Failed to sign request: {str(e)}") + return signature diff --git a/tests/integration/utils/auth.py b/tests/integration/utils/auth.py new file mode 100644 index 00000000..90742f57 --- /dev/null +++ b/tests/integration/utils/auth.py @@ -0,0 +1,58 @@ +import requests +import os +import json + +def get_merchant_token(): + """Get an M2M token for merchant API access""" + auth0_domain = os.environ.get('AUTH0_DOMAIN') + client_id = os.environ.get('AUTH0_M2M_MERCHANT_CLIENT_ID') + client_secret = os.environ.get('AUTH0_M2M_MERCHANT_CLIENT_SECRET') + audience = "https://dashboard.zenobiapay.com" + + response = requests.post( + f"https://{auth0_domain}/oauth/token", + headers={"content-type": "application/json"}, + data=json.dumps({ + "client_id": client_id, + "client_secret": client_secret, + "audience": audience, + "grant_type": "client_credentials" + }) + ) + return response.json()["access_token"] + +def get_customer_token(): + """Get a token for customer API access using the issue-jwt endpoint""" + # Get environment variables + api_endpoint = os.environ.get('API_ENDPOINT') + sub = os.environ.get('CUSTOMER_SUB') + refresh_token = os.environ.get('CUSTOMER_REFRESH_TOKEN') + + # Check if we have the required variables + if not (api_endpoint and sub and refresh_token): + print("Warning: Missing required environment variables for customer authentication.") + print("Using merchant token as fallback. This may cause test failures.") + + # Call the issue-jwt endpoint to get a JWT for the customer + try: + response = requests.post( + f"{api_endpoint}/issue-jwt", + headers={ + "content-type": "application/json", + "authorization": "NONE" + }, + json={ + "refreshToken": refresh_token, + "sub": sub + } + ) + + if response.status_code != 200: + print(f"Error getting customer token: {response.status_code} - {response.text}") + return None + + return response.json()["jwt"] + except Exception as e: + print(f"Exception getting customer token: {str(e)}") + # Fall back to merchant token if we can't get a customer token + return get_merchant_token()