diff --git a/packages/backend/app/db/schema.sql b/packages/backend/app/db/schema.sql index 410189de..190c0d08 100644 --- a/packages/backend/app/db/schema.sql +++ b/packages/backend/app/db/schema.sql @@ -123,3 +123,18 @@ CREATE TABLE IF NOT EXISTS audit_logs ( action VARCHAR(100) NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT NOW() ); + +CREATE TABLE IF NOT EXISTS bank_connections ( + id SERIAL PRIMARY KEY, + user_id INT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + connector_id VARCHAR(50) NOT NULL, + connection_id VARCHAR(255) NOT NULL UNIQUE, + account_id VARCHAR(255), + account_name VARCHAR(255), + institution_name VARCHAR(255), + status VARCHAR(20) NOT NULL DEFAULT 'PENDING', + last_sync_at TIMESTAMP, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP NOT NULL DEFAULT NOW() +); +CREATE INDEX IF NOT EXISTS idx_bank_connections_user ON bank_connections(user_id); diff --git a/packages/backend/app/models.py b/packages/backend/app/models.py index 64d44810..5a06ba4f 100644 --- a/packages/backend/app/models.py +++ b/packages/backend/app/models.py @@ -133,3 +133,19 @@ class AuditLog(db.Model): user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=True) action = db.Column(db.String(100), nullable=False) created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False) + + +class BankConnection(db.Model): + """Stores bank connection metadata for users.""" + __tablename__ = "bank_connections" + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=False) + connector_id = db.Column(db.String(50), nullable=False) + connection_id = db.Column(db.String(255), nullable=False, unique=True) + account_id = db.Column(db.String(255), nullable=True) + account_name = db.Column(db.String(255), nullable=True) + institution_name = db.Column(db.String(255), nullable=True) + status = db.Column(db.String(20), default="PENDING", nullable=False) + last_sync_at = db.Column(db.DateTime, nullable=True) + created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False) + updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) diff --git a/packages/backend/app/routes/__init__.py b/packages/backend/app/routes/__init__.py index f13b0f89..5190586b 100644 --- a/packages/backend/app/routes/__init__.py +++ b/packages/backend/app/routes/__init__.py @@ -7,6 +7,7 @@ from .categories import bp as categories_bp from .docs import bp as docs_bp from .dashboard import bp as dashboard_bp +from .bank_connections import bp as bank_connections_bp def register_routes(app: Flask): @@ -18,3 +19,4 @@ def register_routes(app: Flask): app.register_blueprint(categories_bp, url_prefix="/categories") app.register_blueprint(docs_bp, url_prefix="/docs") app.register_blueprint(dashboard_bp, url_prefix="/dashboard") + app.register_blueprint(bank_connections_bp, url_prefix="/bank") diff --git a/packages/backend/app/routes/bank_connections.py b/packages/backend/app/routes/bank_connections.py new file mode 100644 index 00000000..7e8fa16e --- /dev/null +++ b/packages/backend/app/routes/bank_connections.py @@ -0,0 +1,295 @@ +""" +Bank Connections API + +Provides endpoints for managing bank connections and syncing transactions. +""" + +from datetime import datetime + +from flask import Blueprint, jsonify, request +from flask_jwt_extended import jwt_required, get_jwt_identity + +from ..extensions import db +from ..models import BankConnection, Expense +from ..services.bank_connector import ( + BankConnector, + ConnectorRegistry, + ConnectionResult, + ConnectionStatus, + MockBankConnector, +) + +bp = Blueprint("bank_connections", __name__) + + +def _connector_to_dict(connector: BankConnector) -> dict: + """Convert connector to dict for API response.""" + return { + "connector_id": connector.connector_id, + "connector_name": connector.connector_name, + "supported_features": connector.supported_features, + } + + +def _connection_to_dict(connection: BankConnection) -> dict: + """Convert BankConnection model to dict for API response.""" + return { + "id": connection.id, + "connector_id": connection.connector_id, + "connection_id": connection.connection_id, + "account_id": connection.account_id, + "account_name": connection.account_name, + "institution_name": connection.institution_name, + "status": connection.status, + "last_sync_at": connection.last_sync_at.isoformat() if connection.last_sync_at else None, + "created_at": connection.created_at.isoformat(), + "updated_at": connection.updated_at.isoformat(), + } + + +@bp.get("/connectors") +@jwt_required() +def list_connectors(): + """List all available bank connectors.""" + connectors = [ + _connector_to_dict(ConnectorRegistry.create_instance(conn_id)) + for conn_id in ConnectorRegistry.list_connectors() + ] + return jsonify(connectors) + + +@bp.post("/connections") +@jwt_required() +def create_connection(): + """Create a new bank connection.""" + uid = int(get_jwt_identity()) + data = request.get_json() or {} + + connector_id = data.get("connector_id") + credentials = data.get("credentials", {}) + + if not connector_id: + return jsonify(error="connector_id required"), 400 + + # Get connector + connector = ConnectorRegistry.create_instance(connector_id) + if not connector: + return jsonify(error="Unknown connector"), 400 + + # Validate credentials + is_valid, error_msg = connector.validate_credentials(credentials) + if not is_valid: + return jsonify(error=error_msg), 400 + + # Connect to bank + result = connector.connect(credentials) + if not result.success: + return jsonify(error=result.message, code=result.error_code), 400 + + # Store connection + connection_id = f"conn_{connector_id}_{uid}_{datetime.utcnow().timestamp()}" + bank_conn = BankConnection( + user_id=uid, + connector_id=connector_id, + connection_id=connection_id, + account_id=result.accounts[0].account_id if result.accounts else None, + account_name=result.accounts[0].account_name if result.accounts else None, + institution_name=result.accounts[0].institution_name if result.accounts else None, + status=ConnectionStatus.CONNECTED.value, + ) + db.session.add(bank_conn) + db.session.commit() + + return jsonify(_connection_to_dict(bank_conn)), 201 + + +@bp.get("/connections") +@jwt_required() +def list_connections(): + """List all bank connections for the current user.""" + uid = int(get_jwt_identity()) + connections = BankConnection.query.filter_by(user_id=uid).all() + return jsonify([_connection_to_dict(c) for c in connections]) + + +@bp.get("/connections/") +@jwt_required() +def get_connection(connection_id: int): + """Get a specific bank connection.""" + uid = int(get_jwt_identity()) + connection = BankConnection.query.filter_by(id=connection_id, user_id=uid).first() + if not connection: + return jsonify(error="Connection not found"), 404 + return jsonify(_connection_to_dict(connection)) + + +@bp.delete("/connections/") +@jwt_required() +def delete_connection(connection_id: int): + """Delete a bank connection.""" + uid = int(get_jwt_identity()) + connection = BankConnection.query.filter_by(id=connection_id, user_id=uid).first() + if not connection: + return jsonify(error="Connection not found"), 404 + + # Get connector and disconnect + connector = ConnectorRegistry.create_instance(connection.connector_id) + if connector: + connector.disconnect(connection.connection_id) + + db.session.delete(connection) + db.session.commit() + + return jsonify(message="Connection deleted") + + +@bp.post("/connections//import") +@jwt_required() +def import_transactions(connection_id: int): + """Import transactions from a bank connection.""" + uid = int(get_jwt_identity()) + data = request.get_json() or {} + + connection = BankConnection.query.filter_by(id=connection_id, user_id=uid).first() + if not connection: + return jsonify(error="Connection not found"), 404 + + # Parse dates + start_date = None + end_date = None + if data.get("start_date"): + try: + start_date = datetime.fromisoformat(data["start_date"]) + except ValueError: + return jsonify(error="Invalid start_date"), 400 + if data.get("end_date"): + try: + end_date = datetime.fromisoformat(data["end_date"]) + except ValueError: + return jsonify(error="Invalid end_date"), 400 + + # Get connector and import + connector = ConnectorRegistry.create_instance(connection.connector_id) + if not connector: + return jsonify(error="Connector not found"), 500 + + result = connector.import_transactions( + connection.connection_id, + start_date=start_date, + end_date=end_date, + ) + + if not result.success: + return jsonify(error=result.message, code=result.error_code), 400 + + # Convert transactions to expense format + expenses_data = [t.to_expense_dict() for t in result.transactions] + + return jsonify({ + "message": result.message, + "transactions": expenses_data, + "count": len(expenses_data), + }) + + +@bp.post("/connections//refresh") +@jwt_required() +def refresh_transactions(connection_id: int): + """Refresh transactions from a bank connection.""" + uid = int(get_jwt_identity()) + data = request.get_json() or {} + + connection = BankConnection.query.filter_by(id=connection_id, user_id=uid).first() + if not connection: + return jsonify(error="Connection not found"), 404 + + # Parse 'since' date + since = None + if data.get("since"): + try: + since = datetime.fromisoformat(data["since"]) + except ValueError: + return jsonify(error="Invalid since date"), 400 + elif connection.last_sync_at: + since = connection.last_sync_at + + # Get connector and refresh + connector = ConnectorRegistry.create_instance(connection.connector_id) + if not connector: + return jsonify(error="Connector not found"), 500 + + result = connector.refresh_transactions( + connection.connection_id, + since=since, + ) + + if not result.success: + return jsonify(error=result.message, code=result.error_code), 400 + + # Update last sync time + connection.last_sync_at = datetime.utcnow() + db.session.commit() + + # Convert transactions to expense format + expenses_data = [t.to_expense_dict() for t in result.transactions] + + return jsonify({ + "message": result.message, + "transactions": expenses_data, + "count": len(expenses_data), + }) + + +@bp.post("/connections//import/commit") +@jwt_required() +def commit_imported_transactions(connection_id: int): + """Commit imported transactions as expenses.""" + uid = int(get_jwt_identity()) + data = request.get_json() or {} + + connection = BankConnection.query.filter_by(id=connection_id, user_id=uid).first() + if not connection: + return jsonify(error="Connection not found"), 404 + + transactions = data.get("transactions", []) + if not transactions: + return jsonify(error="No transactions to import"), 400 + + inserted = 0 + duplicates = 0 + + for tx in transactions: + # Check for duplicate (same date, amount, description) + existing = Expense.query.filter_by( + user_id=uid, + spent_at=tx.get("date"), + amount=tx.get("amount"), + notes=tx.get("description"), + ).first() + + if existing: + duplicates += 1 + continue + + expense = Expense( + user_id=uid, + amount=tx.get("amount"), + currency=tx.get("currency", "USD"), + expense_type=tx.get("expense_type", "EXPENSE"), + category_id=tx.get("category_id"), + notes=tx.get("description", ""), + spent_at=tx.get("date"), + ) + db.session.add(expense) + inserted += 1 + + db.session.commit() + + # Update last sync time + connection.last_sync_at = datetime.utcnow() + db.session.commit() + + return jsonify({ + "inserted": inserted, + "duplicates": duplicates, + }), 201 \ No newline at end of file diff --git a/packages/backend/app/services/bank_connector.py b/packages/backend/app/services/bank_connector.py new file mode 100644 index 00000000..caf9ac74 --- /dev/null +++ b/packages/backend/app/services/bank_connector.py @@ -0,0 +1,373 @@ +""" +Bank Sync Connector Architecture + +Provides a pluggable architecture for bank integrations with: +- Connector interface (abstract base class) +- Import & refresh support +- Mock connector for testing +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any + + +class ConnectionStatus(str, Enum): + """Status of a bank connection.""" + PENDING = "PENDING" + CONNECTED = "CONNECTED" + ERROR = "ERROR" + DISCONNECTED = "DISCONNECTED" + + +class TransactionType(str, Enum): + """Type of transaction.""" + EXPENSE = "EXPENSE" + INCOME = "INCOME" + TRANSFER = "TRANSFER" + + +@dataclass +class BankTransaction: + """Represents a transaction from a bank account.""" + date: datetime + amount: float + description: str + transaction_type: TransactionType = TransactionType.EXPENSE + currency: str = "USD" + category_id: int | None = None + notes: str = "" + raw_data: dict[str, Any] = field(default_factory=dict) + + def to_expense_dict(self) -> dict[str, Any]: + """Convert to expense dict format for import.""" + return { + "date": self.date.date().isoformat() if isinstance(self.date, datetime) else self.date, + "amount": abs(self.amount), + "description": self.description[:500], + "category_id": self.category_id, + "expense_type": self.transaction_type.value, + "currency": self.currency, + } + + +@dataclass +class BankAccount: + """Represents a bank account.""" + account_id: str + account_name: str + account_type: str # checking, savings, credit, etc. + balance: float + currency: str = "USD" + institution_name: str = "" + institution_id: str = "" + mask: str = "" # Last 4 digits + + +@dataclass +class ConnectionResult: + """Result of a bank connection operation.""" + success: bool + message: str + accounts: list[BankAccount] = field(default_factory=list) + transactions: list[BankTransaction] = field(default_factory=list) + error_code: str | None = None + + +class BankConnector(ABC): + """ + Abstract base class for bank connectors. + + Implement this interface to create a new bank integration. + """ + + @property + @abstractmethod + def connector_id(self) -> str: + """Unique identifier for this connector.""" + pass + + @property + @abstractmethod + def connector_name(self) -> str: + """Display name for this connector.""" + pass + + @property + def supported_features(self) -> list[str]: + """List of supported features. Override in subclass.""" + return ["import", "refresh"] + + @abstractmethod + def connect(self, credentials: dict[str, Any]) -> ConnectionResult: + """ + Connect to the bank using provided credentials. + + Args: + credentials: Dict containing authentication details + (e.g., api_key, access_token, etc.) + + Returns: + ConnectionResult with status and account info + """ + pass + + @abstractmethod + def disconnect(self, connection_id: str) -> bool: + """ + Disconnect from the bank. + + Args: + connection_id: The connection identifier to disconnect + + Returns: + True if disconnected successfully + """ + pass + + @abstractmethod + def import_transactions( + self, + connection_id: str, + start_date: datetime | None = None, + end_date: datetime | None = None, + ) -> ConnectionResult: + """ + Import transactions from the connected bank account. + + Args: + connection_id: The connection identifier + start_date: Start date for transaction import (optional) + end_date: End date for transaction import (optional) + + Returns: + ConnectionResult with transactions + """ + pass + + @abstractmethod + def refresh_transactions( + self, + connection_id: str, + since: datetime | None = None, + ) -> ConnectionResult: + """ + Refresh/fetch new transactions since last sync. + + Args: + connection_id: The connection identifier + since: Fetch transactions after this datetime (optional) + + Returns: + ConnectionResult with new transactions + """ + pass + + def get_connection_status(self, connection_id: str) -> ConnectionStatus: + """ + Get the current status of a connection. + + Args: + connection_id: The connection identifier + + Returns: + ConnectionStatus enum value + """ + return ConnectionStatus.CONNECTED + + def validate_credentials(self, credentials: dict[str, Any]) -> tuple[bool, str]: + """ + Validate credentials before attempting connection. + + Args: + credentials: Dict containing authentication details + + Returns: + Tuple of (is_valid, error_message) + """ + if not credentials: + return False, "Credentials cannot be empty" + return True, "" + + +class MockBankConnector(BankConnector): + """ + Mock bank connector for testing purposes. + + Simulates a bank connection with configurable behavior. + """ + + def __init__( + self, + mock_accounts: list[BankAccount] | None = None, + mock_transactions: list[BankTransaction] | None = None, + should_fail: bool = False, + failure_message: str = "Mock connection failed", + ): + self._connector_id = "mock" + self._connector_name = "Mock Bank" + self._mock_accounts = mock_accounts or [ + BankAccount( + account_id="mock_acc_001", + account_name="Mock Checking", + account_type="checking", + balance=5000.00, + currency="USD", + institution_name="Mock Bank", + institution_id="mock_bank", + mask="1234", + ) + ] + self._mock_transactions = mock_transactions or [] + self._should_fail = should_fail + self._failure_message = failure_message + self._connections: dict[str, dict[str, Any]] = {} + + @property + def connector_id(self) -> str: + return self._connector_id + + @property + def connector_name(self) -> str: + return self._connector_name + + def connect(self, credentials: dict[str, Any]) -> ConnectionResult: + if self._should_fail: + return ConnectionResult( + success=False, + message=self._failure_message, + error_code="MOCK_ERROR", + ) + + # Validate mock credentials + is_valid, error_msg = self.validate_credentials(credentials) + if not is_valid: + return ConnectionResult( + success=False, + message=error_msg, + error_code="INVALID_CREDENTIALS", + ) + + connection_id = f"conn_{self._connector_id}_{len(self._connections)}" + self._connections[connection_id] = { + "status": ConnectionStatus.CONNECTED, + "credentials": credentials, + "connected_at": datetime.utcnow(), + } + + return ConnectionResult( + success=True, + message="Connected successfully", + accounts=self._mock_accounts, + ) + + def disconnect(self, connection_id: str) -> bool: + if connection_id in self._connections: + del self._connections[connection_id] + return True + return False + + def import_transactions( + self, + connection_id: str, + start_date: datetime | None = None, + end_date: datetime | None = None, + ) -> ConnectionResult: + if connection_id not in self._connections: + return ConnectionResult( + success=False, + message="Connection not found", + error_code="CONNECTION_NOT_FOUND", + ) + + # Filter transactions by date range if provided + transactions = self._mock_transactions + if start_date: + transactions = [t for t in transactions if t.date >= start_date] + if end_date: + transactions = [t for t in transactions if t.date <= end_date] + + return ConnectionResult( + success=True, + message=f"Imported {len(transactions)} transactions", + transactions=transactions, + ) + + def refresh_transactions( + self, + connection_id: str, + since: datetime | None = None, + ) -> ConnectionResult: + if connection_id not in self._connections: + return ConnectionResult( + success=False, + message="Connection not found", + error_code="CONNECTION_NOT_FOUND", + ) + + # Filter transactions after 'since' date + transactions = self._mock_transactions + if since: + transactions = [t for t in transactions if t.date > since] + + return ConnectionResult( + success=True, + message=f"Refreshed {len(transactions)} new transactions", + transactions=transactions, + ) + + def get_connection_status(self, connection_id: str) -> ConnectionStatus: + if connection_id in self._connections: + return self._connections[connection_id].get("status", ConnectionStatus.DISCONNECTED) + return ConnectionStatus.DISCONNECTED + + def add_mock_transaction(self, transaction: BankTransaction) -> None: + """Add a transaction to the mock connector.""" + self._mock_transactions.append(transaction) + + def clear_transactions(self) -> None: + """Clear all mock transactions.""" + self._mock_transactions.clear() + + def set_failure(self, should_fail: bool, message: str = "Mock connection failed") -> None: + """Configure the mock to fail or succeed.""" + self._should_fail = should_fail + self._failure_message = message + + +# Registry for available connectors +class ConnectorRegistry: + """Registry for managing available bank connectors.""" + + _connectors: dict[str, type[BankConnector]] = {} + + @classmethod + def register(cls, connector_class: type[BankConnector]) -> None: + """Register a connector class.""" + # Create temporary instance to get connector_id + instance = connector_class() + cls._connectors[instance.connector_id] = connector_class + + @classmethod + def get(cls, connector_id: str) -> type[BankConnector] | None: + """Get a connector class by ID.""" + return cls._connectors.get(connector_id) + + @classmethod + def list_connectors(cls) -> list[str]: + """List all registered connector IDs.""" + return list(cls._connectors.keys()) + + @classmethod + def create_instance(cls, connector_id: str) -> BankConnector | None: + """Create an instance of a connector.""" + connector_class = cls.get(connector_id) + if connector_class: + return connector_class() + return None + + +# Register the mock connector +ConnectorRegistry.register(MockBankConnector) \ No newline at end of file diff --git a/packages/backend/tests/test_bank_connector.py b/packages/backend/tests/test_bank_connector.py new file mode 100644 index 00000000..422f95dd --- /dev/null +++ b/packages/backend/tests/test_bank_connector.py @@ -0,0 +1,437 @@ +""" +Tests for Bank Connector Architecture +""" + +import pytest +from datetime import datetime, date + +from app.services.bank_connector import ( + BankConnector, + BankTransaction, + BankAccount, + ConnectionResult, + ConnectionStatus, + TransactionType, + MockBankConnector, + ConnectorRegistry, +) + + +class TestBankTransaction: + """Tests for BankTransaction dataclass.""" + + def test_to_expense_dict_expense(self): + tx = BankTransaction( + date=datetime(2026, 2, 15), + amount=-50.00, + description="Grocery Store", + transaction_type=TransactionType.EXPENSE, + currency="USD", + ) + result = tx.to_expense_dict() + assert result["amount"] == 50.00 + assert result["description"] == "Grocery Store" + assert result["expense_type"] == "EXPENSE" + assert result["currency"] == "USD" + + def test_to_expense_dict_income(self): + tx = BankTransaction( + date=datetime(2026, 2, 15), + amount=5000.00, + description="Salary", + transaction_type=TransactionType.INCOME, + currency="USD", + ) + result = tx.to_expense_dict() + assert result["amount"] == 5000.00 + assert result["expense_type"] == "INCOME" + + def test_to_expense_dict_with_category(self): + tx = BankTransaction( + date=datetime(2026, 2, 15), + amount=-25.00, + description="Coffee", + category_id=5, + ) + result = tx.to_expense_dict() + assert result["category_id"] == 5 + + +class TestBankAccount: + """Tests for BankAccount dataclass.""" + + def test_basic_account(self): + account = BankAccount( + account_id="acc_123", + account_name="Checking Account", + account_type="checking", + balance=1000.00, + currency="USD", + institution_name="Test Bank", + institution_id="test_bank", + mask="1234", + ) + assert account.account_id == "acc_123" + assert account.balance == 1000.00 + assert account.mask == "1234" + + +class TestMockBankConnector: + """Tests for MockBankConnector.""" + + def test_connector_id_and_name(self): + connector = MockBankConnector() + assert connector.connector_id == "mock" + assert connector.connector_name == "Mock Bank" + + def test_connect_success(self): + connector = MockBankConnector() + result = connector.connect({"api_key": "test_key"}) + + assert result.success is True + assert result.message == "Connected successfully" + assert len(result.accounts) == 1 + assert result.accounts[0].account_name == "Mock Checking" + + def test_connect_failure(self): + connector = MockBankConnector(should_fail=True, failure_message="Invalid credentials") + result = connector.connect({"api_key": "bad_key"}) + + assert result.success is False + assert result.message == "Invalid credentials" + assert result.error_code == "MOCK_ERROR" + + def test_connect_invalid_credentials(self): + connector = MockBankConnector() + result = connector.connect({}) + + assert result.success is False + assert "empty" in result.message.lower() + + def test_disconnect(self): + connector = MockBankConnector() + result = connector.connect({"api_key": "test"}) + connection_id = result.accounts[0].account_id # This is wrong, need actual connection_id + + # Actually, let's check the connection is stored + assert len(connector._connections) == 1 + conn_id = list(connector._connections.keys())[0] + + disconnected = connector.disconnect(conn_id) + assert disconnected is True + assert len(connector._connections) == 0 + + def test_disconnect_not_found(self): + connector = MockBankConnector() + result = connector.disconnect("nonexistent") + assert result is False + + def test_import_transactions(self): + transactions = [ + BankTransaction( + date=datetime(2026, 2, 10), + amount=-50.00, + description="Coffee", + ), + BankTransaction( + date=datetime(2026, 2, 11), + amount=-100.00, + description="Groceries", + ), + ] + connector = MockBankConnector(mock_transactions=transactions) + + # First connect to get a connection_id + connect_result = connector.connect({"api_key": "test"}) + conn_id = list(connector._connections.keys())[0] + + result = connector.import_transactions(conn_id) + + assert result.success is True + assert len(result.transactions) == 2 + + def test_import_transactions_with_date_filter(self): + transactions = [ + BankTransaction(date=datetime(2026, 2, 10), amount=-50.00, description="Coffee"), + BankTransaction(date=datetime(2026, 2, 15), amount=-100.00, description="Groceries"), + ] + connector = MockBankConnector(mock_transactions=transactions) + + connect_result = connector.connect({"api_key": "test"}) + conn_id = list(connector._connections.keys())[0] + + result = connector.import_transactions( + conn_id, + start_date=datetime(2026, 2, 12), + ) + + assert result.success is True + assert len(result.transactions) == 1 + assert result.transactions[0].description == "Groceries" + + def test_refresh_transactions(self): + transactions = [ + BankTransaction(date=datetime(2026, 2, 10), amount=-50.00, description="Old"), + BankTransaction(date=datetime(2026, 2, 20), amount=-100.00, description="New"), + ] + connector = MockBankConnector(mock_transactions=transactions) + + connect_result = connector.connect({"api_key": "test"}) + conn_id = list(connector._connections.keys())[0] + + # Refresh since Feb 15 - should only get the "New" transaction + result = connector.refresh_transactions( + conn_id, + since=datetime(2026, 2, 15), + ) + + assert result.success is True + assert len(result.transactions) == 1 + assert result.transactions[0].description == "New" + + def test_get_connection_status(self): + connector = MockBankConnector() + connect_result = connector.connect({"api_key": "test"}) + conn_id = list(connector._connections.keys())[0] + + status = connector.get_connection_status(conn_id) + assert status == ConnectionStatus.CONNECTED + + # After disconnect + connector.disconnect(conn_id) + status = connector.get_connection_status(conn_id) + assert status == ConnectionStatus.DISCONNECTED + + def test_add_mock_transaction(self): + connector = MockBankConnector() + connector.add_mock_transaction( + BankTransaction( + date=datetime(2026, 2, 10), + amount=-50.00, + description="Test", + ) + ) + assert len(connector._mock_transactions) == 1 + + def test_clear_transactions(self): + connector = MockBankConnector(mock_transactions=[ + BankTransaction(date=datetime(2026, 2, 10), amount=-50.00, description="Test"), + ]) + connector.clear_transactions() + assert len(connector._mock_transactions) == 0 + + +class TestConnectorRegistry: + """Tests for ConnectorRegistry.""" + + def test_register_and_get(self): + # MockBankConnector should already be registered + connector_class = ConnectorRegistry.get("mock") + assert connector_class is MockBankConnector + + def test_list_connectors(self): + connectors = ConnectorRegistry.list_connectors() + assert "mock" in connectors + + def test_create_instance(self): + connector = ConnectorRegistry.create_instance("mock") + assert isinstance(connector, MockBankConnector) + + def test_create_instance_unknown(self): + connector = ConnectorRegistry.create_instance("unknown_connector") + assert connector is None + + +class TestConnectionResult: + """Tests for ConnectionResult dataclass.""" + + def test_default_values(self): + result = ConnectionResult(success=True, message="OK") + assert result.success is True + assert result.message == "OK" + assert result.accounts == [] + assert result.transactions == [] + assert result.error_code is None + + +# Integration-style tests with the Flask app +def test_bank_connections_api_list_connectors(client, auth_header): + """Test the /bank/connectors endpoint.""" + r = client.get("/bank/connectors", headers=auth_header) + assert r.status_code == 200 + data = r.get_json() + assert len(data) >= 1 + assert data[0]["connector_id"] == "mock" + + +def test_bank_connections_api_create_connection(client, auth_header): + """Test creating a bank connection.""" + r = client.post( + "/bank/connections", + json={"connector_id": "mock", "credentials": {"api_key": "test"}}, + headers=auth_header, + ) + assert r.status_code == 201 + data = r.get_json() + assert "connection_id" in data + assert data["connector_id"] == "mock" + assert data["status"] == "CONNECTED" + + +def test_bank_connections_api_list_connections(client, auth_header): + """Test listing bank connections.""" + # First create one + client.post( + "/bank/connections", + json={"connector_id": "mock", "credentials": {"api_key": "test"}}, + headers=auth_header, + ) + + r = client.get("/bank/connections", headers=auth_header) + assert r.status_code == 200 + data = r.get_json() + assert len(data) == 1 + + +def test_bank_connections_api_get_connection(client, auth_header): + """Test getting a specific connection.""" + create_r = client.post( + "/bank/connections", + json={"connector_id": "mock", "credentials": {"api_key": "test"}}, + headers=auth_header, + ) + conn_id = create_r.get_json()["id"] + + r = client.get(f"/bank/connections/{conn_id}", headers=auth_header) + assert r.status_code == 200 + assert r.get_json()["id"] == conn_id + + +def test_bank_connections_api_delete_connection(client, auth_header): + """Test deleting a bank connection.""" + create_r = client.post( + "/bank/connections", + json={"connector_id": "mock", "credentials": {"api_key": "test"}}, + headers=auth_header, + ) + conn_id = create_r.get_json()["id"] + + r = client.delete(f"/bank/connections/{conn_id}", headers=auth_header) + assert r.status_code == 200 + + # Verify it's gone + r = client.get(f"/bank/connections/{conn_id}", headers=auth_header) + assert r.status_code == 404 + + +def test_bank_connections_api_import_transactions(client, auth_header): + """Test importing transactions from a connection.""" + # Create connection + create_r = client.post( + "/bank/connections", + json={"connector_id": "mock", "credentials": {"api_key": "test"}}, + headers=auth_header, + ) + conn_id = create_r.get_json()["id"] + + # Import transactions (empty because mock has no transactions by default) + r = client.post( + f"/bank/connections/{conn_id}/import", + json={}, + headers=auth_header, + ) + assert r.status_code == 200 + data = r.get_json() + assert "transactions" in data + + +def test_bank_connections_api_import_commit(client, auth_header): + """Test committing imported transactions.""" + # Create connection + create_r = client.post( + "/bank/connections", + json={"connector_id": "mock", "credentials": {"api_key": "test"}}, + headers=auth_header, + ) + conn_id = create_r.get_json()["id"] + + # Commit transactions + transactions = [ + { + "date": "2026-02-15", + "amount": 50.00, + "description": "Test Expense", + "expense_type": "EXPENSE", + "currency": "USD", + } + ] + r = client.post( + f"/bank/connections/{conn_id}/import/commit", + json={"transactions": transactions}, + headers=auth_header, + ) + assert r.status_code == 201 + data = r.get_json() + assert data["inserted"] == 1 + assert data["duplicates"] == 0 + + # Verify expense was created + r = client.get("/expenses", headers=auth_header) + assert r.status_code == 200 + expenses = r.get_json() + assert len(expenses) == 1 + assert expenses[0]["notes"] == "Test Expense" + + +def test_bank_connections_api_duplicate_detection(client, auth_header): + """Test that duplicate transactions are detected on commit.""" + # Create connection + create_r = client.post( + "/bank/connections", + json={"connector_id": "mock", "credentials": {"api_key": "test"}}, + headers=auth_header, + ) + conn_id = create_r.get_json()["id"] + + transactions = [ + { + "date": "2026-02-15", + "amount": 50.00, + "description": "Duplicate Test", + "expense_type": "EXPENSE", + "currency": "USD", + } + ] + + # First commit + r = client.post( + f"/bank/connections/{conn_id}/import/commit", + json={"transactions": transactions}, + headers=auth_header, + ) + assert r.get_json()["inserted"] == 1 + + # Second commit should be duplicate + r = client.post( + f"/bank/connections/{conn_id}/import/commit", + json={"transactions": transactions}, + headers=auth_header, + ) + assert r.get_json()["inserted"] == 0 + assert r.get_json()["duplicates"] == 1 + + +def test_bank_connections_api_invalid_connector(client, auth_header): + """Test creating connection with invalid connector.""" + r = client.post( + "/bank/connections", + json={"connector_id": "nonexistent", "credentials": {}}, + headers=auth_header, + ) + assert r.status_code == 400 + + +def test_bank_connections_api_not_found(client, auth_header): + """Test accessing non-existent connection.""" + r = client.get("/bank/connections/99999", headers=auth_header) + assert r.status_code == 404 \ No newline at end of file