diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..a51e1b1 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,45 @@ +name: CI + +on: + push: + branches: [main, master] + pull_request: + branches: [main, master] + +permissions: + contents: read + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[dev] + + - name: Lint with ruff + run: ruff check . + + - name: Type check with mypy + run: mypy af_server_client --ignore-missing-imports + + - name: Test with pytest + run: pytest --cov=af_server_client --cov-report=xml -v + + - name: Upload coverage + uses: codecov/codecov-action@v3 + with: + file: ./coverage.xml + fail_ci_if_error: false diff --git a/.gitignore b/.gitignore index 7a3a3f1..b45648a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,52 @@ /.venv/ /.idea/ + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Testing +.tox/ +.nox/ +.coverage +.coverage.* +htmlcov/ +.pytest_cache/ +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ + +# IDE +.vscode/ + +# OS +.DS_Store +Thumbs.db + +# Benchmark results +benchmark_results/ +*.json +!af_server_config.toml + +# PsychoPy experiments (user-created) +experiments/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..5fa37fa --- /dev/null +++ b/README.md @@ -0,0 +1,297 @@ +# AF-Server Client + +Low-latency TCP client for LabVIEW server communication with AF-Serializer protocol. + +## Features + +- **Low-latency TCP communication** with optimizations: + - Nagle's algorithm disabled (TCP_NODELAY) + - QoS with DSCP EF for low latency + - Minimal buffer sizes + - Asyncio + uvloop for improved performance + +- **Request-response architecture** with UUID correlation + +- **Interactive Textual TUI console** with: + - Status panel showing connection and latency metrics + - Rich log for command output + - Command input with history + - Real-time progress for benchmarks + +- **RTT benchmarking system** with statistics and export + +- **PsychoPy integration** for laboratory experiments + +## Installation + +### From Source + +```bash +git clone https://github.com/Faragoz/AF-Server.git +cd AF-Server +pip install -e . +``` + +### With Development Dependencies + +```bash +pip install -e ".[dev]" +``` + +### With PsychoPy Support + +```bash +pip install -e ".[psychopy]" +``` + +## Quick Start + +### Launch Console + +```bash +af-server-client +``` + +### In Console + +``` +> connect 127.0.0.1 2001 +> status +> benchmark +> export json results.json +> help +``` + +### Command-Line Options + +```bash +# Specify config file +af-server-client -c myconfig.toml + +# Override host and port +af-server-client -H 192.168.1.100 -p 3000 + +# Auto-connect on startup +af-server-client --connect + +# Run benchmark and exit +af-server-client --benchmark --iterations 100 --export results.json +``` + +## Configuration + +Create `af_server_config.toml` in your working directory: + +```toml +[network] +host = "127.0.0.1" +port = 2001 +timeout = 5.0 +buffer_size = 256 +enable_nodelay = true +enable_qos = true +heartbeat_interval = 30.0 + +[console] +log_level = "DEBUG" +history_size = 100 +max_log_lines = 1000 +theme = "monokai" + +[benchmark] +default_iterations = 10 +payload_sizes = [0, 128, 256, 512, 1024, 2048, 4096, 8192] +export_format = "json" +export_directory = "./benchmark_results" + +[psychopy] +experiments_dir = "./experiments" +default_experiment = "feedback_loop.py" +subject_gui_width = 800 +subject_gui_height = 600 +``` + +## Console Commands + +| Command | Description | +|---------|-------------| +| `connect [host] [port]` | Connect to server (uses config defaults if no args) | +| `disconnect` | Disconnect from server | +| `status` | Show connection status | +| `send ` | Send a command to server | +| `benchmark [iterations] [sizes...]` | Run RTT benchmark | +| `launch-psychopy [experiment]` | Launch PsychoPy experiment | +| `config show` | Display current configuration | +| `config set ` | Update configuration | +| `config reload` | Reload configuration from file | +| `export ` | Export benchmark data (json/csv/excel) | +| `clear` | Clear console log | +| `help [command]` | Show help | +| `quit` / `exit` | Exit the application | + +### Keyboard Shortcuts + +| Key | Action | +|-----|--------| +| `Ctrl+C` | Quit | +| `Ctrl+L` | Clear log | +| `Ctrl+D` | Disconnect | +| `↑` / `↓` | Command history | + +## Architecture + +``` +af_server_client/ +├── core/ +│ ├── tcp_client.py # Asyncio TCP client with uvloop +│ ├── protocol.py # @lvclass command definitions +│ ├── response_handler.py # Request-response correlation +│ └── config.py # TOML configuration loader +├── console/ +│ ├── app.py # Main Textual TUI application +│ ├── commands.py # Command executor and registry +│ ├── widgets.py # StatusPanel, ProgressPanel +│ └── command_parser.py # Parse and validate commands +├── benchmark/ +│ ├── rtt_test.py # Round-trip time benchmarking +│ ├── metrics.py # Sample, BenchmarkRun classes +│ └── export.py # JSON/CSV/Excel exporters +├── psychopy_bridge/ +│ ├── launcher.py # Launch PsychoPy experiments +│ ├── subject_gui.py # Subject interface (buttons/sliders) +│ └── shared_client.py # Singleton TCP client +└── cli.py # CLI entry point +``` + +## Protocol Classes + +The client uses AF-Serializer for communication: + +```python +from af_server_client.core.protocol import Protocol, EchoCommand, Response + +# General protocol wrapper +protocol = Protocol() +protocol.data = "request-id" + +# Echo command for benchmarking +cmd = EchoCommand() +cmd.payload = "test data" +cmd.timestamp = 1234567890.123 + +# Server response +response = Response() +response.request_id = "request-id" +response.success = True +response.exec_time = 1500.0 # μs +response.message = "OK" +``` + +## Benchmarking + +The benchmark system measures round-trip time with various payload sizes: + +```python +from af_server_client.benchmark import RTTBenchmark, BenchmarkExporter + +# Run benchmark +benchmark = RTTBenchmark(client, config) +result = await benchmark.run(iterations=10, payload_sizes=[0, 128, 512, 1024]) + +# Access statistics +print(f"Avg Latency: {result.stats['avg_total_latency']:.2f}ms") +print(f"P99 Latency: {result.stats['p99_latency']:.2f}ms") +print(f"Jitter: {result.stats['jitter']:.2f}ms") + +# Export results +exporter = BenchmarkExporter(result) +exporter.to_json("benchmark_results.json") +exporter.to_csv("benchmark_results.csv") +exporter.to_excel("benchmark_results.xlsx") +``` + +### Metrics Collected + +- **Total Latency**: Full round-trip time +- **Network Latency**: RTT minus server execution time +- **Execution Time**: Server-reported processing time +- **Jitter**: Standard deviation of latency +- **Percentiles**: P95 and P99 latency + +## PsychoPy Integration + +For laboratory experiments with CRIO devices: + +```python +from af_server_client.psychopy_bridge import SharedClient, ExperimentLauncher + +# Initialize shared client +client = SharedClient.initialize(config) +await client.connect() + +# Launch experiment +launcher = ExperimentLauncher(config) +await launcher.launch("my_experiment.py") +``` + +### Subject GUI + +```python +from af_server_client.psychopy_bridge.subject_gui import SubjectGUI + +gui = SubjectGUI(client, width=800, height=600) +gui.register_callback('start_trial', on_start_trial) +gui.register_callback('intensity_changed', on_intensity_changed) +gui.run() +``` + +## Development + +### Install Development Dependencies + +```bash +pip install -e ".[dev]" +``` + +### Run Tests + +```bash +pytest -v +``` + +### Run Tests with Coverage + +```bash +pytest --cov=af_server_client --cov-report=html +``` + +### Lint Code + +```bash +ruff check . +``` + +### Type Check + +```bash +mypy af_server_client +``` + +## Requirements + +- Python >= 3.10 +- af-serializer +- textual >= 0.50.0 +- uvloop >= 0.19.0 (Unix only) +- rich >= 13.7.0 +- pandas >= 2.2.0 +- numpy >= 1.26.0 +- openpyxl >= 3.1.0 + +## License + +MIT License - See LICENSE file for details. + +## Author + +Faragoz - aragon.froylan@gmail.com diff --git a/af_server_client/__init__.py b/af_server_client/__init__.py new file mode 100644 index 0000000..ec2df9c --- /dev/null +++ b/af_server_client/__init__.py @@ -0,0 +1,18 @@ +"""AF-Server Client - Low-latency TCP client for LabVIEW server communication.""" + +__version__ = "0.1.0" + +from af_server_client.core.tcp_client import TCPClient +from af_server_client.core.protocol import Protocol, EchoCommand, Response +from af_server_client.core.response_handler import ResponseHandler +from af_server_client.core.config import Config, load_config + +__all__ = [ + "TCPClient", + "Protocol", + "EchoCommand", + "Response", + "ResponseHandler", + "Config", + "load_config", +] diff --git a/af_server_client/_mock/__init__.py b/af_server_client/_mock/__init__.py new file mode 100644 index 0000000..540056d --- /dev/null +++ b/af_server_client/_mock/__init__.py @@ -0,0 +1,21 @@ +"""Mock module for af_serializer.""" + +from af_server_client._mock.af_serializer import ( + lvclass, + lvflatten, + lvunflatten, + LVString, + LVDouble, + LVBoolean, + LVU16, +) + +__all__ = [ + "lvclass", + "lvflatten", + "lvunflatten", + "LVString", + "LVDouble", + "LVBoolean", + "LVU16", +] diff --git a/af_server_client/_mock/af_serializer.py b/af_server_client/_mock/af_serializer.py new file mode 100644 index 0000000..2a12f15 --- /dev/null +++ b/af_server_client/_mock/af_serializer.py @@ -0,0 +1,153 @@ +"""Mock af_serializer module for testing when the real package is not available. + +This module provides minimal mock implementations of the af_serializer +decorators and functions needed for testing. +""" + +import struct +from dataclasses import dataclass +from typing import Any, Type, TypeVar + +T = TypeVar("T") + + +# Type hints (these are just markers, actual types are handled by dataclass) +LVString = str +LVDouble = float +LVBoolean = bool +LVU16 = int + + +# Registry to store class metadata +_class_registry: dict[str, Type[Any]] = {} + + +def lvclass(library: str = "", class_name: str = ""): + """Decorator to mark a class as a LabVIEW class. + + Args: + library: Library name. + class_name: Class name. + + Returns: + Decorated class. + """ + + def decorator(cls: Type[T]) -> Type[T]: + # Convert to dataclass if not already + if not hasattr(cls, "__dataclass_fields__"): + cls = dataclass(cls) + + # Store metadata + cls._lv_library = library # type: ignore + cls._lv_class_name = class_name # type: ignore + + # Register class + key = f"{library}:{class_name}" + _class_registry[key] = cls + + return cls + + return decorator + + +def lvflatten(obj: Any) -> bytes: + """Serialize an object to bytes. + + Args: + obj: Object to serialize. + + Returns: + Serialized bytes. + """ + if isinstance(obj, str): + encoded = obj.encode("utf-8") + return struct.pack(">I", len(encoded)) + encoded + + # Check bool before int/float because bool is a subclass of int + if isinstance(obj, bool): + return struct.pack(">?", obj) + + if isinstance(obj, (int, float)): + return struct.pack(">d", float(obj)) + + # For lvclass objects + if hasattr(obj, "_lv_class_name"): + result = b"" + library = getattr(obj, "_lv_library", "") + class_name = getattr(obj, "_lv_class_name", "") + + # Write class identifier + key = f"{library}:{class_name}".encode("utf-8") + result += struct.pack(">I", len(key)) + key + + # Write fields + if hasattr(obj, "__dataclass_fields__"): + for field_name in obj.__dataclass_fields__: + value = getattr(obj, field_name) + result += lvflatten(value) + + return result + + raise TypeError(f"Cannot serialize type: {type(obj)}") + + +def lvunflatten(data: bytes) -> Any: + """Deserialize bytes to an object. + + Args: + data: Bytes to deserialize. + + Returns: + Deserialized object. + """ + offset = 0 + + # Read class identifier + key_len = struct.unpack(">I", data[offset : offset + 4])[0] + offset += 4 + key = data[offset : offset + key_len].decode("utf-8") + offset += key_len + + # Find class in registry + if key not in _class_registry: + raise ValueError(f"Unknown class: {key}") + + cls = _class_registry[key] + + # Create instance + kwargs: dict[str, Any] = {} + if hasattr(cls, "__dataclass_fields__"): + for field_name, field_info in cls.__dataclass_fields__.items(): + field_type = field_info.type + value: Any + + if field_type in (str, LVString) or ( + isinstance(field_type, str) and "str" in field_type.lower() + ): + str_len = struct.unpack(">I", data[offset : offset + 4])[0] + offset += 4 + value = data[offset : offset + str_len].decode("utf-8") + offset += str_len + # Check bool before float since we want to match the type annotation + elif field_type in (bool, LVBoolean) or ( + isinstance(field_type, str) and "bool" in field_type.lower() + ): + value = struct.unpack(">?", data[offset : offset + 1])[0] + offset += 1 + elif field_type in (float, LVDouble) or ( + isinstance(field_type, str) and "float" in field_type.lower() + ): + value = struct.unpack(">d", data[offset : offset + 8])[0] + offset += 8 + elif field_type in (int, LVU16): + # For consistency with lvflatten which uses double for all numbers + value_float = struct.unpack(">d", data[offset : offset + 8])[0] + value = int(value_float) + offset += 8 + else: + raise TypeError(f"Unknown field type: {field_type}") + + kwargs[field_name] = value + + return cls(**kwargs) diff --git a/af_server_client/benchmark/__init__.py b/af_server_client/benchmark/__init__.py new file mode 100644 index 0000000..1aea07c --- /dev/null +++ b/af_server_client/benchmark/__init__.py @@ -0,0 +1,12 @@ +"""Benchmark module for RTT testing and metrics.""" + +from af_server_client.benchmark.metrics import Sample, BenchmarkRun +from af_server_client.benchmark.rtt_test import RTTBenchmark +from af_server_client.benchmark.export import BenchmarkExporter + +__all__ = [ + "Sample", + "BenchmarkRun", + "RTTBenchmark", + "BenchmarkExporter", +] diff --git a/af_server_client/benchmark/export.py b/af_server_client/benchmark/export.py new file mode 100644 index 0000000..fbb94bc --- /dev/null +++ b/af_server_client/benchmark/export.py @@ -0,0 +1,245 @@ +"""Benchmark export functionality. + +This module provides exporters for benchmark results in various formats +including JSON, CSV, and Excel. +""" + +import json +from datetime import datetime +from pathlib import Path +from typing import Optional + +import pandas as pd + +from af_server_client.benchmark.metrics import BenchmarkRun + + +class BenchmarkExporter: + """Export benchmark results to various formats. + + Supports JSON, CSV, and Excel export formats. + """ + + def __init__(self, benchmark_run: BenchmarkRun) -> None: + """Initialize the exporter. + + Args: + benchmark_run: Benchmark run to export. + """ + self.benchmark_run = benchmark_run + + def to_json( + self, + filename: str, + directory: Optional[str] = None, + ) -> Path: + """Export to JSON file. + + Args: + filename: Output filename. + directory: Optional output directory. + + Returns: + Path to the exported file. + """ + path = self._get_path(filename, directory, ".json") + + data = self.benchmark_run.to_dict() + data["export_timestamp"] = datetime.now().isoformat() + + with open(path, "w") as f: + json.dump(data, f, indent=2, default=str) + + return path + + def to_csv( + self, + filename: str, + directory: Optional[str] = None, + ) -> Path: + """Export to CSV file. + + Args: + filename: Output filename. + directory: Optional output directory. + + Returns: + Path to the exported file. + """ + path = self._get_path(filename, directory, ".csv") + + # Convert samples to DataFrame + rows = [] + for sample_id, sample in self.benchmark_run.samples.items(): + row = { + "sample_id": sample_id, + "request_timestamp": sample.request.get("timestamp"), + "response_timestamp": sample.response.get("timestamp"), + "payload_size": sample.request.get("payload_size"), + "total_latency_ms": sample.metrics.get("total_latency"), + "network_latency_ms": sample.metrics.get("network_latency"), + "exec_time_ms": sample.metrics.get("exec_time"), + } + rows.append(row) + + df = pd.DataFrame(rows) + df.to_csv(path, index=False) + + return path + + def to_excel( + self, + filename: str, + directory: Optional[str] = None, + ) -> Path: + """Export to Excel file with multiple sheets. + + Args: + filename: Output filename. + directory: Optional output directory. + + Returns: + Path to the exported file. + """ + path = self._get_path(filename, directory, ".xlsx") + + with pd.ExcelWriter(path, engine="openpyxl") as writer: + # Summary sheet + summary_data = { + "Metric": [ + "Total Samples", + "Duration (s)", + "Avg Total Latency (ms)", + "Avg Network Latency (ms)", + "Avg Exec Time (ms)", + "Jitter (ms)", + "Min Latency (ms)", + "Max Latency (ms)", + "Median Latency (ms)", + "P95 Latency (ms)", + "P99 Latency (ms)", + ], + "Value": [ + self.benchmark_run.stats.get("samples_count", 0), + self.benchmark_run.timing.get("duration", 0), + self.benchmark_run.stats.get("avg_total_latency", 0), + self.benchmark_run.stats.get("avg_network_latency", 0), + self.benchmark_run.stats.get("avg_exec_time", 0), + self.benchmark_run.stats.get("jitter", 0), + self.benchmark_run.stats.get("min_latency", 0), + self.benchmark_run.stats.get("max_latency", 0), + self.benchmark_run.stats.get("median_latency", 0), + self.benchmark_run.stats.get("p95_latency", 0), + self.benchmark_run.stats.get("p99_latency", 0), + ], + } + df_summary = pd.DataFrame(summary_data) + df_summary.to_excel(writer, sheet_name="Summary", index=False) + + # Samples sheet + rows = [] + for sample_id, sample in self.benchmark_run.samples.items(): + row = { + "Sample ID": sample_id, + "Request Timestamp": sample.request.get("timestamp"), + "Response Timestamp": sample.response.get("timestamp"), + "Payload Size (bytes)": sample.request.get("payload_size"), + "Total Latency (ms)": sample.metrics.get("total_latency"), + "Network Latency (ms)": sample.metrics.get("network_latency"), + "Exec Time (ms)": sample.metrics.get("exec_time"), + } + rows.append(row) + + df_samples = pd.DataFrame(rows) + df_samples.to_excel(writer, sheet_name="Samples", index=False) + + # Metadata sheet + metadata_data = { + "Key": list(self.benchmark_run.metadata.keys()), + "Value": [str(v) for v in self.benchmark_run.metadata.values()], + } + df_metadata = pd.DataFrame(metadata_data) + df_metadata.to_excel(writer, sheet_name="Metadata", index=False) + + return path + + def get_dataframe(self) -> pd.DataFrame: + """Get samples as a pandas DataFrame. + + Returns: + DataFrame with sample data. + """ + rows = [] + for sample_id, sample in self.benchmark_run.samples.items(): + row = { + "sample_id": sample_id, + "request_timestamp": sample.request.get("timestamp"), + "response_timestamp": sample.response.get("timestamp"), + "payload_size": sample.request.get("payload_size"), + "total_latency_ms": sample.metrics.get("total_latency"), + "network_latency_ms": sample.metrics.get("network_latency"), + "exec_time_ms": sample.metrics.get("exec_time"), + } + rows.append(row) + + return pd.DataFrame(rows) + + def _get_path( + self, + filename: str, + directory: Optional[str], + extension: str, + ) -> Path: + """Get the full path for output file. + + Args: + filename: Base filename. + directory: Optional directory. + extension: File extension. + + Returns: + Full path object. + """ + if not filename.endswith(extension): + filename = f"{filename}{extension}" + + if directory: + dir_path = Path(directory) + dir_path.mkdir(parents=True, exist_ok=True) + return dir_path / filename + + return Path(filename) + + +def export_benchmark( + benchmark_run: BenchmarkRun, + format_type: str, + filename: str, + directory: Optional[str] = None, +) -> Path: + """Export benchmark to specified format. + + Convenience function for exporting benchmark results. + + Args: + benchmark_run: Benchmark run to export. + format_type: Export format ('json', 'csv', 'excel'). + filename: Output filename. + directory: Optional output directory. + + Returns: + Path to the exported file. + + Raises: + ValueError: If format_type is unknown. + """ + exporter = BenchmarkExporter(benchmark_run) + + if format_type == "json": + return exporter.to_json(filename, directory) + elif format_type == "csv": + return exporter.to_csv(filename, directory) + elif format_type == "excel": + return exporter.to_excel(filename, directory) + else: + raise ValueError(f"Unknown export format: {format_type}") diff --git a/af_server_client/benchmark/metrics.py b/af_server_client/benchmark/metrics.py new file mode 100644 index 0000000..fe9b036 --- /dev/null +++ b/af_server_client/benchmark/metrics.py @@ -0,0 +1,190 @@ +"""Benchmark metrics classes. + +This module provides data classes for storing benchmark measurements +and statistics. +""" + +from dataclasses import dataclass, field +from typing import Any, Optional + + +@dataclass +class Sample: + """Single request-response measurement. + + Stores timing and payload information for a single benchmark sample. + + Attributes: + request: Request timing and data. + response: Response timing and data. + metrics: Calculated metrics (latency, exec_time, etc.). + """ + + request: dict[str, Any] = field( + default_factory=lambda: { + "timestamp": None, + "payload_size": None, + "raw": None, + } + ) + response: dict[str, Any] = field( + default_factory=lambda: { + "timestamp": None, + "payload_size": None, + "raw": None, + } + ) + metrics: dict[str, float] = field( + default_factory=lambda: { + "exec_time": 0.0, # Server execution time (ms) + "total_latency": 0.0, # Full round-trip (ms) + "network_latency": 0.0, # RTT minus exec_time (ms) + } + ) + + def calculate_metrics(self) -> None: + """Calculate metrics from request/response timestamps.""" + if ( + self.request["timestamp"] is not None + and self.response["timestamp"] is not None + ): + total_ms = (self.response["timestamp"] - self.request["timestamp"]) * 1000 + exec_ms = self.metrics.get("exec_time", 0.0) + self.metrics["total_latency"] = total_ms + self.metrics["network_latency"] = max(0, total_ms - exec_ms) + + +@dataclass +class BenchmarkRun: + """Collection of samples with statistics. + + Contains all samples from a benchmark run and computed statistics. + + Attributes: + timing: Start time, end time, and duration. + stats: Computed statistics (averages, percentiles, etc.). + samples: Dictionary of samples keyed by request ID. + metadata: Additional metadata about the benchmark run. + """ + + timing: dict[str, Optional[float]] = field( + default_factory=lambda: { + "start_time": None, + "end_time": None, + "duration": None, + } + ) + stats: dict[str, float] = field( + default_factory=lambda: { + "samples_count": 0, + "avg_exec_time": 0.0, + "avg_total_latency": 0.0, + "avg_network_latency": 0.0, + "jitter": 0.0, # Std dev of total_latency + "p95_latency": 0.0, # 95th percentile + "p99_latency": 0.0, # 99th percentile + "min_latency": 0.0, + "max_latency": 0.0, + "median_latency": 0.0, + } + ) + samples: dict[str, Sample] = field(default_factory=dict) + metadata: dict[str, Any] = field( + default_factory=lambda: { + "iterations": 0, + "payload_sizes": [], + "host": "", + "port": 0, + } + ) + + def add_sample(self, sample_id: str, sample: Sample) -> None: + """Add a sample to the benchmark run. + + Args: + sample_id: Unique identifier for the sample. + sample: Sample to add. + """ + self.samples[sample_id] = sample + + def compute_statistics(self) -> None: + """Compute statistics from all samples.""" + if not self.samples: + return + + # Collect all latencies + total_latencies = [] + network_latencies = [] + exec_times = [] + + for sample in self.samples.values(): + total_latencies.append(sample.metrics["total_latency"]) + network_latencies.append(sample.metrics["network_latency"]) + exec_times.append(sample.metrics["exec_time"]) + + self.stats["samples_count"] = len(self.samples) + + # Averages + self.stats["avg_total_latency"] = sum(total_latencies) / len(total_latencies) + self.stats["avg_network_latency"] = ( + sum(network_latencies) / len(network_latencies) + ) + self.stats["avg_exec_time"] = sum(exec_times) / len(exec_times) + + # Min/Max + self.stats["min_latency"] = min(total_latencies) + self.stats["max_latency"] = max(total_latencies) + + # Median + sorted_latencies = sorted(total_latencies) + mid = len(sorted_latencies) // 2 + if len(sorted_latencies) % 2 == 0: + self.stats["median_latency"] = ( + sorted_latencies[mid - 1] + sorted_latencies[mid] + ) / 2 + else: + self.stats["median_latency"] = sorted_latencies[mid] + + # Jitter (standard deviation) + mean = self.stats["avg_total_latency"] + variance = sum((x - mean) ** 2 for x in total_latencies) / len(total_latencies) + self.stats["jitter"] = variance ** 0.5 + + # Percentiles + p95_idx = int(len(sorted_latencies) * 0.95) + p99_idx = int(len(sorted_latencies) * 0.99) + self.stats["p95_latency"] = sorted_latencies[ + min(p95_idx, len(sorted_latencies) - 1) + ] + self.stats["p99_latency"] = sorted_latencies[ + min(p99_idx, len(sorted_latencies) - 1) + ] + + # Duration + if ( + self.timing["start_time"] is not None + and self.timing["end_time"] is not None + ): + self.timing["duration"] = ( + self.timing["end_time"] - self.timing["start_time"] + ) + + def to_dict(self) -> dict[str, Any]: + """Convert benchmark run to dictionary. + + Returns: + Dictionary representation of the benchmark run. + """ + return { + "timing": self.timing, + "stats": self.stats, + "samples": { + sid: { + "request": s.request, + "response": s.response, + "metrics": s.metrics, + } + for sid, s in self.samples.items() + }, + "metadata": self.metadata, + } diff --git a/af_server_client/benchmark/rtt_test.py b/af_server_client/benchmark/rtt_test.py new file mode 100644 index 0000000..fbd5187 --- /dev/null +++ b/af_server_client/benchmark/rtt_test.py @@ -0,0 +1,237 @@ +"""RTT benchmarking implementation. + +This module provides round-trip time benchmarking functionality +for measuring latency to the LabVIEW server. +""" + +import asyncio +import time +import uuid +from typing import TYPE_CHECKING, Callable, Optional + +from af_server_client.benchmark.metrics import Sample, BenchmarkRun +from af_server_client.core.protocol import EchoCommand + +if TYPE_CHECKING: + from af_server_client.core.config import Config + from af_server_client.core.tcp_client import TCPClient + + +class RTTBenchmark: + """Round-trip time benchmark runner. + + Measures latency by sending EchoCommand packets of various sizes + and recording response times. + """ + + def __init__( + self, + client: "TCPClient", + config: "Config", + ) -> None: + """Initialize the benchmark runner. + + Args: + client: TCP client to use for benchmarking. + config: Configuration object. + """ + self.client = client + self.config = config + self._progress_callback: Optional[Callable[[int, int], None]] = None + + def set_progress_callback( + self, callback: Optional[Callable[[int, int], None]] + ) -> None: + """Set progress callback. + + Args: + callback: Callback function(current, total). + """ + self._progress_callback = callback + + async def run( + self, + iterations: Optional[int] = None, + payload_sizes: Optional[list[int]] = None, + ) -> BenchmarkRun: + """Run the benchmark. + + Args: + iterations: Number of iterations per payload size. + payload_sizes: List of payload sizes to test. + + Returns: + Benchmark run with all samples and statistics. + """ + iterations = iterations or self.config.benchmark.default_iterations + payload_sizes = payload_sizes or self.config.benchmark.payload_sizes + + result = BenchmarkRun() + result.metadata["iterations"] = iterations + result.metadata["payload_sizes"] = payload_sizes + result.metadata["host"] = self.client.host + result.metadata["port"] = self.client.port + + total_samples = len(payload_sizes) * iterations + current = 0 + + result.timing["start_time"] = time.time() + + for payload_size in payload_sizes: + for i in range(iterations): + sample_id = str(uuid.uuid4()) + sample = await self._run_single(payload_size, sample_id) + result.add_sample(sample_id, sample) + + current += 1 + if self._progress_callback: + self._progress_callback(current, total_samples) + + # Small delay between samples + await asyncio.sleep(0.01) + + result.timing["end_time"] = time.time() + result.compute_statistics() + + return result + + async def _run_single(self, payload_size: int, sample_id: str) -> Sample: + """Run a single benchmark sample. + + Args: + payload_size: Size of payload to send. + sample_id: Unique identifier for this sample. + + Returns: + Sample with timing metrics. + """ + sample = Sample() + + # Create payload + payload = "x" * payload_size + + # Create echo command + cmd = EchoCommand() + cmd.payload = payload + cmd.timestamp = time.time() + + # Record request time + sample.request["timestamp"] = time.perf_counter() + sample.request["payload_size"] = payload_size + sample.request["raw"] = { + "payload": payload, + "timestamp": cmd.timestamp, + } + + try: + # Send and wait for response + response = await self.client.send_command(cmd, timeout=10.0) + + # Record response time + sample.response["timestamp"] = time.perf_counter() + + if response: + sample.response["payload_size"] = len(response.message) + sample.response["raw"] = { + "request_id": response.request_id, + "success": response.success, + "exec_time": response.exec_time, + "message": response.message, + } + # Convert exec_time from μs to ms + sample.metrics["exec_time"] = response.exec_time / 1000.0 + + # Calculate total latency + sample.calculate_metrics() + + except asyncio.TimeoutError: + sample.response["timestamp"] = time.perf_counter() + sample.metrics["total_latency"] = -1 # Indicate timeout + sample.metrics["network_latency"] = -1 + + return sample + + +class LatencyHistogram: + """Simple histogram for latency distribution. + + Buckets latency values for visualization and analysis. + """ + + def __init__( + self, + min_value: float = 0.0, + max_value: float = 100.0, + num_buckets: int = 20, + ) -> None: + """Initialize the histogram. + + Args: + min_value: Minimum latency value (ms). + max_value: Maximum latency value (ms). + num_buckets: Number of histogram buckets. + """ + self.min_value = min_value + self.max_value = max_value + self.num_buckets = num_buckets + self.bucket_size = (max_value - min_value) / num_buckets + self.buckets: list[int] = [0] * num_buckets + self.count = 0 + self.overflow = 0 + self.underflow = 0 + + def add(self, value: float) -> None: + """Add a value to the histogram. + + Args: + value: Latency value in ms. + """ + self.count += 1 + + if value < self.min_value: + self.underflow += 1 + return + + if value >= self.max_value: + self.overflow += 1 + return + + bucket_idx = int((value - self.min_value) / self.bucket_size) + bucket_idx = min(bucket_idx, self.num_buckets - 1) + self.buckets[bucket_idx] += 1 + + def get_distribution(self) -> list[tuple[float, float, int]]: + """Get the histogram distribution. + + Returns: + List of (bucket_start, bucket_end, count) tuples. + """ + result = [] + for i in range(self.num_buckets): + start = self.min_value + i * self.bucket_size + end = start + self.bucket_size + result.append((start, end, self.buckets[i])) + return result + + def to_ascii(self, width: int = 40) -> str: + """Generate ASCII representation of histogram. + + Args: + width: Maximum bar width. + + Returns: + ASCII histogram string. + """ + lines = [] + max_count = max(self.buckets) if self.buckets else 1 + + for i, count in enumerate(self.buckets): + start = self.min_value + i * self.bucket_size + bar_width = int((count / max_count) * width) if max_count > 0 else 0 + bar = "█" * bar_width + lines.append(f"{start:6.1f}ms |{bar} ({count})") + + if self.overflow: + lines.append(f" >max |{'█' * int((self.overflow / max_count) * width)} ({self.overflow})") + + return "\n".join(lines) diff --git a/af_server_client/cli.py b/af_server_client/cli.py new file mode 100644 index 0000000..e871c59 --- /dev/null +++ b/af_server_client/cli.py @@ -0,0 +1,208 @@ +"""CLI entry point for AF-Server client. + +This module provides the command-line interface for launching +the AF-Server client console. +""" + +import argparse +import sys +from pathlib import Path +from typing import TYPE_CHECKING, Optional + +from af_server_client.core.config import load_config +from af_server_client.console.app import run_console + +if TYPE_CHECKING: + from af_server_client.core.config import Config + +def parse_args(args: Optional[list[str]] = None) -> argparse.Namespace: + """Parse command-line arguments. + + Args: + args: Command-line arguments. If None, uses sys.argv. + + Returns: + Parsed arguments namespace. + """ + parser = argparse.ArgumentParser( + prog="af-server-client", + description="Low-latency TCP client for LabVIEW server communication", + ) + + parser.add_argument( + "-c", + "--config", + type=str, + help="Path to configuration file", + ) + + parser.add_argument( + "-H", + "--host", + type=str, + help="Server host (overrides config)", + ) + + parser.add_argument( + "-p", + "--port", + type=int, + help="Server port (overrides config)", + ) + + parser.add_argument( + "--connect", + action="store_true", + help="Auto-connect on startup", + ) + + parser.add_argument( + "--benchmark", + action="store_true", + help="Run benchmark and exit", + ) + + parser.add_argument( + "--iterations", + type=int, + default=10, + help="Benchmark iterations (default: 10)", + ) + + parser.add_argument( + "--export", + type=str, + help="Export benchmark results to file", + ) + + parser.add_argument( + "--format", + type=str, + choices=["json", "csv", "excel"], + default="json", + help="Export format (default: json)", + ) + + parser.add_argument( + "-v", + "--version", + action="version", + version="%(prog)s 0.1.0", + ) + + return parser.parse_args(args) + + +def main(args: Optional[list[str]] = None) -> int: + """Main entry point. + + Args: + args: Command-line arguments. + + Returns: + Exit code. + """ + parsed = parse_args(args) + + # Load configuration + config_path = parsed.config + if config_path: + config_path = Path(config_path) + if not config_path.exists(): + print(f"Error: Configuration file not found: {config_path}", file=sys.stderr) + return 1 + + config = load_config(config_path) + + # Apply command-line overrides + if parsed.host: + config.network.host = parsed.host + if parsed.port: + config.network.port = parsed.port + + # Run benchmark mode + if parsed.benchmark: + return run_benchmark_mode(config, parsed) + + # Run interactive console + try: + run_console(config) + return 0 + except KeyboardInterrupt: + return 0 + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + return 1 + + +def run_benchmark_mode( + config: "Config", + parsed: argparse.Namespace, +) -> int: + """Run benchmark mode without interactive console. + + Args: + config: Configuration object. + parsed: Parsed arguments. + + Returns: + Exit code. + """ + import asyncio + from af_server_client.core.tcp_client import TCPClient + from af_server_client.benchmark.rtt_test import RTTBenchmark + from af_server_client.benchmark.export import BenchmarkExporter + + async def run() -> int: + client = TCPClient(config) + + try: + print(f"Connecting to {config.network.host}:{config.network.port}...") + await client.connect() + print("Connected!") + + print(f"Running benchmark ({parsed.iterations} iterations)...") + benchmark = RTTBenchmark(client, config) + + def progress(current: int, total: int) -> None: + pct = (current / total) * 100 + print(f"\rProgress: {current}/{total} ({pct:.1f}%)", end="", flush=True) + + benchmark.set_progress_callback(progress) + result = await benchmark.run(iterations=parsed.iterations) + print() # New line after progress + + # Print results + print("\nResults:") + print(f" Samples: {result.stats['samples_count']}") + print(f" Duration: {result.timing['duration']:.2f}s") + print(f" Avg Latency: {result.stats['avg_total_latency']:.2f}ms") + print(f" P95 Latency: {result.stats['p95_latency']:.2f}ms") + print(f" P99 Latency: {result.stats['p99_latency']:.2f}ms") + print(f" Jitter: {result.stats['jitter']:.2f}ms") + + # Export if requested + if parsed.export: + exporter = BenchmarkExporter(result) + if parsed.format == "json": + path = exporter.to_json(parsed.export) + elif parsed.format == "csv": + path = exporter.to_csv(parsed.export) + else: + path = exporter.to_excel(parsed.export) + print(f"\nExported to: {path}") + + return 0 + + except Exception as e: + print(f"\nError: {e}", file=sys.stderr) + return 1 + + finally: + await client.disconnect() + + return asyncio.run(run()) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/af_server_client/console/__init__.py b/af_server_client/console/__init__.py new file mode 100644 index 0000000..c5ec052 --- /dev/null +++ b/af_server_client/console/__init__.py @@ -0,0 +1,11 @@ +"""Console module for Textual TUI application.""" + +from af_server_client.console.app import AFServerConsole +from af_server_client.console.commands import CommandExecutor +from af_server_client.console.command_parser import CommandParser + +__all__ = [ + "AFServerConsole", + "CommandExecutor", + "CommandParser", +] diff --git a/af_server_client/console/app.py b/af_server_client/console/app.py new file mode 100644 index 0000000..eaf4e80 --- /dev/null +++ b/af_server_client/console/app.py @@ -0,0 +1,229 @@ +"""Main Textual TUI application for AF-Server client. + +This module provides the interactive console interface for managing +communications with the LabVIEW server. +""" + +from typing import Optional + +from textual.app import App, ComposeResult +from textual.binding import Binding +from textual.widgets import Footer, Header, Input, RichLog + +from af_server_client.console.commands import CommandExecutor +from af_server_client.console.widgets import StatusPanel, ProgressPanel +from af_server_client.core.config import Config +from af_server_client.core.tcp_client import TCPClient + + +class AFServerConsole(App): + """Interactive TUI for AF-Server client. + + Provides a full-featured console interface with: + - Status panel showing connection and latency metrics + - Rich log for command output + - Command input with history + - Progress display for benchmarks + """ + + TITLE = "AF-Server Client" + CSS = """ + Screen { + layout: vertical; + } + + #status { + height: 5; + border: round $primary; + padding: 0 1; + } + + #progress { + height: 3; + border: round $accent; + padding: 0 1; + display: none; + } + + #progress.visible { + display: block; + } + + #log { + height: 1fr; + border: round $secondary; + padding: 0 1; + } + + #input { + height: 3; + border: round $primary; + } + """ + + BINDINGS = [ + Binding("ctrl+c", "quit", "Quit", priority=True), + Binding("ctrl+l", "clear_log", "Clear"), + Binding("ctrl+d", "disconnect", "Disconnect"), + ] + + def __init__( + self, + client: Optional[TCPClient] = None, + config: Optional[Config] = None, + ) -> None: + """Initialize the console application. + + Args: + client: TCP client instance. If None, creates a new one. + config: Configuration object. If None, loads default. + """ + super().__init__() + self.config = config or Config() + self.client = client or TCPClient(self.config) + self.executor = CommandExecutor( + self.client, + self.config, + log_callback=self._log_message, + ) + self._history_index = -1 + + def compose(self) -> ComposeResult: + """Compose the application UI.""" + yield Header() + yield StatusPanel(self.client, id="status") + yield ProgressPanel(id="progress") + yield RichLog(id="log", highlight=True, markup=True) + yield Input(placeholder="Enter command... (type 'help' for commands)", id="input") + yield Footer() + + def on_mount(self) -> None: + """Handle application mount.""" + # Set up client callbacks + self.client.add_callback(self._handle_client_event) + + # Set up progress callback + self.executor.set_progress_callback(self._update_progress) + + # Welcome message + log = self.query_one("#log", RichLog) + log.write("[bold cyan]AF-Server Client Console[/bold cyan]") + log.write("Type [bold]help[/bold] for available commands") + log.write("") + + # Focus on input + self.query_one("#input", Input).focus() + + def _log_message(self, level: str, message: str) -> None: + """Log a message to the console. + + Args: + level: Log level (info, warning, error). + message: Message to log. + """ + log = self.query_one("#log", RichLog) + + if level == "error": + log.write(f"[red]ERROR:[/red] {message}") + elif level == "warning": + log.write(f"[yellow]WARNING:[/yellow] {message}") + else: + log.write(f"[dim]INFO:[/dim] {message}") + + def _handle_client_event(self, event_type: str, data: object) -> None: + """Handle TCP client events. + + Args: + event_type: Type of event. + data: Event data. + """ + log = self.query_one("#log", RichLog) + + if event_type == "connected": + if isinstance(data, dict): + log.write( + f"[green]✓ Connected to {data.get('host')}:{data.get('port')}[/green]" + ) + elif event_type == "disconnected": + log.write("[red]✗ Disconnected[/red]") + elif event_type == "error": + if isinstance(data, dict): + log.write(f"[red]Error: {data.get('error', 'Unknown error')}[/red]") + + def _update_progress(self, current: int, total: int, message: str) -> None: + """Update progress display. + + Args: + current: Current progress. + total: Total count. + message: Progress message. + """ + progress = self.query_one("#progress", ProgressPanel) + if current >= total: + progress.remove_class("visible") + else: + progress.add_class("visible") + progress.set_progress(current, total, message) + + async def on_input_submitted(self, event: Input.Submitted) -> None: + """Handle input submission. + + Args: + event: Input submitted event. + """ + command = event.value.strip() + input_widget = event.input + input_widget.value = "" + + if not command: + return + + # Add to display + log = self.query_one("#log", RichLog) + log.write(f"[cyan]> {command}[/cyan]") + + # Execute command + result = await self.executor.execute(command) + + # Handle special actions + if result.data and isinstance(result.data, dict): + action = result.data.get("action") + if action == "quit": + self.exit() + return + elif action == "clear": + log.clear() + return + + # Display result + if result.success: + if result.message: + log.write(f"[green]{result.message}[/green]") + else: + log.write(f"[red]{result.message}[/red]") + + log.write("") + + def action_clear_log(self) -> None: + """Clear the log widget.""" + log = self.query_one("#log", RichLog) + log.clear() + + async def action_disconnect(self) -> None: + """Disconnect from server.""" + await self.client.disconnect() + + async def action_quit(self) -> None: + """Quit the application.""" + await self.client.disconnect() + self.exit() + + +def run_console(config: Optional[Config] = None) -> None: + """Run the console application. + + Args: + config: Optional configuration object. + """ + app = AFServerConsole(config=config) + app.run() diff --git a/af_server_client/console/command_parser.py b/af_server_client/console/command_parser.py new file mode 100644 index 0000000..409d43c --- /dev/null +++ b/af_server_client/console/command_parser.py @@ -0,0 +1,260 @@ +"""Command parser for the console application. + +This module provides parsing and validation for console commands. +""" + +import shlex +from dataclasses import dataclass +from typing import Any, Optional + + +@dataclass +class ParsedCommand: + """Represents a parsed command. + + Attributes: + name: Command name (e.g., 'connect', 'benchmark'). + args: List of positional arguments. + kwargs: Dictionary of keyword arguments. + raw: Original raw command string. + """ + + name: str + args: list[str] + kwargs: dict[str, str] + raw: str + + +class CommandParseError(Exception): + """Error during command parsing.""" + + +class CommandParser: + """Parser for console commands. + + Supports: + - Simple commands: `connect host port` + - Commands with flags: `benchmark --iterations 10` + - Quoted arguments: `send "hello world"` + """ + + # Command signatures for validation + COMMANDS: dict[str, dict[str, Any]] = { + "connect": { + "args": ["host", "port"], + "optional": True, + "description": "Connect to server (uses config defaults if no args)", + }, + "disconnect": { + "args": [], + "optional": False, + "description": "Disconnect from server", + }, + "status": { + "args": [], + "optional": False, + "description": "Show connection status", + }, + "send": { + "args": ["command"], + "optional": False, + "description": "Send a command to server", + }, + "benchmark": { + "args": ["iterations"], + "optional": True, + "description": "Run RTT benchmark (default: 10 iterations)", + }, + "launch-psychopy": { + "args": ["experiment"], + "optional": True, + "description": "Launch PsychoPy experiment", + }, + "config": { + "args": ["action"], + "optional": False, + "description": "Config management (show/set/reload)", + }, + "export": { + "args": ["format", "filename"], + "optional": False, + "description": "Export benchmark data (json/csv/excel)", + }, + "clear": { + "args": [], + "optional": False, + "description": "Clear console log", + }, + "help": { + "args": ["command"], + "optional": True, + "description": "Show help for command", + }, + "quit": { + "args": [], + "optional": False, + "description": "Exit the application", + }, + "exit": { + "args": [], + "optional": False, + "description": "Exit the application", + }, + } + + def __init__(self) -> None: + """Initialize the command parser.""" + self._history: list[str] = [] + self._max_history = 100 + + @property + def history(self) -> list[str]: + """Get command history.""" + return self._history.copy() + + def add_to_history(self, command: str) -> None: + """Add a command to history. + + Args: + command: Command string to add. + """ + if command and (not self._history or self._history[-1] != command): + self._history.append(command) + if len(self._history) > self._max_history: + self._history.pop(0) + + def parse(self, command_str: str) -> ParsedCommand: + """Parse a command string. + + Args: + command_str: Raw command string. + + Returns: + Parsed command object. + + Raises: + CommandParseError: If parsing fails. + """ + command_str = command_str.strip() + if not command_str: + raise CommandParseError("Empty command") + + try: + parts = shlex.split(command_str) + except ValueError as e: + raise CommandParseError(f"Invalid command syntax: {e}") from e + + if not parts: + raise CommandParseError("Empty command") + + name = parts[0].lower() + args: list[str] = [] + kwargs: dict[str, str] = {} + + # Parse remaining parts + i = 1 + while i < len(parts): + part = parts[i] + if part.startswith("--"): + # Keyword argument + key = part[2:] + if i + 1 < len(parts) and not parts[i + 1].startswith("--"): + kwargs[key] = parts[i + 1] + i += 2 + else: + kwargs[key] = "true" + i += 1 + elif part.startswith("-") and len(part) == 2: + # Short flag + key = part[1] + if i + 1 < len(parts) and not parts[i + 1].startswith("-"): + kwargs[key] = parts[i + 1] + i += 2 + else: + kwargs[key] = "true" + i += 1 + else: + # Positional argument + args.append(part) + i += 1 + + return ParsedCommand( + name=name, + args=args, + kwargs=kwargs, + raw=command_str, + ) + + def validate(self, parsed: ParsedCommand) -> Optional[str]: + """Validate a parsed command. + + Args: + parsed: Parsed command to validate. + + Returns: + Error message if invalid, None if valid. + """ + if parsed.name not in self.COMMANDS: + return f"Unknown command: {parsed.name}. Type 'help' for available commands." + + cmd_spec = self.COMMANDS[parsed.name] + required_args = cmd_spec["args"] + optional = cmd_spec.get("optional", False) + + if not optional and len(parsed.args) < len(required_args): + missing = required_args[len(parsed.args) :] + return f"Missing required arguments: {', '.join(missing)}" + + return None + + def get_completions(self, partial: str) -> list[str]: + """Get command completions for partial input. + + Args: + partial: Partial command input. + + Returns: + List of possible completions. + """ + partial = partial.lower().strip() + if not partial: + return list(self.COMMANDS.keys()) + + # Check if we're completing the command name or arguments + parts = partial.split() + if len(parts) == 1: + # Completing command name + return [cmd for cmd in self.COMMANDS.keys() if cmd.startswith(parts[0])] + else: + # For now, just return empty list for argument completion + # Could be extended for specific commands + return [] + + def get_help(self, command: Optional[str] = None) -> str: + """Get help text. + + Args: + command: Specific command to get help for. + + Returns: + Help text string. + """ + if command and command in self.COMMANDS: + cmd_spec = self.COMMANDS[command] + args_str = " ".join(f"<{arg}>" for arg in cmd_spec["args"]) + optional_marker = " (args optional)" if cmd_spec.get("optional") else "" + return ( + f"{command} {args_str}{optional_marker}\n" + f" {cmd_spec['description']}" + ) + + # General help + lines = ["Available commands:", ""] + for cmd_name, cmd_spec in sorted(self.COMMANDS.items()): + args_str = " ".join(f"<{arg}>" for arg in cmd_spec["args"]) + optional_marker = "*" if cmd_spec.get("optional") else "" + lines.append(f" {cmd_name} {args_str}{optional_marker}") + lines.append(f" {cmd_spec['description']}") + lines.append("") + lines.append("* = arguments optional") + return "\n".join(lines) diff --git a/af_server_client/console/commands.py b/af_server_client/console/commands.py new file mode 100644 index 0000000..3e448ae --- /dev/null +++ b/af_server_client/console/commands.py @@ -0,0 +1,337 @@ +"""Command executor and registry for the console application. + +This module provides the command execution logic for console commands. +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Optional + +from af_server_client.benchmark.rtt_test import RTTBenchmark +from af_server_client.benchmark.export import BenchmarkExporter +from af_server_client.console.command_parser import CommandParser, ParsedCommand +from af_server_client.core.config import Config, reload_config +from af_server_client.psychopy_bridge.launcher import ExperimentLauncher + +if TYPE_CHECKING: + from af_server_client.core.tcp_client import TCPClient + + +@dataclass +class CommandResult: + """Result of a command execution. + + Attributes: + success: Whether the command succeeded. + message: Result message. + data: Optional additional data. + """ + + success: bool + message: str + data: Optional[Any] = None + + +CommandHandler = Callable[[ParsedCommand], Coroutine[Any, Any, CommandResult]] + + +class CommandExecutor: + """Execute console commands. + + Handles all console commands including connection management, + benchmarking, and configuration. + """ + + def __init__( + self, + client: "TCPClient", + config: Config, + log_callback: Optional[Callable[[str, str], None]] = None, + ) -> None: + """Initialize the command executor. + + Args: + client: TCP client instance. + config: Configuration object. + log_callback: Optional callback for logging (level, message). + """ + self.client = client + self.config = config + self.parser = CommandParser() + self._log_callback = log_callback + self._benchmark_result: Optional[Any] = None + self._progress_callback: Optional[Callable[[int, int, str], None]] = None + self._launcher = ExperimentLauncher(config) + + def set_progress_callback( + self, callback: Optional[Callable[[int, int, str], None]] + ) -> None: + """Set the progress callback. + + Args: + callback: Callback function(current, total, message). + """ + self._progress_callback = callback + + def _log(self, level: str, message: str) -> None: + """Log a message. + + Args: + level: Log level (info, warning, error). + message: Message to log. + """ + if self._log_callback: + self._log_callback(level, message) + + async def execute(self, command_str: str) -> CommandResult: + """Execute a command string. + + Args: + command_str: Raw command string. + + Returns: + Command result. + """ + try: + parsed = self.parser.parse(command_str) + except Exception as e: + return CommandResult(False, f"Parse error: {e}") + + # Validate command + error = self.parser.validate(parsed) + if error: + return CommandResult(False, error) + + # Add to history + self.parser.add_to_history(command_str) + + # Execute command + handler = self._get_handler(parsed.name) + if handler: + try: + return await handler(parsed) + except Exception as e: + return CommandResult(False, f"Execution error: {e}") + + return CommandResult(False, f"No handler for command: {parsed.name}") + + def _get_handler(self, command: str) -> Optional[CommandHandler]: + """Get the handler for a command. + + Args: + command: Command name. + + Returns: + Handler function or None. + """ + handlers: dict[str, CommandHandler] = { + "connect": self._cmd_connect, + "disconnect": self._cmd_disconnect, + "status": self._cmd_status, + "send": self._cmd_send, + "benchmark": self._cmd_benchmark, + "launch-psychopy": self._cmd_launch_psychopy, + "config": self._cmd_config, + "export": self._cmd_export, + "clear": self._cmd_clear, + "help": self._cmd_help, + "quit": self._cmd_quit, + "exit": self._cmd_quit, + } + return handlers.get(command) + + async def _cmd_connect(self, parsed: ParsedCommand) -> CommandResult: + """Handle connect command.""" + host = parsed.args[0] if len(parsed.args) > 0 else None + port = int(parsed.args[1]) if len(parsed.args) > 1 else None + + try: + await self.client.connect(host=host, port=port) + target_host = host or self.client.host + target_port = port or self.client.port + return CommandResult(True, f"Connected to {target_host}:{target_port}") + except Exception as e: + return CommandResult(False, f"Connection failed: {e}") + + async def _cmd_disconnect(self, parsed: ParsedCommand) -> CommandResult: + """Handle disconnect command.""" + await self.client.disconnect() + return CommandResult(True, "Disconnected") + + async def _cmd_status(self, parsed: ParsedCommand) -> CommandResult: + """Handle status command.""" + stats = self.client.get_statistics() + + lines = [] + if stats["connected"]: + lines.append(f"✓ Connected to {stats['host']}:{stats['port']}") + else: + lines.append("✗ Disconnected") + + lines.append(f"Uptime: {stats['uptime_seconds']:.1f}s") + lines.append(f"Avg Latency: {stats['avg_latency_ms']:.2f}ms") + lines.append(f"P95 Latency: {stats['p95_latency_ms']:.2f}ms") + lines.append(f"P99 Latency: {stats['p99_latency_ms']:.2f}ms") + lines.append(f"Pending: {stats['pending_count']}") + lines.append(f"Completed: {stats['completed_count']}") + lines.append(f"Failed: {stats['failed_count']}") + lines.append(f"Timeouts: {stats['timeout_count']}") + + return CommandResult(True, "\n".join(lines), data=stats) + + async def _cmd_send(self, parsed: ParsedCommand) -> CommandResult: + """Handle send command.""" + if not self.client.connected: + return CommandResult(False, "Not connected") + + command = " ".join(parsed.args) + self._log("info", f"Sending: {command}") + + try: + from af_server_client.core.protocol import Protocol + + protocol = Protocol() + protocol.data = command + response = await self.client.send_command(protocol) + if response: + return CommandResult( + response.success, + f"Response: {response.message}", + data=response, + ) + return CommandResult(False, "No response received") + except Exception as e: + return CommandResult(False, f"Send failed: {e}") + + async def _cmd_benchmark(self, parsed: ParsedCommand) -> CommandResult: + """Handle benchmark command.""" + if not self.client.connected: + return CommandResult(False, "Not connected") + + iterations = ( + int(parsed.args[0]) + if len(parsed.args) > 0 + else self.config.benchmark.default_iterations + ) + + # Parse payload sizes from remaining args or use config defaults + if len(parsed.args) > 1: + sizes = [int(x) for x in parsed.args[1:]] + else: + sizes = self.config.benchmark.payload_sizes + + benchmark = RTTBenchmark(self.client, self.config) + + # Set up progress callback + if self._progress_callback: + def progress_handler(current: int, total: int) -> None: + if self._progress_callback: + self._progress_callback(current, total, f"Running: {current}/{total}") + + benchmark.set_progress_callback(progress_handler) + + try: + result = await benchmark.run(iterations=iterations, payload_sizes=sizes) + self._benchmark_result = result + + lines = [ + f"Benchmark completed: {result.stats['samples_count']} samples", + f"Duration: {result.timing['duration']:.2f}s", + f"Avg Total Latency: {result.stats['avg_total_latency']:.2f}ms", + f"Avg Network Latency: {result.stats['avg_network_latency']:.2f}ms", + f"Avg Exec Time: {result.stats['avg_exec_time']:.2f}ms", + f"Jitter: {result.stats['jitter']:.2f}ms", + f"P95 Latency: {result.stats['p95_latency']:.2f}ms", + f"P99 Latency: {result.stats['p99_latency']:.2f}ms", + ] + return CommandResult(True, "\n".join(lines), data=result) + except Exception as e: + return CommandResult(False, f"Benchmark failed: {e}") + + async def _cmd_launch_psychopy(self, parsed: ParsedCommand) -> CommandResult: + """Handle launch-psychopy command.""" + experiment = ( + parsed.args[0] + if len(parsed.args) > 0 + else self.config.psychopy.default_experiment + ) + + try: + success = await self._launcher.launch(experiment) + if success: + return CommandResult(True, f"Launched experiment: {experiment}") + return CommandResult(False, f"Failed to launch experiment: {experiment}") + except Exception as e: + return CommandResult(False, f"Launch failed: {e}") + + async def _cmd_config(self, parsed: ParsedCommand) -> CommandResult: + """Handle config command.""" + if not parsed.args: + return CommandResult(False, "Usage: config ") + + action = parsed.args[0].lower() + + if action == "show": + config_dict = self.config.to_dict() + lines = [] + for section, values in config_dict.items(): + lines.append(f"[{section}]") + for key, value in values.items(): + lines.append(f" {key} = {value}") + return CommandResult(True, "\n".join(lines), data=config_dict) + + elif action == "set": + if len(parsed.args) < 3: + return CommandResult(False, "Usage: config set ") + key = parsed.args[1] + value = " ".join(parsed.args[2:]) + if self.config.set(key, value): + self.config.save() + return CommandResult(True, f"Set {key} = {value}") + return CommandResult(False, f"Failed to set {key}") + + elif action == "reload": + self.config = reload_config(self.config) + return CommandResult(True, "Configuration reloaded") + + return CommandResult(False, f"Unknown config action: {action}") + + async def _cmd_export(self, parsed: ParsedCommand) -> CommandResult: + """Handle export command.""" + if self._benchmark_result is None: + return CommandResult(False, "No benchmark results to export. Run benchmark first.") + + if len(parsed.args) < 2: + return CommandResult(False, "Usage: export ") + + export_format = parsed.args[0].lower() + filename = parsed.args[1] + + exporter = BenchmarkExporter(self._benchmark_result) + + try: + if export_format == "json": + path = exporter.to_json(filename) + elif export_format == "csv": + path = exporter.to_csv(filename) + elif export_format == "excel": + path = exporter.to_excel(filename) + else: + return CommandResult(False, f"Unknown format: {export_format}") + + return CommandResult(True, f"Exported to: {path}") + except Exception as e: + return CommandResult(False, f"Export failed: {e}") + + async def _cmd_clear(self, parsed: ParsedCommand) -> CommandResult: + """Handle clear command.""" + return CommandResult(True, "", data={"action": "clear"}) + + async def _cmd_help(self, parsed: ParsedCommand) -> CommandResult: + """Handle help command.""" + command = parsed.args[0] if len(parsed.args) > 0 else None + help_text = self.parser.get_help(command) + return CommandResult(True, help_text) + + async def _cmd_quit(self, parsed: ParsedCommand) -> CommandResult: + """Handle quit/exit command.""" + return CommandResult(True, "Goodbye!", data={"action": "quit"}) diff --git a/af_server_client/console/widgets.py b/af_server_client/console/widgets.py new file mode 100644 index 0000000..3987336 --- /dev/null +++ b/af_server_client/console/widgets.py @@ -0,0 +1,186 @@ +"""Custom widgets for the Textual TUI console. + +This module provides widgets for displaying status information +and statistics in the console application. +""" + +from typing import TYPE_CHECKING, Any + +from rich.text import Text +from textual.widgets import Static + +if TYPE_CHECKING: + from af_server_client.core.tcp_client import TCPClient + + +class StatusPanel(Static): + """Panel displaying connection status and statistics. + + Displays: + - Connection status (connected/disconnected) + - Latency metrics (avg, p95, p99) + - Request counts (pending, completed, failed) + - Session uptime + """ + + DEFAULT_CSS = """ + StatusPanel { + height: 5; + border: round $primary; + padding: 0 1; + } + """ + + def __init__( + self, + client: "TCPClient", + *args: Any, + **kwargs: Any, + ) -> None: + """Initialize the status panel. + + Args: + client: TCP client to display status for. + *args: Additional positional arguments for Static. + **kwargs: Additional keyword arguments for Static. + """ + super().__init__(*args, **kwargs) + self.client = client + self._update_interval = 1.0 + + def on_mount(self) -> None: + """Start the update timer when mounted.""" + self.set_interval(self._update_interval, self._update_display) + self._update_display() + + def _format_uptime(self, seconds: float) -> str: + """Format uptime as HH:MM:SS. + + Args: + seconds: Uptime in seconds. + + Returns: + Formatted uptime string. + """ + hours = int(seconds // 3600) + minutes = int((seconds % 3600) // 60) + secs = int(seconds % 60) + return f"{hours:02d}:{minutes:02d}:{secs:02d}" + + def _update_display(self) -> None: + """Update the status display.""" + stats = self.client.get_statistics() + + # Build status text + lines = [] + + # Connection status + if stats["connected"]: + conn_status = Text("✓ Connected", style="green bold") + conn_status.append(f" to {stats['host']}:{stats['port']}", style="white") + else: + conn_status = Text("✗ Disconnected", style="red bold") + lines.append(Text("Connection: ") + conn_status) + + # Latency + avg_latency = stats.get("avg_latency_ms", 0.0) + p95_latency = stats.get("p95_latency_ms", 0.0) + p99_latency = stats.get("p99_latency_ms", 0.0) + latency_line = Text("Latency: ") + latency_line.append(f"Avg: {avg_latency:.1f}ms", style="cyan") + latency_line.append(" | ", style="dim") + latency_line.append(f"P95: {p95_latency:.1f}ms", style="yellow") + latency_line.append(" | ", style="dim") + latency_line.append(f"P99: {p99_latency:.1f}ms", style="red") + lines.append(latency_line) + + # Requests + pending = stats.get("pending_count", 0) + completed = stats.get("completed_count", 0) + failed = stats.get("failed_count", 0) + stats.get("timeout_count", 0) + request_line = Text("Requests: ") + request_line.append(f"Pending: {pending}", style="blue") + request_line.append(" | ", style="dim") + request_line.append(f"Completed: {completed}", style="green") + request_line.append(" | ", style="dim") + request_line.append(f"Failed: {failed}", style="red") + lines.append(request_line) + + # Uptime + uptime = stats.get("uptime_seconds", 0.0) + uptime_line = Text("Uptime: ") + uptime_line.append(f"Session: {self._format_uptime(uptime)}", style="magenta") + lines.append(uptime_line) + + # Combine all lines + output = Text() + for i, line in enumerate(lines): + output.append_text(line) + if i < len(lines) - 1: + output.append("\n") + + self.update(output) + + +class ProgressPanel(Static): + """Panel for displaying benchmark progress. + + Shows progress bar and current benchmark statistics. + """ + + DEFAULT_CSS = """ + ProgressPanel { + height: 3; + border: round $accent; + padding: 0 1; + } + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize the progress panel.""" + super().__init__(*args, **kwargs) + self._progress = 0.0 + self._message = "Ready" + self._total = 0 + self._current = 0 + + def set_progress(self, current: int, total: int, message: str = "") -> None: + """Update progress display. + + Args: + current: Current progress count. + total: Total items count. + message: Optional progress message. + """ + self._current = current + self._total = total + self._progress = current / total if total > 0 else 0.0 + self._message = message or f"Progress: {current}/{total}" + self._update_display() + + def reset(self) -> None: + """Reset progress display.""" + self._progress = 0.0 + self._message = "Ready" + self._total = 0 + self._current = 0 + self._update_display() + + def _update_display(self) -> None: + """Update the progress display.""" + bar_width = 40 + filled = int(bar_width * self._progress) + empty = bar_width - filled + + bar = Text() + bar.append("[") + bar.append("=" * filled, style="green") + bar.append(" " * empty, style="dim") + bar.append("]") + bar.append(f" {self._progress * 100:.1f}%", style="cyan") + + output = Text() + output.append(f"{self._message}\n") + output.append_text(bar) + + self.update(output) diff --git a/af_server_client/core/__init__.py b/af_server_client/core/__init__.py new file mode 100644 index 0000000..22412a3 --- /dev/null +++ b/af_server_client/core/__init__.py @@ -0,0 +1,16 @@ +"""Core module for AF-Server client functionality.""" + +from af_server_client.core.tcp_client import TCPClient +from af_server_client.core.protocol import Protocol, EchoCommand, Response +from af_server_client.core.response_handler import ResponseHandler +from af_server_client.core.config import Config, load_config + +__all__ = [ + "TCPClient", + "Protocol", + "EchoCommand", + "Response", + "ResponseHandler", + "Config", + "load_config", +] diff --git a/af_server_client/core/config.py b/af_server_client/core/config.py new file mode 100644 index 0000000..5779571 --- /dev/null +++ b/af_server_client/core/config.py @@ -0,0 +1,354 @@ +"""Configuration loader and manager for AF-Server client. + +This module provides TOML-based configuration management with hot-reload +support for non-critical settings. +""" + +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Optional + +if sys.version_info >= (3, 11): + import tomllib +else: + import tomli as tomllib + + +DEFAULT_CONFIG = { + "network": { + "host": "127.0.0.1", + "port": 2001, + "timeout": 5.0, + "buffer_size": 256, + "enable_nodelay": True, + "enable_qos": True, + "heartbeat_interval": 30.0, + }, + "console": { + "log_level": "DEBUG", + "history_size": 100, + "max_log_lines": 1000, + "theme": "monokai", + }, + "benchmark": { + "default_iterations": 10, + "payload_sizes": [0, 128, 256, 512, 1024, 2048, 4096, 8192], + "export_format": "json", + "export_directory": "./benchmark_results", + }, + "psychopy": { + "experiments_dir": "./experiments", + "default_experiment": "feedback_loop.py", + "subject_gui_width": 800, + "subject_gui_height": 600, + }, +} + + +@dataclass +class NetworkConfig: + """Network configuration settings.""" + + host: str = "127.0.0.1" + port: int = 2001 + timeout: float = 5.0 + buffer_size: int = 256 + enable_nodelay: bool = True + enable_qos: bool = True + heartbeat_interval: float = 30.0 + + +@dataclass +class ConsoleConfig: + """Console TUI configuration settings.""" + + log_level: str = "DEBUG" + history_size: int = 100 + max_log_lines: int = 1000 + theme: str = "monokai" + + +@dataclass +class BenchmarkConfig: + """Benchmark configuration settings.""" + + default_iterations: int = 10 + payload_sizes: list[int] = field( + default_factory=lambda: [0, 128, 256, 512, 1024, 2048, 4096, 8192] + ) + export_format: str = "json" + export_directory: str = "./benchmark_results" + + +@dataclass +class PsychoPyConfig: + """PsychoPy integration configuration settings.""" + + experiments_dir: str = "./experiments" + default_experiment: str = "feedback_loop.py" + subject_gui_width: int = 800 + subject_gui_height: int = 600 + + +@dataclass +class Config: + """Main configuration container. + + Attributes: + network: Network connection settings. + console: Console TUI settings. + benchmark: Benchmark settings. + psychopy: PsychoPy integration settings. + config_path: Path to the configuration file. + """ + + network: NetworkConfig = field(default_factory=NetworkConfig) + console: ConsoleConfig = field(default_factory=ConsoleConfig) + benchmark: BenchmarkConfig = field(default_factory=BenchmarkConfig) + psychopy: PsychoPyConfig = field(default_factory=PsychoPyConfig) + config_path: Optional[Path] = None + + def to_dict(self) -> dict[str, Any]: + """Convert configuration to dictionary. + + Returns: + Dictionary representation of the configuration. + """ + return { + "network": { + "host": self.network.host, + "port": self.network.port, + "timeout": self.network.timeout, + "buffer_size": self.network.buffer_size, + "enable_nodelay": self.network.enable_nodelay, + "enable_qos": self.network.enable_qos, + "heartbeat_interval": self.network.heartbeat_interval, + }, + "console": { + "log_level": self.console.log_level, + "history_size": self.console.history_size, + "max_log_lines": self.console.max_log_lines, + "theme": self.console.theme, + }, + "benchmark": { + "default_iterations": self.benchmark.default_iterations, + "payload_sizes": self.benchmark.payload_sizes, + "export_format": self.benchmark.export_format, + "export_directory": self.benchmark.export_directory, + }, + "psychopy": { + "experiments_dir": self.psychopy.experiments_dir, + "default_experiment": self.psychopy.default_experiment, + "subject_gui_width": self.psychopy.subject_gui_width, + "subject_gui_height": self.psychopy.subject_gui_height, + }, + } + + def get(self, key: str, default: Any = None) -> Any: + """Get a configuration value by dot-notation key. + + Args: + key: Dot-notation key (e.g., 'network.host'). + default: Default value if key not found. + + Returns: + Configuration value or default. + """ + parts = key.split(".") + value: Any = self.to_dict() + for part in parts: + if isinstance(value, dict) and part in value: + value = value[part] + else: + return default + return value + + def set(self, key: str, value: Any) -> bool: + """Set a configuration value by dot-notation key. + + Args: + key: Dot-notation key (e.g., 'network.host'). + value: Value to set. + + Returns: + True if successful, False otherwise. + """ + parts = key.split(".") + if len(parts) != 2: + return False + + section, attr = parts + config_section = getattr(self, section, None) + if config_section is None: + return False + + if not hasattr(config_section, attr): + return False + + # Type conversion based on existing type + current_value = getattr(config_section, attr) + converted: Any + try: + if isinstance(current_value, bool): + converted = str(value).lower() in ("true", "1", "yes") + elif isinstance(current_value, int): + converted = int(value) + elif isinstance(current_value, float): + converted = float(value) + elif isinstance(current_value, list): + if isinstance(value, list): + converted = value + else: + # Parse comma-separated list + converted = [int(x.strip()) for x in str(value).split(",")] + else: + converted = str(value) + + setattr(config_section, attr, converted) + return True + except (ValueError, TypeError): + return False + + + def save(self) -> bool: + """Save configuration to TOML file. + + Returns: + True if saved successfully, False otherwise. + """ + if self.config_path is None: + return False + + toml_content = _generate_toml(self.to_dict()) + try: + with open(self.config_path, "w") as f: + f.write(toml_content) + return True + except OSError: + return False + + +def _generate_toml(config: dict[str, Any]) -> str: + """Generate TOML content from configuration dictionary. + + Args: + config: Configuration dictionary. + + Returns: + TOML-formatted string. + """ + lines = [] + for section, values in config.items(): + lines.append(f"[{section}]") + for key, value in values.items(): + if isinstance(value, str): + lines.append(f'{key} = "{value}"') + elif isinstance(value, bool): + lines.append(f"{key} = {str(value).lower()}") + elif isinstance(value, list): + list_str = ", ".join(str(v) for v in value) + lines.append(f"{key} = [{list_str}]") + else: + lines.append(f"{key} = {value}") + lines.append("") + return "\n".join(lines) + + +def load_config(path: Optional[str | Path] = None) -> Config: + """Load configuration from TOML file. + + Args: + path: Path to configuration file. If None, searches for + af_server_config.toml in current directory and user home. + + Returns: + Loaded configuration object. + """ + config = Config() + + # Find config file + if path is not None: + config_path = Path(path) + else: + # Search order: current dir, user home + search_paths = [ + Path.cwd() / "af_server_config.toml", + Path.home() / ".af_server_config.toml", + Path.home() / "af_server_config.toml", + ] + config_path = None + for p in search_paths: + if p.exists(): + config_path = p + break + + if config_path is None or not config_path.exists(): + return config + + config.config_path = config_path + + # Load TOML + try: + with open(config_path, "rb") as f: + data = tomllib.load(f) + except (OSError, tomllib.TOMLDecodeError): + return config + + # Apply loaded values + if "network" in data: + net = data["network"] + config.network = NetworkConfig( + host=net.get("host", config.network.host), + port=net.get("port", config.network.port), + timeout=net.get("timeout", config.network.timeout), + buffer_size=net.get("buffer_size", config.network.buffer_size), + enable_nodelay=net.get("enable_nodelay", config.network.enable_nodelay), + enable_qos=net.get("enable_qos", config.network.enable_qos), + heartbeat_interval=net.get("heartbeat_interval", config.network.heartbeat_interval), + ) + + if "console" in data: + con = data["console"] + config.console = ConsoleConfig( + log_level=con.get("log_level", config.console.log_level), + history_size=con.get("history_size", config.console.history_size), + max_log_lines=con.get("max_log_lines", config.console.max_log_lines), + theme=con.get("theme", config.console.theme), + ) + + if "benchmark" in data: + bench = data["benchmark"] + config.benchmark = BenchmarkConfig( + default_iterations=bench.get( + "default_iterations", config.benchmark.default_iterations + ), + payload_sizes=bench.get("payload_sizes", config.benchmark.payload_sizes), + export_format=bench.get("export_format", config.benchmark.export_format), + export_directory=bench.get("export_directory", config.benchmark.export_directory), + ) + + if "psychopy" in data: + psy = data["psychopy"] + config.psychopy = PsychoPyConfig( + experiments_dir=psy.get("experiments_dir", config.psychopy.experiments_dir), + default_experiment=psy.get("default_experiment", config.psychopy.default_experiment), + subject_gui_width=psy.get("subject_gui_width", config.psychopy.subject_gui_width), + subject_gui_height=psy.get("subject_gui_height", config.psychopy.subject_gui_height), + ) + + return config + + +def reload_config(config: Config) -> Config: + """Reload configuration from file. + + Args: + config: Existing configuration object. + + Returns: + New configuration object with reloaded values. + """ + if config.config_path is not None: + return load_config(config.config_path) + return load_config() diff --git a/af_server_client/core/protocol.py b/af_server_client/core/protocol.py new file mode 100644 index 0000000..b358bc1 --- /dev/null +++ b/af_server_client/core/protocol.py @@ -0,0 +1,89 @@ +"""Protocol classes for AF-Serializer communication. + +This module defines the @lvclass decorated classes for communication +with the LabVIEW server using AF-Serializer protocol. +""" + +try: + from af_serializer import lvclass, LVString, LVDouble, LVBoolean +except ImportError: + from af_server_client._mock.af_serializer import lvclass, LVString, LVDouble, LVBoolean + + +@lvclass(library="", class_name="Protocol") +class Protocol: + """General protocol wrapper for commands. + + Attributes: + data: Serialized command data. + """ + + data: LVString = "" + + +@lvclass(library="Benchmark", class_name="EchoCommand") +class EchoCommand: + """Echo command for benchmarking. + + Used to measure round-trip time by sending a payload + that the server echoes back. + + Attributes: + payload: Variable size payload for testing. + timestamp: Client timestamp when the command was sent. + """ + + payload: LVString = "" + timestamp: LVDouble = 0.0 + + +@lvclass(library="", class_name="Response") +class Response: + """Server response to client commands. + + Attributes: + request_id: Correlation ID matching the original request. + success: Whether the command executed successfully. + exec_time: Server execution time in microseconds. + message: Result or error message. + """ + + request_id: LVString = "" + success: LVBoolean = False + exec_time: LVDouble = 0.0 + message: LVString = "" + + +@lvclass(library="Experiment", class_name="StartTrialCommand") +class StartTrialCommand: + """Command to start an experimental trial. + + Attributes: + trial_id: Unique identifier for the trial. + parameters: JSON-encoded trial parameters. + """ + + trial_id: LVString = "" + parameters: LVString = "" + + +@lvclass(library="Experiment", class_name="SetIntensityCommand") +class SetIntensityCommand: + """Command to set intensity parameter. + + Attributes: + value: Intensity value (0-100). + """ + + value: LVDouble = 0.0 + + +@lvclass(library="Experiment", class_name="StopTrialCommand") +class StopTrialCommand: + """Command to stop the current trial. + + Attributes: + reason: Reason for stopping the trial. + """ + + reason: LVString = "" diff --git a/af_server_client/core/response_handler.py b/af_server_client/core/response_handler.py new file mode 100644 index 0000000..b3690fd --- /dev/null +++ b/af_server_client/core/response_handler.py @@ -0,0 +1,236 @@ +"""Response handler for request-response correlation. + +This module provides functionality to correlate requests with their +corresponding responses using unique request IDs. +""" + +import asyncio +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + from af_server_client.core.protocol import Response + + +@dataclass +class PendingRequest: + """Represents a pending request awaiting response. + + Attributes: + request_id: Unique identifier for the request. + timestamp: Time when the request was sent (perf_counter). + timeout: Timeout in seconds for the request. + future: Asyncio future to be resolved with response. + command: Original command object that was sent. + """ + + request_id: str + timestamp: float + timeout: float + future: asyncio.Future[Any] + command: Optional[Any] = None + + +@dataclass +class RequestMetrics: + """Metrics for a completed request. + + Attributes: + request_id: Unique identifier for the request. + request_timestamp: Time when request was sent. + response_timestamp: Time when response was received. + total_latency_ms: Total round-trip time in milliseconds. + exec_time_ms: Server execution time in milliseconds. + network_latency_ms: Network latency (total - exec_time). + success: Whether the request succeeded. + """ + + request_id: str + request_timestamp: float + response_timestamp: float + total_latency_ms: float + exec_time_ms: float + network_latency_ms: float + success: bool + + +class ResponseHandler: + """Correlate requests with responses. + + This class manages pending requests and matches incoming responses + to their corresponding requests using unique request IDs. + + Attributes: + pending: Dictionary of pending requests keyed by request_id. + completed_count: Number of successfully completed requests. + failed_count: Number of failed requests. + timeout_count: Number of timed out requests. + metrics_history: List of recent request metrics. + """ + + def __init__(self, max_history: int = 1000) -> None: + """Initialize the response handler. + + Args: + max_history: Maximum number of metrics to keep in history. + """ + self.pending: dict[str, PendingRequest] = {} + self.completed_count: int = 0 + self.failed_count: int = 0 + self.timeout_count: int = 0 + self._max_history = max_history + self.metrics_history: list[RequestMetrics] = [] + self._lock = asyncio.Lock() + + @property + def pending_count(self) -> int: + """Get the number of pending requests.""" + return len(self.pending) + + def register_request( + self, + request_id: str, + timeout: float, + command: Optional[Any] = None, + ) -> asyncio.Future[Any]: + """Register a new request and return a future for its response. + + Args: + request_id: Unique identifier for the request. + timeout: Timeout in seconds for the request. + command: Optional original command object. + + Returns: + Future that will be resolved with the response. + """ + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + future: asyncio.Future[Any] = loop.create_future() + self.pending[request_id] = PendingRequest( + request_id=request_id, + timestamp=time.perf_counter(), + timeout=timeout, + future=future, + command=command, + ) + return future + + def handle_response(self, response: "Response") -> Optional[RequestMetrics]: + """Handle an incoming response and resolve the corresponding future. + + Args: + response: Response object from the server. + + Returns: + Request metrics if successful, None if request not found. + """ + request_id = response.request_id + if request_id not in self.pending: + return None + + req = self.pending.pop(request_id) + response_timestamp = time.perf_counter() + + if req.future.done(): + return None + + # Calculate metrics + total_latency_ms = (response_timestamp - req.timestamp) * 1000 + exec_time_ms = response.exec_time / 1000.0 # Convert from μs to ms + network_latency_ms = total_latency_ms - exec_time_ms + + metrics = RequestMetrics( + request_id=request_id, + request_timestamp=req.timestamp, + response_timestamp=response_timestamp, + total_latency_ms=total_latency_ms, + exec_time_ms=exec_time_ms, + network_latency_ms=max(0, network_latency_ms), # Avoid negative values + success=response.success, + ) + + # Update counters + if response.success: + self.completed_count += 1 + else: + self.failed_count += 1 + + # Add to history + self.metrics_history.append(metrics) + if len(self.metrics_history) > self._max_history: + self.metrics_history.pop(0) + + # Resolve the future + req.future.set_result(response) + + return metrics + + def handle_timeout(self, request_id: str) -> bool: + """Handle a request timeout. + + Args: + request_id: ID of the timed-out request. + + Returns: + True if request was found and handled, False otherwise. + """ + if request_id not in self.pending: + return False + + req = self.pending.pop(request_id) + self.timeout_count += 1 + + if not req.future.done(): + req.future.set_exception(TimeoutError(f"Request {request_id} timed out")) + + return True + + def cancel_all(self) -> int: + """Cancel all pending requests. + + Returns: + Number of requests cancelled. + """ + count = len(self.pending) + for request_id, req in list(self.pending.items()): + if not req.future.done(): + req.future.cancel() + del self.pending[request_id] + return count + + def get_statistics(self) -> dict[str, Any]: + """Get response handler statistics. + + Returns: + Dictionary with statistics including averages and percentiles. + """ + stats: dict[str, Any] = { + "pending_count": self.pending_count, + "completed_count": self.completed_count, + "failed_count": self.failed_count, + "timeout_count": self.timeout_count, + "avg_latency_ms": 0.0, + "p95_latency_ms": 0.0, + "p99_latency_ms": 0.0, + } + + if not self.metrics_history: + return stats + + latencies = [m.total_latency_ms for m in self.metrics_history] + latencies_sorted = sorted(latencies) + + stats["avg_latency_ms"] = sum(latencies) / len(latencies) + + # Percentiles + p95_idx = int(len(latencies_sorted) * 0.95) + p99_idx = int(len(latencies_sorted) * 0.99) + + stats["p95_latency_ms"] = latencies_sorted[min(p95_idx, len(latencies_sorted) - 1)] + stats["p99_latency_ms"] = latencies_sorted[min(p99_idx, len(latencies_sorted) - 1)] + + return stats diff --git a/af_server_client/core/tcp_client.py b/af_server_client/core/tcp_client.py new file mode 100644 index 0000000..902801c --- /dev/null +++ b/af_server_client/core/tcp_client.py @@ -0,0 +1,392 @@ +"""Async TCP client with low-latency optimizations. + +This module provides an asyncio-based TCP client optimized for +low-latency communication with the LabVIEW server. +""" + +import asyncio +import socket +import sys +import time +import uuid +from typing import Any, Callable, Optional + +try: + from af_serializer import lvflatten, lvunflatten +except ImportError: + from af_server_client._mock.af_serializer import lvflatten, lvunflatten + +from af_server_client.core.config import Config +from af_server_client.core.protocol import Protocol, Response +from af_server_client.core.response_handler import ResponseHandler + +# Try to use uvloop for better performance on Unix systems +try: + import uvloop + + if sys.platform != "win32": + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +except ImportError: + pass + + +class TCPClientError(Exception): + """Base exception for TCP client errors.""" + + +class ConnectionError(TCPClientError): + """Connection-related error.""" + + +class SendError(TCPClientError): + """Error sending data.""" + + +class TCPClient: + """Async TCP client with low-latency optimizations. + + This client implements socket optimizations including: + - Nagle's algorithm disabled (TCP_NODELAY) + - QoS with DSCP EF for low latency + - Minimal buffer sizes + + Attributes: + config: Configuration object. + connected: Whether the client is connected. + response_handler: Handler for request-response correlation. + """ + + def __init__(self, config: Config) -> None: + """Initialize the TCP client. + + Args: + config: Configuration object with network settings. + """ + self.config = config + self.reader: Optional[asyncio.StreamReader] = None + self.writer: Optional[asyncio.StreamWriter] = None + self.response_handler = ResponseHandler() + self.connected = False + self._listen_task: Optional[asyncio.Task[None]] = None + self._heartbeat_task: Optional[asyncio.Task[None]] = None + self._callbacks: list[Callable[[str, Any], None]] = [] + self._connection_time: Optional[float] = None + + @property + def host(self) -> str: + """Get the configured host.""" + return self.config.network.host + + @property + def port(self) -> int: + """Get the configured port.""" + return self.config.network.port + + @property + def uptime(self) -> float: + """Get connection uptime in seconds.""" + if self._connection_time is None: + return 0.0 + return time.time() - self._connection_time + + def add_callback(self, callback: Callable[[str, Any], None]) -> None: + """Add a callback for events. + + Args: + callback: Callback function(event_type, data). + """ + self._callbacks.append(callback) + + def remove_callback(self, callback: Callable[[str, Any], None]) -> None: + """Remove a callback. + + Args: + callback: Callback function to remove. + """ + if callback in self._callbacks: + self._callbacks.remove(callback) + + def _emit_event(self, event_type: str, data: Any = None) -> None: + """Emit an event to all callbacks. + + Args: + event_type: Type of event. + data: Event data. + """ + for callback in self._callbacks: + try: + callback(event_type, data) + except Exception: + pass # Don't let callback errors propagate + + async def connect( + self, + host: Optional[str] = None, + port: Optional[int] = None, + ) -> bool: + """Establish connection with socket optimizations. + + Args: + host: Optional host override. + port: Optional port override. + + Returns: + True if connected successfully, False otherwise. + + Raises: + ConnectionError: If connection fails. + """ + if self.connected: + return True + + target_host = host or self.host + target_port = port or self.port + + # Create socket with optimizations + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + + # Disable Nagle's algorithm + if self.config.network.enable_nodelay: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + # Set QoS with DSCP EF for low latency + if self.config.network.enable_qos: + try: + sock.setsockopt(socket.IPPROTO_IP, socket.IP_TOS, 0xB8) + except OSError: + # QoS not supported on this system - emit event for logging + self._emit_event( + "warning", + {"message": "QoS (DSCP EF) not supported on this system"} + ) + + # Set buffer size + sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.config.network.buffer_size) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.config.network.buffer_size) + + sock.setblocking(False) + + try: + # Connect with timeout + loop = asyncio.get_event_loop() + await asyncio.wait_for( + loop.sock_connect(sock, (target_host, target_port)), + timeout=self.config.network.timeout, + ) + + # Create streams + self.reader, self.writer = await asyncio.open_connection(sock=sock) + self.connected = True + self._connection_time = time.time() + self._reconnect_attempts = 0 + + # Start background tasks + self._listen_task = asyncio.create_task(self._listen_loop()) + + if self.config.network.heartbeat_interval > 0: + self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) + + self._emit_event("connected", {"host": target_host, "port": target_port}) + return True + + except asyncio.TimeoutError as e: + sock.close() + raise ConnectionError(f"Connection timed out to {target_host}:{target_port}") from e + except OSError as e: + sock.close() + raise ConnectionError(f"Connection failed to {target_host}:{target_port}: {e}") from e + + async def disconnect(self) -> None: + """Gracefully close the connection.""" + if not self.connected: + return + + self.connected = False + + # Cancel background tasks + if self._listen_task is not None: + self._listen_task.cancel() + try: + await self._listen_task + except asyncio.CancelledError: + pass + self._listen_task = None + + if self._heartbeat_task is not None: + self._heartbeat_task.cancel() + try: + await self._heartbeat_task + except asyncio.CancelledError: + pass + self._heartbeat_task = None + + # Cancel pending requests + self.response_handler.cancel_all() + + # Close writer + if self.writer is not None: + self.writer.close() + try: + await self.writer.wait_closed() + except Exception: + pass + self.writer = None + self.reader = None + + self._connection_time = None + self._emit_event("disconnected", None) + + async def send_command( + self, + cmd: Any, + wait_response: bool = True, + timeout: Optional[float] = None, + ) -> Optional[Response]: + """Send command and optionally wait for response. + + Args: + cmd: Command object to send. + wait_response: Whether to wait for response. + timeout: Timeout in seconds (uses config default if None). + + Returns: + Response object if wait_response is True, None otherwise. + + Raises: + SendError: If sending fails. + TimeoutError: If response times out. + """ + if not self.connected or self.writer is None: + raise SendError("Not connected") + + timeout = timeout or self.config.network.timeout + request_id = str(uuid.uuid4()) + + # Create protocol wrapper + protocol = Protocol() + protocol.data = request_id + + # Serialize command + try: + # Serialize both protocol and command + protocol_data = lvflatten(protocol) + cmd_data = lvflatten(cmd) + data = protocol_data + cmd_data + except Exception as e: + raise SendError(f"Serialization failed: {e}") from e + + # Send with length prefix + try: + length = len(data).to_bytes(4, byteorder="big") + self.writer.write(length + data) + await self.writer.drain() + except Exception as e: + self.connected = False + raise SendError(f"Send failed: {e}") from e + + self._emit_event("sent", {"request_id": request_id, "command": type(cmd).__name__}) + + if wait_response: + future = self.response_handler.register_request(request_id, timeout, cmd) + try: + response: Response = await asyncio.wait_for(future, timeout=timeout) + return response + except asyncio.TimeoutError: + self.response_handler.handle_timeout(request_id) + raise + + return None + + + async def send_raw(self, data: bytes) -> None: + """Send raw data without protocol wrapping. + + Args: + data: Raw bytes to send. + + Raises: + SendError: If sending fails. + """ + if not self.connected or self.writer is None: + raise SendError("Not connected") + + try: + length = len(data).to_bytes(4, byteorder="big") + self.writer.write(length + data) + await self.writer.drain() + except Exception as e: + self.connected = False + raise SendError(f"Send failed: {e}") from e + + async def _listen_loop(self) -> None: + """Continuously receive and process responses.""" + while self.connected and self.reader is not None: + try: + # Read length prefix + length_data = await self.reader.readexactly(4) + msg_len = int.from_bytes(length_data, byteorder="big") + + # Read message + data = await self.reader.readexactly(msg_len) + + # Deserialize + try: + response = lvunflatten(data) + + # Handle response + if isinstance(response, Response): + metrics = self.response_handler.handle_response(response) + self._emit_event( + "response", + {"response": response, "metrics": metrics}, + ) + else: + self._emit_event("message", {"data": response}) + + except Exception as e: + self._emit_event("error", {"type": "deserialize", "error": str(e)}) + + except asyncio.CancelledError: + break + except asyncio.IncompleteReadError: + # Connection closed + self.connected = False + self._emit_event("disconnected", {"reason": "connection_closed"}) + break + except Exception as e: + self._emit_event("error", {"type": "receive", "error": str(e)}) + self.connected = False + break + + async def _heartbeat_loop(self) -> None: + """Send periodic heartbeats.""" + interval = self.config.network.heartbeat_interval + while self.connected: + try: + await asyncio.sleep(interval) + if self.connected: + # Send a simple Protocol message as heartbeat + protocol = Protocol() + protocol.data = "heartbeat" + await self.send_raw(lvflatten(protocol)) + except asyncio.CancelledError: + break + except Exception: + pass # Ignore heartbeat errors + + def get_statistics(self) -> dict[str, Any]: + """Get client statistics. + + Returns: + Dictionary with connection and request statistics. + """ + stats = self.response_handler.get_statistics() + stats.update( + { + "connected": self.connected, + "host": self.host, + "port": self.port, + "uptime_seconds": self.uptime, + } + ) + return stats diff --git a/af_server_client/psychopy_bridge/__init__.py b/af_server_client/psychopy_bridge/__init__.py new file mode 100644 index 0000000..dba9bd2 --- /dev/null +++ b/af_server_client/psychopy_bridge/__init__.py @@ -0,0 +1,9 @@ +"""PsychoPy bridge module for experiment integration.""" + +from af_server_client.psychopy_bridge.shared_client import SharedClient +from af_server_client.psychopy_bridge.launcher import ExperimentLauncher + +__all__ = [ + "SharedClient", + "ExperimentLauncher", +] diff --git a/af_server_client/psychopy_bridge/launcher.py b/af_server_client/psychopy_bridge/launcher.py new file mode 100644 index 0000000..5d6cade --- /dev/null +++ b/af_server_client/psychopy_bridge/launcher.py @@ -0,0 +1,197 @@ +"""PsychoPy experiment launcher. + +This module provides functionality to launch PsychoPy experiments +as subprocesses with shared TCP client access. +""" + +import asyncio +import subprocess +import sys +from pathlib import Path +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from af_server_client.core.config import Config + + +class ExperimentLauncher: + """Launch PsychoPy experiments as subprocesses. + + Manages the lifecycle of PsychoPy experiment processes. + """ + + def __init__(self, config: "Config") -> None: + """Initialize the launcher. + + Args: + config: Configuration object. + """ + self.config = config + self._processes: dict[str, subprocess.Popen[bytes]] = {} + self._experiments_dir = Path(config.psychopy.experiments_dir) + + @property + def experiments_dir(self) -> Path: + """Get the experiments directory path.""" + return self._experiments_dir + + def list_experiments(self) -> list[str]: + """List available experiments. + + Returns: + List of experiment filenames. + """ + if not self._experiments_dir.exists(): + return [] + + return [ + f.name + for f in self._experiments_dir.iterdir() + if f.is_file() and f.suffix == ".py" + ] + + async def launch( + self, + experiment: str, + wait: bool = False, + ) -> bool: + """Launch a PsychoPy experiment. + + Args: + experiment: Experiment filename or path. + wait: Whether to wait for the experiment to complete. + + Returns: + True if launched successfully. + """ + # Resolve experiment path + experiment_path = self._resolve_experiment_path(experiment) + if experiment_path is None: + return False + + # Build command + cmd = [ + sys.executable, + str(experiment_path), + "--host", + self.config.network.host, + "--port", + str(self.config.network.port), + ] + + try: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + self._processes[experiment] = process + + if wait: + await self._wait_for_process(experiment) + + return True + except Exception: + return False + + async def stop(self, experiment: str) -> bool: + """Stop a running experiment. + + Args: + experiment: Experiment name. + + Returns: + True if stopped successfully. + """ + if experiment not in self._processes: + return False + + process = self._processes[experiment] + try: + process.terminate() + await asyncio.sleep(0.5) + if process.poll() is None: + process.kill() + del self._processes[experiment] + return True + except Exception: + return False + + async def stop_all(self) -> None: + """Stop all running experiments.""" + for experiment in list(self._processes.keys()): + await self.stop(experiment) + + def is_running(self, experiment: str) -> bool: + """Check if an experiment is running. + + Args: + experiment: Experiment name. + + Returns: + True if running. + """ + if experiment not in self._processes: + return False + + return self._processes[experiment].poll() is None + + def get_running(self) -> list[str]: + """Get list of running experiments. + + Returns: + List of running experiment names. + """ + return [ + exp + for exp, proc in self._processes.items() + if proc.poll() is None + ] + + def _resolve_experiment_path(self, experiment: str) -> Optional[Path]: + """Resolve experiment path. + + Args: + experiment: Experiment filename or path. + + Returns: + Resolved path or None if not found. + """ + # Check if it's an absolute path + if Path(experiment).is_absolute(): + path = Path(experiment) + if path.exists(): + return path + return None + + # Check in experiments directory + path = self._experiments_dir / experiment + if path.exists(): + return path + + # Try adding .py extension + if not experiment.endswith(".py"): + path = self._experiments_dir / f"{experiment}.py" + if path.exists(): + return path + + return None + + async def _wait_for_process(self, experiment: str) -> Optional[int]: + """Wait for experiment process to complete. + + Args: + experiment: Experiment name. + + Returns: + Return code or None if not found. + """ + if experiment not in self._processes: + return None + + process = self._processes[experiment] + + while process.poll() is None: + await asyncio.sleep(0.1) + + return process.returncode diff --git a/af_server_client/psychopy_bridge/shared_client.py b/af_server_client/psychopy_bridge/shared_client.py new file mode 100644 index 0000000..d0e065a --- /dev/null +++ b/af_server_client/psychopy_bridge/shared_client.py @@ -0,0 +1,173 @@ +"""Shared TCP client singleton for PsychoPy integration. + +This module provides a singleton TCP client that can be shared +between the main console and PsychoPy experiments. +""" + +import asyncio +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from af_server_client.core.tcp_client import TCPClient + from af_server_client.core.config import Config + + +class SharedClient: + """Singleton TCP client for shared access. + + Provides a single TCP client instance that can be shared between + the main console application and PsychoPy experiments. + """ + + _instance: Optional["SharedClient"] = None + _client: Optional["TCPClient"] = None + _config: Optional["Config"] = None + _lock = asyncio.Lock() + + def __new__(cls) -> "SharedClient": + """Create singleton instance.""" + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + @classmethod + def get_instance(cls) -> "SharedClient": + """Get the singleton instance. + + Returns: + SharedClient singleton instance. + """ + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def initialize( + cls, + config: Optional["Config"] = None, + ) -> "SharedClient": + """Initialize the shared client with configuration. + + Args: + config: Optional configuration object. + + Returns: + SharedClient instance. + """ + instance = cls.get_instance() + + if config is not None: + cls._config = config + + if cls._client is None and cls._config is not None: + from af_server_client.core.tcp_client import TCPClient + + cls._client = TCPClient(cls._config) + + return instance + + @property + def client(self) -> Optional["TCPClient"]: + """Get the TCP client instance. + + Returns: + TCP client or None if not initialized. + """ + return self._client + + @property + def config(self) -> Optional["Config"]: + """Get the configuration. + + Returns: + Configuration or None if not set. + """ + return self._config + + @property + def is_connected(self) -> bool: + """Check if the client is connected. + + Returns: + True if connected, False otherwise. + """ + return self._client is not None and self._client.connected + + async def connect( + self, + host: Optional[str] = None, + port: Optional[int] = None, + ) -> bool: + """Connect to the server. + + Args: + host: Optional host override. + port: Optional port override. + + Returns: + True if connected successfully. + """ + async with self._lock: + if self._client is None: + return False + + if self._client.connected: + return True + + try: + return await self._client.connect(host=host, port=port) + except Exception: + return False + + async def disconnect(self) -> None: + """Disconnect from the server.""" + async with self._lock: + if self._client is not None and self._client.connected: + await self._client.disconnect() + + async def send_command( + self, + cmd: object, + wait_response: bool = True, + timeout: Optional[float] = None, + ) -> Optional[object]: + """Send a command through the shared client. + + Args: + cmd: Command object to send. + wait_response: Whether to wait for response. + timeout: Optional timeout override. + + Returns: + Response if wait_response is True, None otherwise. + """ + if self._client is None or not self._client.connected: + return None + + return await self._client.send_command( + cmd, + wait_response=wait_response, + timeout=timeout, + ) + + @classmethod + def reset(cls) -> None: + """Reset the singleton instance. + + Useful for testing or reinitialization. + """ + if cls._client is not None: + # Don't await - just set to None + cls._client = None + cls._config = None + cls._instance = None + + +# Convenience function for getting the shared client +def get_shared_client() -> SharedClient: + """Get the shared client instance. + + Returns: + SharedClient singleton instance. + """ + return SharedClient.get_instance() diff --git a/af_server_client/psychopy_bridge/subject_gui.py b/af_server_client/psychopy_bridge/subject_gui.py new file mode 100644 index 0000000..2bdf28a --- /dev/null +++ b/af_server_client/psychopy_bridge/subject_gui.py @@ -0,0 +1,304 @@ +"""Subject GUI template for PsychoPy experiments. + +This module provides a template for creating subject interfaces +with buttons and sliders for real-time feedback in experiments. + +Note: PsychoPy is an optional dependency. This module provides +a template that can be used when PsychoPy is installed. +""" + +import asyncio +from typing import TYPE_CHECKING, Any, Callable, Optional + +if TYPE_CHECKING: + from af_server_client.psychopy_bridge.shared_client import SharedClient + + +# Try to import PsychoPy +try: + from psychopy import visual, event + + PSYCHOPY_AVAILABLE = True +except ImportError: + PSYCHOPY_AVAILABLE = False + + +class SubjectGUIBase: + """Base class for subject GUI interfaces. + + Provides common functionality for experiment subject interfaces. + This base class works without PsychoPy for testing purposes. + """ + + def __init__( + self, + client: "SharedClient", + width: int = 800, + height: int = 600, + ) -> None: + """Initialize the GUI base. + + Args: + client: Shared TCP client for communication. + width: Window width. + height: Window height. + """ + self.client = client + self.width = width + self.height = height + self.running = False + self._callbacks: dict[str, Callable[..., Any]] = {} + + def register_callback(self, event_name: str, callback: Callable[..., Any]) -> None: + """Register a callback for an event. + + Args: + event_name: Name of the event. + callback: Callback function. + """ + self._callbacks[event_name] = callback + + def _emit_event(self, event_name: str, *args: Any, **kwargs: Any) -> None: + """Emit an event to registered callbacks. + + Args: + event_name: Name of the event. + *args: Positional arguments. + **kwargs: Keyword arguments. + """ + if event_name in self._callbacks: + self._callbacks[event_name](*args, **kwargs) + + async def send_command(self, cmd: object) -> Optional[object]: + """Send a command through the shared client. + + Args: + cmd: Command to send. + + Returns: + Response or None. + """ + return await self.client.send_command(cmd) + + def start(self) -> None: + """Start the GUI main loop.""" + self.running = True + + def stop(self) -> None: + """Stop the GUI main loop.""" + self.running = False + + +class SubjectGUI(SubjectGUIBase): + """PsychoPy-based subject GUI. + + Provides a visual interface for experiment subjects with + buttons, sliders, and text displays. + + Example usage: + ```python + from af_server_client.psychopy_bridge import SharedClient + from af_server_client.psychopy_bridge.subject_gui import SubjectGUI + + client = SharedClient.initialize(config) + await client.connect() + + gui = SubjectGUI(client, width=800, height=600) + + # Register callbacks + gui.register_callback('start_trial', on_start_trial) + gui.register_callback('intensity_changed', on_intensity_changed) + + # Run the GUI + gui.run() + ``` + """ + + def __init__( + self, + client: "SharedClient", + width: int = 800, + height: int = 600, + ) -> None: + """Initialize the subject GUI. + + Args: + client: Shared TCP client. + width: Window width. + height: Window height. + + Raises: + ImportError: If PsychoPy is not installed. + """ + super().__init__(client, width, height) + + if not PSYCHOPY_AVAILABLE: + raise ImportError( + "PsychoPy is not installed. " + "Install with: pip install af-server-client[psychopy]" + ) + + self.win: Optional[Any] = None + self.btn_start: Optional[Any] = None + self.btn_stop: Optional[Any] = None + self.slider_intensity: Optional[Any] = None + self.text_status: Optional[Any] = None + + def create_window(self) -> None: + """Create the PsychoPy window and components.""" + self.win = visual.Window( + size=(self.width, self.height), + fullscr=False, + units="pix", + ) + + # Start button + self.btn_start = visual.ButtonStim( + self.win, + text="Start Trial", + pos=(0, 150), + size=(200, 50), + ) + + # Stop button + self.btn_stop = visual.ButtonStim( + self.win, + text="Stop", + pos=(0, 50), + size=(200, 50), + ) + + # Intensity slider + self.slider_intensity = visual.Slider( + self.win, + ticks=(0, 50, 100), + pos=(0, -100), + size=(400, 30), + labels=("0%", "50%", "100%"), + granularity=1, + ) + + # Status text + self.text_status = visual.TextStim( + self.win, + text="Ready", + pos=(0, -200), + height=24, + ) + + def run(self) -> None: + """Run the GUI main loop.""" + if self.win is None: + self.create_window() + + self.running = True + last_intensity: Optional[float] = None + + while self.running: + # Check for button clicks + if self.btn_start is not None and self.btn_start.isClicked: + self._emit_event("start_trial") + asyncio.get_event_loop().run_until_complete( + self._on_start_trial() + ) + + if self.btn_stop is not None and self.btn_stop.isClicked: + self._emit_event("stop_trial") + asyncio.get_event_loop().run_until_complete( + self._on_stop_trial() + ) + + # Check slider value + if self.slider_intensity is not None: + intensity = self.slider_intensity.getRating() + if intensity is not None and intensity != last_intensity: + last_intensity = intensity + self._emit_event("intensity_changed", intensity) + asyncio.get_event_loop().run_until_complete( + self._on_intensity_changed(intensity) + ) + + # Draw components + if self.btn_start is not None: + self.btn_start.draw() + if self.btn_stop is not None: + self.btn_stop.draw() + if self.slider_intensity is not None: + self.slider_intensity.draw() + if self.text_status is not None: + self.text_status.draw() + + if self.win is not None: + self.win.flip() + + # Check for escape key + if "escape" in event.getKeys(): + self.stop() + + def stop(self) -> None: + """Stop the GUI and close window.""" + self.running = False + if self.win is not None: + self.win.close() + self.win = None + + def set_status(self, text: str) -> None: + """Set the status text. + + Args: + text: Status text to display. + """ + if self.text_status is not None: + self.text_status.text = text + + async def _on_start_trial(self) -> None: + """Handle start trial button click.""" + from af_server_client.core.protocol import StartTrialCommand + + cmd = StartTrialCommand() + await self.send_command(cmd) + self.set_status("Trial started") + + async def _on_stop_trial(self) -> None: + """Handle stop trial button click.""" + from af_server_client.core.protocol import StopTrialCommand + + cmd = StopTrialCommand() + cmd.reason = "User stopped" + await self.send_command(cmd) + self.set_status("Trial stopped") + + async def _on_intensity_changed(self, value: float) -> None: + """Handle intensity slider change. + + Args: + value: New intensity value (0-100). + """ + from af_server_client.core.protocol import SetIntensityCommand + + cmd = SetIntensityCommand() + cmd.value = value + await self.send_command(cmd) + + +# Factory function for creating subject GUIs +def create_subject_gui( + client: "SharedClient", + width: int = 800, + height: int = 600, +) -> SubjectGUIBase: + """Create a subject GUI. + + Creates a PsychoPy GUI if available, otherwise returns base class. + + Args: + client: Shared TCP client. + width: Window width. + height: Window height. + + Returns: + Subject GUI instance. + """ + if PSYCHOPY_AVAILABLE: + return SubjectGUI(client, width, height) + return SubjectGUIBase(client, width, height) diff --git a/af_server_config.toml b/af_server_config.toml new file mode 100644 index 0000000..6127b82 --- /dev/null +++ b/af_server_config.toml @@ -0,0 +1,26 @@ +[network] +host = "127.0.0.1" +port = 2001 +timeout = 5.0 +buffer_size = 256 +enable_nodelay = true +enable_qos = true +heartbeat_interval = 30.0 + +[console] +log_level = "DEBUG" +history_size = 100 +max_log_lines = 1000 +theme = "monokai" + +[benchmark] +default_iterations = 10 +payload_sizes = [0, 128, 256, 512, 1024, 2048, 4096, 8192] +export_format = "json" +export_directory = "./benchmark_results" + +[psychopy] +experiments_dir = "./experiments" +default_experiment = "feedback_loop.py" +subject_gui_width = 800 +subject_gui_height = 600 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..7f3ea73 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,49 @@ +[project] +name = "af-server-client" +version = "0.1.0" +description = "Low-latency TCP client for LabVIEW server communication with AF-Serializer" +readme = "README.md" +requires-python = ">=3.10" +authors = [{name = "Faragoz", email = "aragon.froylan@gmail.com"}] +license = {text = "MIT"} + +dependencies = [ + "af-serializer @ git+https://github.com/Faragoz/AF-Serializer.git", + "textual>=0.50.0", + "uvloop>=0.19.0; sys_platform != 'win32'", + "rich>=13.7.0", + "pandas>=2.2.0", + "numpy>=1.26.0", + "openpyxl>=3.1.0", + "tomli>=2.0.1; python_version < '3.11'", +] + +[project.optional-dependencies] +psychopy = ["psychopy>=2024.1.0"] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", + "pytest-cov>=4.1.0", + "mypy>=1.8.0", + "ruff>=0.2.0", +] + +[project.scripts] +af-server-client = "af_server_client.cli:main" + +[build-system] +requires = ["setuptools>=68.0.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.ruff] +line-length = 100 +target-version = "py310" + +[tool.mypy] +python_version = "3.10" +warn_return_any = true +warn_unused_configs = true + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] diff --git a/test.py b/test.py index 14f8a29..64169b9 100644 --- a/test.py +++ b/test.py @@ -1,4 +1,4 @@ -from af_serializer import lvclass, lvflatten, lvunflatten, LVU16 +from af_serializer import lvflatten print(lvflatten("Hello")) # Expected output: [] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..983534a --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Tests package initialization.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..2482076 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,28 @@ +"""Pytest configuration and fixtures. + +Sets up the mock af_serializer if the real package is not available. +""" + +import sys + +# Add the mock af_serializer to sys.modules if the real one isn't available +try: + import af_serializer # noqa: F401 +except ImportError: + # Import the mock + from af_server_client._mock import af_serializer as mock_af_serializer + + sys.modules["af_serializer"] = mock_af_serializer + + +import asyncio + +import pytest + + +@pytest.fixture(scope="session") +def event_loop(): + """Create an instance of the default event loop for the test session.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py new file mode 100644 index 0000000..516f445 --- /dev/null +++ b/tests/test_benchmark.py @@ -0,0 +1,237 @@ +"""Tests for benchmark functionality.""" + +import pytest + +from af_server_client.benchmark.metrics import Sample, BenchmarkRun +from af_server_client.benchmark.rtt_test import LatencyHistogram +from af_server_client.benchmark.export import BenchmarkExporter + + +class TestSample: + """Tests for Sample class.""" + + def test_sample_creation(self): + """Test creating a Sample.""" + sample = Sample() + + assert sample.request["timestamp"] is None + assert sample.request["payload_size"] is None + assert sample.response["timestamp"] is None + assert sample.metrics["total_latency"] == 0.0 + assert sample.metrics["network_latency"] == 0.0 + assert sample.metrics["exec_time"] == 0.0 + + def test_calculate_metrics(self): + """Test calculating metrics from timestamps.""" + sample = Sample() + sample.request["timestamp"] = 1000.0 + sample.response["timestamp"] = 1000.005 # 5ms later + sample.metrics["exec_time"] = 1.0 # 1ms exec time + + sample.calculate_metrics() + + assert abs(sample.metrics["total_latency"] - 5.0) < 0.001 # 5ms + assert abs(sample.metrics["network_latency"] - 4.0) < 0.001 # 5 - 1 = 4ms + + + def test_calculate_metrics_no_timestamps(self): + """Test calculating metrics without timestamps.""" + sample = Sample() + + sample.calculate_metrics() + + assert sample.metrics["total_latency"] == 0.0 + + +class TestBenchmarkRun: + """Tests for BenchmarkRun class.""" + + def test_benchmark_run_creation(self): + """Test creating a BenchmarkRun.""" + run = BenchmarkRun() + + assert run.timing["start_time"] is None + assert run.timing["end_time"] is None + assert run.stats["samples_count"] == 0 + assert len(run.samples) == 0 + + def test_add_sample(self): + """Test adding samples to benchmark run.""" + run = BenchmarkRun() + sample = Sample() + sample.request["timestamp"] = 1000.0 + sample.response["timestamp"] = 1000.002 + sample.metrics["total_latency"] = 2.0 + + run.add_sample("sample-1", sample) + + assert len(run.samples) == 1 + assert "sample-1" in run.samples + + def test_compute_statistics(self): + """Test computing statistics from samples.""" + run = BenchmarkRun() + + # Add some samples with known latencies + for i, latency in enumerate([1.0, 2.0, 3.0, 4.0, 5.0]): + sample = Sample() + sample.metrics["total_latency"] = latency + sample.metrics["network_latency"] = latency - 0.5 + sample.metrics["exec_time"] = 0.5 + run.add_sample(f"sample-{i}", sample) + + run.timing["start_time"] = 0.0 + run.timing["end_time"] = 1.0 + + run.compute_statistics() + + assert run.stats["samples_count"] == 5 + assert run.stats["avg_total_latency"] == 3.0 + assert run.stats["min_latency"] == 1.0 + assert run.stats["max_latency"] == 5.0 + assert run.stats["median_latency"] == 3.0 + assert run.timing["duration"] == 1.0 + + def test_compute_statistics_empty(self): + """Test computing statistics with no samples.""" + run = BenchmarkRun() + + run.compute_statistics() + + assert run.stats["samples_count"] == 0 + + def test_to_dict(self): + """Test converting benchmark run to dictionary.""" + run = BenchmarkRun() + sample = Sample() + run.add_sample("sample-1", sample) + + data = run.to_dict() + + assert "timing" in data + assert "stats" in data + assert "samples" in data + assert "metadata" in data + assert "sample-1" in data["samples"] + + +class TestLatencyHistogram: + """Tests for LatencyHistogram class.""" + + def test_histogram_creation(self): + """Test creating a histogram.""" + hist = LatencyHistogram(min_value=0, max_value=10, num_buckets=10) + + assert hist.count == 0 + assert len(hist.buckets) == 10 + + def test_add_values(self): + """Test adding values to histogram.""" + hist = LatencyHistogram(min_value=0, max_value=10, num_buckets=10) + + hist.add(0.5) + hist.add(5.5) + hist.add(9.5) + + assert hist.count == 3 + assert hist.buckets[0] == 1 # 0-1ms + assert hist.buckets[5] == 1 # 5-6ms + assert hist.buckets[9] == 1 # 9-10ms + + def test_overflow(self): + """Test overflow handling.""" + hist = LatencyHistogram(min_value=0, max_value=10, num_buckets=10) + + hist.add(15.0) + + assert hist.overflow == 1 + + def test_underflow(self): + """Test underflow handling.""" + hist = LatencyHistogram(min_value=1, max_value=10, num_buckets=9) + + hist.add(0.5) + + assert hist.underflow == 1 + + def test_get_distribution(self): + """Test getting distribution.""" + hist = LatencyHistogram(min_value=0, max_value=10, num_buckets=10) + hist.add(1.5) + + dist = hist.get_distribution() + + assert len(dist) == 10 + assert dist[1] == (1.0, 2.0, 1) # 1-2ms bucket has 1 value + + def test_to_ascii(self): + """Test ASCII representation.""" + hist = LatencyHistogram(min_value=0, max_value=10, num_buckets=10) + hist.add(1.5) + hist.add(1.8) + + ascii_repr = hist.to_ascii(width=20) + + assert "ms" in ascii_repr + assert "(" in ascii_repr # Count markers + + +class TestBenchmarkExporter: + """Tests for BenchmarkExporter class.""" + + @pytest.fixture + def benchmark_run(self): + """Create a test benchmark run.""" + run = BenchmarkRun() + run.timing["start_time"] = 0.0 + run.timing["end_time"] = 1.0 + + for i in range(5): + sample = Sample() + sample.request["timestamp"] = i * 0.1 + sample.request["payload_size"] = 128 + sample.response["timestamp"] = i * 0.1 + 0.002 + sample.metrics["total_latency"] = 2.0 + sample.metrics["network_latency"] = 1.5 + sample.metrics["exec_time"] = 0.5 + run.add_sample(f"sample-{i}", sample) + + run.compute_statistics() + return run + + def test_to_json(self, benchmark_run, tmp_path): + """Test exporting to JSON.""" + exporter = BenchmarkExporter(benchmark_run) + + path = exporter.to_json("test", str(tmp_path)) + + assert path.exists() + assert path.suffix == ".json" + + def test_to_csv(self, benchmark_run, tmp_path): + """Test exporting to CSV.""" + exporter = BenchmarkExporter(benchmark_run) + + path = exporter.to_csv("test", str(tmp_path)) + + assert path.exists() + assert path.suffix == ".csv" + + def test_to_excel(self, benchmark_run, tmp_path): + """Test exporting to Excel.""" + exporter = BenchmarkExporter(benchmark_run) + + path = exporter.to_excel("test", str(tmp_path)) + + assert path.exists() + assert path.suffix == ".xlsx" + + def test_get_dataframe(self, benchmark_run): + """Test getting DataFrame.""" + exporter = BenchmarkExporter(benchmark_run) + + df = exporter.get_dataframe() + + assert len(df) == 5 + assert "sample_id" in df.columns + assert "total_latency_ms" in df.columns diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..03c662f --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,197 @@ +"""Tests for configuration module.""" + +import tempfile +from pathlib import Path + +from af_server_client.core.config import ( + Config, + NetworkConfig, + ConsoleConfig, + BenchmarkConfig, + PsychoPyConfig, + load_config, + reload_config, +) + + +class TestNetworkConfig: + """Tests for NetworkConfig dataclass.""" + + def test_default_values(self): + """Test default network configuration values.""" + config = NetworkConfig() + + assert config.host == "127.0.0.1" + assert config.port == 2001 + assert config.timeout == 5.0 + assert config.buffer_size == 256 + assert config.enable_nodelay is True + assert config.enable_qos is True + assert config.heartbeat_interval == 30.0 + + def test_custom_values(self): + """Test custom network configuration values.""" + config = NetworkConfig( + host="192.168.1.100", + port=3000, + timeout=10.0, + ) + + assert config.host == "192.168.1.100" + assert config.port == 3000 + assert config.timeout == 10.0 + + +class TestConfig: + """Tests for Config class.""" + + def test_default_config(self): + """Test default configuration.""" + config = Config() + + assert isinstance(config.network, NetworkConfig) + assert isinstance(config.console, ConsoleConfig) + assert isinstance(config.benchmark, BenchmarkConfig) + assert isinstance(config.psychopy, PsychoPyConfig) + + def test_to_dict(self): + """Test converting config to dictionary.""" + config = Config() + data = config.to_dict() + + assert "network" in data + assert "console" in data + assert "benchmark" in data + assert "psychopy" in data + assert data["network"]["host"] == "127.0.0.1" + + def test_get_value(self): + """Test getting config value by key.""" + config = Config() + + assert config.get("network.host") == "127.0.0.1" + assert config.get("network.port") == 2001 + assert config.get("console.log_level") == "DEBUG" + + def test_get_value_default(self): + """Test getting config value with default.""" + config = Config() + + assert config.get("invalid.key", "default") == "default" + + def test_set_value(self): + """Test setting config value.""" + config = Config() + + result = config.set("network.host", "192.168.1.1") + + assert result is True + assert config.network.host == "192.168.1.1" + + def test_set_value_int(self): + """Test setting integer config value.""" + config = Config() + + result = config.set("network.port", "3000") + + assert result is True + assert config.network.port == 3000 + + def test_set_value_bool(self): + """Test setting boolean config value.""" + config = Config() + + result = config.set("network.enable_nodelay", "false") + + assert result is True + assert config.network.enable_nodelay is False + + def test_set_value_invalid_section(self): + """Test setting value with invalid section.""" + config = Config() + + result = config.set("invalid.key", "value") + + assert result is False + + def test_set_value_invalid_key(self): + """Test setting value with invalid key.""" + config = Config() + + result = config.set("network.invalid", "value") + + assert result is False + + +class TestLoadConfig: + """Tests for load_config function.""" + + def test_load_default_config(self): + """Test loading default configuration when no file exists.""" + config = load_config() + + assert isinstance(config, Config) + assert config.network.host == "127.0.0.1" + + def test_load_config_from_file(self): + """Test loading configuration from TOML file.""" + toml_content = """ +[network] +host = "192.168.1.100" +port = 3000 +timeout = 10.0 + +[console] +log_level = "INFO" +""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".toml", delete=False) as f: + f.write(toml_content) + f.flush() + + config = load_config(f.name) + + assert config.network.host == "192.168.1.100" + assert config.network.port == 3000 + assert config.console.log_level == "INFO" + + # Cleanup + Path(f.name).unlink() + + def test_load_config_missing_file(self): + """Test loading config with missing file returns defaults.""" + config = load_config("/nonexistent/path/config.toml") + + assert isinstance(config, Config) + assert config.network.host == "127.0.0.1" + + +class TestReloadConfig: + """Tests for reload_config function.""" + + def test_reload_without_path(self): + """Test reloading config without config_path.""" + config = Config() + + reloaded = reload_config(config) + + assert isinstance(reloaded, Config) + + def test_reload_with_path(self): + """Test reloading config from file.""" + toml_content = """ +[network] +host = "10.0.0.1" +""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".toml", delete=False) as f: + f.write(toml_content) + f.flush() + + config = load_config(f.name) + config.network.host = "changed" + + reloaded = reload_config(config) + + assert reloaded.network.host == "10.0.0.1" + + # Cleanup + Path(f.name).unlink() diff --git a/tests/test_console.py b/tests/test_console.py new file mode 100644 index 0000000..8534f2e --- /dev/null +++ b/tests/test_console.py @@ -0,0 +1,255 @@ +"""Tests for console components.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from af_server_client.console.command_parser import ( + CommandParser, + ParsedCommand, + CommandParseError, +) +from af_server_client.console.commands import CommandExecutor +from af_server_client.core.config import Config + + +class TestParsedCommand: + """Tests for ParsedCommand dataclass.""" + + def test_parsed_command_creation(self): + """Test creating a ParsedCommand.""" + cmd = ParsedCommand( + name="connect", + args=["localhost", "2001"], + kwargs={"timeout": "5"}, + raw="connect localhost 2001 --timeout 5", + ) + + assert cmd.name == "connect" + assert cmd.args == ["localhost", "2001"] + assert cmd.kwargs == {"timeout": "5"} + + +class TestCommandParser: + """Tests for CommandParser class.""" + + @pytest.fixture + def parser(self): + """Create a parser instance.""" + return CommandParser() + + def test_parse_simple_command(self, parser): + """Test parsing a simple command.""" + result = parser.parse("connect") + + assert result.name == "connect" + assert result.args == [] + assert result.kwargs == {} + + def test_parse_command_with_args(self, parser): + """Test parsing command with arguments.""" + result = parser.parse("connect localhost 2001") + + assert result.name == "connect" + assert result.args == ["localhost", "2001"] + + def test_parse_command_with_kwargs(self, parser): + """Test parsing command with keyword arguments.""" + result = parser.parse("benchmark --iterations 20") + + assert result.name == "benchmark" + assert result.kwargs == {"iterations": "20"} + + def test_parse_quoted_args(self, parser): + """Test parsing command with quoted arguments.""" + result = parser.parse('send "hello world"') + + assert result.name == "send" + assert result.args == ["hello world"] + + def test_parse_empty_command(self, parser): + """Test parsing empty command raises error.""" + with pytest.raises(CommandParseError, match="Empty command"): + parser.parse("") + + def test_parse_whitespace_only(self, parser): + """Test parsing whitespace-only command raises error.""" + with pytest.raises(CommandParseError, match="Empty command"): + parser.parse(" ") + + def test_validate_known_command(self, parser): + """Test validating a known command.""" + parsed = parser.parse("connect localhost 2001") + + error = parser.validate(parsed) + + assert error is None + + def test_validate_unknown_command(self, parser): + """Test validating an unknown command.""" + parsed = parser.parse("foobar") + + error = parser.validate(parsed) + + assert error is not None + assert "Unknown command" in error + + def test_validate_missing_args(self, parser): + """Test validating command with missing required args.""" + parsed = parser.parse("export json") # Missing filename + + error = parser.validate(parsed) + + assert error is not None + assert "Missing required arguments" in error + + def test_get_completions(self, parser): + """Test getting command completions.""" + completions = parser.get_completions("con") + + assert "connect" in completions + assert "config" in completions + + def test_get_completions_empty(self, parser): + """Test completions with empty input returns all commands.""" + completions = parser.get_completions("") + + assert len(completions) == len(parser.COMMANDS) + + def test_get_help_all(self, parser): + """Test getting general help.""" + help_text = parser.get_help() + + assert "Available commands" in help_text + assert "connect" in help_text + assert "benchmark" in help_text + + def test_get_help_specific(self, parser): + """Test getting help for specific command.""" + help_text = parser.get_help("connect") + + assert "connect" in help_text + assert "host" in help_text + + def test_history(self, parser): + """Test command history.""" + parser.add_to_history("connect") + parser.add_to_history("status") + + assert parser.history == ["connect", "status"] + + def test_history_no_duplicates(self, parser): + """Test that consecutive duplicates aren't added.""" + parser.add_to_history("status") + parser.add_to_history("status") + + assert parser.history == ["status"] + + +class TestCommandExecutor: + """Tests for CommandExecutor class.""" + + @pytest.fixture + def mock_client(self): + """Create a mock TCP client.""" + client = MagicMock() + client.connected = False + client.host = "127.0.0.1" + client.port = 2001 + client.connect = AsyncMock(return_value=True) + client.disconnect = AsyncMock() + client.get_statistics = MagicMock(return_value={ + "connected": False, + "host": "127.0.0.1", + "port": 2001, + "uptime_seconds": 0.0, + "avg_latency_ms": 0.0, + "p95_latency_ms": 0.0, + "p99_latency_ms": 0.0, + "pending_count": 0, + "completed_count": 0, + "failed_count": 0, + "timeout_count": 0, + }) + return client + + @pytest.fixture + def config(self): + """Create a test configuration.""" + return Config() + + @pytest.fixture + def executor(self, mock_client, config): + """Create a command executor.""" + return CommandExecutor(mock_client, config) + + @pytest.mark.asyncio + async def test_execute_connect(self, executor, mock_client): + """Test executing connect command.""" + result = await executor.execute("connect localhost 2001") + + assert result.success is True + mock_client.connect.assert_called_once() + + @pytest.mark.asyncio + async def test_execute_disconnect(self, executor, mock_client): + """Test executing disconnect command.""" + result = await executor.execute("disconnect") + + assert result.success is True + mock_client.disconnect.assert_called_once() + + @pytest.mark.asyncio + async def test_execute_status(self, executor): + """Test executing status command.""" + result = await executor.execute("status") + + assert result.success is True + assert result.data is not None + + @pytest.mark.asyncio + async def test_execute_help(self, executor): + """Test executing help command.""" + result = await executor.execute("help") + + assert result.success is True + assert "Available commands" in result.message + + @pytest.mark.asyncio + async def test_execute_quit(self, executor): + """Test executing quit command.""" + result = await executor.execute("quit") + + assert result.success is True + assert result.data["action"] == "quit" + + @pytest.mark.asyncio + async def test_execute_clear(self, executor): + """Test executing clear command.""" + result = await executor.execute("clear") + + assert result.success is True + assert result.data["action"] == "clear" + + @pytest.mark.asyncio + async def test_execute_config_show(self, executor): + """Test executing config show command.""" + result = await executor.execute("config show") + + assert result.success is True + assert "network" in result.message + + @pytest.mark.asyncio + async def test_execute_unknown(self, executor): + """Test executing unknown command.""" + result = await executor.execute("foobar") + + assert result.success is False + assert "Unknown command" in result.message + + @pytest.mark.asyncio + async def test_execute_parse_error(self, executor): + """Test executing with parse error.""" + result = await executor.execute("") + + assert result.success is False + assert "Parse error" in result.message diff --git a/tests/test_protocol.py b/tests/test_protocol.py new file mode 100644 index 0000000..4f6be9e --- /dev/null +++ b/tests/test_protocol.py @@ -0,0 +1,166 @@ +"""Tests for protocol classes.""" + + +try: + from af_serializer import lvflatten, lvunflatten +except ImportError: + from af_server_client._mock.af_serializer import lvflatten, lvunflatten + +from af_server_client.core.protocol import ( + Protocol, + EchoCommand, + Response, + StartTrialCommand, + SetIntensityCommand, + StopTrialCommand, +) + + +class TestProtocol: + """Tests for Protocol class.""" + + def test_protocol_creation(self): + """Test creating a Protocol instance.""" + protocol = Protocol() + assert protocol.data == "" + + def test_protocol_with_data(self): + """Test Protocol with data.""" + protocol = Protocol() + protocol.data = "test-request-id" + assert protocol.data == "test-request-id" + + def test_protocol_serialization(self): + """Test Protocol serialization roundtrip.""" + protocol = Protocol() + protocol.data = "request-123" + + # Serialize + data = lvflatten(protocol) + assert isinstance(data, bytes) + + # Deserialize + restored = lvunflatten(data) + assert isinstance(restored, Protocol) + assert restored.data == "request-123" + + +class TestEchoCommand: + """Tests for EchoCommand class.""" + + def test_echo_command_creation(self): + """Test creating an EchoCommand instance.""" + cmd = EchoCommand() + assert cmd.payload == "" + assert cmd.timestamp == 0.0 + + def test_echo_command_with_data(self): + """Test EchoCommand with data.""" + cmd = EchoCommand() + cmd.payload = "test-payload" + cmd.timestamp = 1234567890.123 + + assert cmd.payload == "test-payload" + assert cmd.timestamp == 1234567890.123 + + def test_echo_command_serialization(self): + """Test EchoCommand serialization roundtrip.""" + cmd = EchoCommand() + cmd.payload = "hello world" + cmd.timestamp = 1000.5 + + data = lvflatten(cmd) + restored = lvunflatten(data) + + assert isinstance(restored, EchoCommand) + assert restored.payload == "hello world" + assert restored.timestamp == 1000.5 + + def test_echo_command_large_payload(self): + """Test EchoCommand with large payload.""" + cmd = EchoCommand() + cmd.payload = "x" * 8192 + cmd.timestamp = 999.999 + + data = lvflatten(cmd) + restored = lvunflatten(data) + + assert len(restored.payload) == 8192 + + +class TestResponse: + """Tests for Response class.""" + + def test_response_creation(self): + """Test creating a Response instance.""" + response = Response() + assert response.request_id == "" + assert response.success is False + assert response.exec_time == 0.0 + assert response.message == "" + + def test_response_with_data(self): + """Test Response with data.""" + response = Response() + response.request_id = "req-456" + response.success = True + response.exec_time = 1500.5 + response.message = "Operation completed" + + assert response.request_id == "req-456" + assert response.success is True + assert response.exec_time == 1500.5 + assert response.message == "Operation completed" + + def test_response_serialization(self): + """Test Response serialization roundtrip.""" + response = Response() + response.request_id = "test-id" + response.success = True + response.exec_time = 2000.0 + response.message = "Success" + + data = lvflatten(response) + restored = lvunflatten(data) + + assert isinstance(restored, Response) + assert restored.request_id == "test-id" + assert restored.success is True + assert restored.exec_time == 2000.0 + assert restored.message == "Success" + + +class TestExperimentCommands: + """Tests for experiment-related commands.""" + + def test_start_trial_command(self): + """Test StartTrialCommand.""" + cmd = StartTrialCommand() + cmd.trial_id = "trial-001" + cmd.parameters = '{"duration": 60}' + + data = lvflatten(cmd) + restored = lvunflatten(data) + + assert restored.trial_id == "trial-001" + assert restored.parameters == '{"duration": 60}' + + def test_set_intensity_command(self): + """Test SetIntensityCommand.""" + cmd = SetIntensityCommand() + cmd.value = 75.5 + + data = lvflatten(cmd) + restored = lvunflatten(data) + + assert restored.value == 75.5 + + def test_stop_trial_command(self): + """Test StopTrialCommand.""" + cmd = StopTrialCommand() + cmd.reason = "User cancelled" + + data = lvflatten(cmd) + restored = lvunflatten(data) + + assert restored.reason == "User cancelled" diff --git a/tests/test_response_handler.py b/tests/test_response_handler.py new file mode 100644 index 0000000..b1c7ca2 --- /dev/null +++ b/tests/test_response_handler.py @@ -0,0 +1,219 @@ +"""Tests for response handler.""" + +import asyncio +import pytest + +from af_server_client.core.response_handler import ( + ResponseHandler, + PendingRequest, + RequestMetrics, +) +from af_server_client.core.protocol import Response + + +class TestPendingRequest: + """Tests for PendingRequest dataclass.""" + + def test_pending_request_creation(self): + """Test creating a PendingRequest.""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + future = loop.create_future() + + req = PendingRequest( + request_id="test-id", + timestamp=1234.5, + timeout=5.0, + future=future, + ) + + assert req.request_id == "test-id" + assert req.timestamp == 1234.5 + assert req.timeout == 5.0 + assert req.future == future + assert req.command is None + + +class TestResponseHandler: + """Tests for ResponseHandler class.""" + + def test_handler_creation(self): + """Test creating a ResponseHandler.""" + handler = ResponseHandler() + + assert handler.pending_count == 0 + assert handler.completed_count == 0 + assert handler.failed_count == 0 + assert handler.timeout_count == 0 + + def test_register_request(self): + """Test registering a request.""" + handler = ResponseHandler() + + future = handler.register_request("req-1", timeout=5.0) + + assert handler.pending_count == 1 + assert "req-1" in handler.pending + assert not future.done() + + def test_handle_response(self): + """Test handling a response.""" + handler = ResponseHandler() + + # Register request + future = handler.register_request("req-1", timeout=5.0) + + # Create response + response = Response() + response.request_id = "req-1" + response.success = True + response.exec_time = 1000.0 # μs + response.message = "OK" + + # Handle response + metrics = handler.handle_response(response) + + assert metrics is not None + assert metrics.request_id == "req-1" + assert metrics.success is True + assert handler.pending_count == 0 + assert handler.completed_count == 1 + assert future.done() + assert future.result() == response + + def test_handle_failed_response(self): + """Test handling a failed response.""" + handler = ResponseHandler() + + _ = handler.register_request("req-2", timeout=5.0) + + response = Response() + response.request_id = "req-2" + response.success = False + response.message = "Error" + + handler.handle_response(response) + + assert handler.completed_count == 0 + assert handler.failed_count == 1 + + def test_handle_unknown_response(self): + """Test handling response with unknown request_id.""" + handler = ResponseHandler() + + response = Response() + response.request_id = "unknown" + + metrics = handler.handle_response(response) + + assert metrics is None + assert handler.pending_count == 0 + + def test_handle_timeout(self): + """Test handling request timeout.""" + handler = ResponseHandler() + + future = handler.register_request("req-3", timeout=5.0) + + result = handler.handle_timeout("req-3") + + assert result is True + assert handler.timeout_count == 1 + assert handler.pending_count == 0 + assert future.done() + + with pytest.raises(TimeoutError): + future.result() + + def test_handle_timeout_unknown(self): + """Test timeout for unknown request.""" + handler = ResponseHandler() + + result = handler.handle_timeout("unknown") + + assert result is False + assert handler.timeout_count == 0 + + def test_cancel_all(self): + """Test cancelling all pending requests.""" + handler = ResponseHandler() + + futures = [ + handler.register_request(f"req-{i}", timeout=5.0) + for i in range(5) + ] + + cancelled = handler.cancel_all() + + assert cancelled == 5 + assert handler.pending_count == 0 + for future in futures: + assert future.cancelled() + + def test_get_statistics_empty(self): + """Test statistics with no history.""" + handler = ResponseHandler() + + stats = handler.get_statistics() + + assert stats["pending_count"] == 0 + assert stats["completed_count"] == 0 + assert stats["avg_latency_ms"] == 0.0 + + def test_get_statistics_with_data(self): + """Test statistics with response history.""" + handler = ResponseHandler() + + # Add some requests and responses + for i in range(10): + handler.register_request(f"req-{i}", timeout=5.0) + response = Response() + response.request_id = f"req-{i}" + response.success = True + response.exec_time = 100.0 # μs + handler.handle_response(response) + + stats = handler.get_statistics() + + assert stats["completed_count"] == 10 + assert stats["avg_latency_ms"] >= 0 + assert stats["p95_latency_ms"] >= 0 + assert stats["p99_latency_ms"] >= 0 + + def test_metrics_history_limit(self): + """Test that metrics history is limited.""" + handler = ResponseHandler(max_history=10) + + # Add more than max_history responses + for i in range(20): + handler.register_request(f"req-{i}", timeout=5.0) + response = Response() + response.request_id = f"req-{i}" + response.success = True + handler.handle_response(response) + + assert len(handler.metrics_history) == 10 + + +class TestRequestMetrics: + """Tests for RequestMetrics dataclass.""" + + def test_metrics_creation(self): + """Test creating RequestMetrics.""" + metrics = RequestMetrics( + request_id="test", + request_timestamp=1000.0, + response_timestamp=1000.005, + total_latency_ms=5.0, + exec_time_ms=1.0, + network_latency_ms=4.0, + success=True, + ) + + assert metrics.request_id == "test" + assert metrics.total_latency_ms == 5.0 + assert metrics.network_latency_ms == 4.0 + assert metrics.success is True diff --git a/tests/test_tcp_client.py b/tests/test_tcp_client.py new file mode 100644 index 0000000..029a8a7 --- /dev/null +++ b/tests/test_tcp_client.py @@ -0,0 +1,156 @@ +"""Tests for TCP client.""" + +import asyncio +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from af_server_client.core.tcp_client import TCPClient, SendError +from af_server_client.core.config import Config + + +class TestTCPClient: + """Tests for TCPClient class.""" + + @pytest.fixture + def config(self): + """Create a test configuration.""" + config = Config() + config.network.host = "127.0.0.1" + config.network.port = 2001 + config.network.timeout = 1.0 + config.network.heartbeat_interval = 0 # Disable for tests + return config + + @pytest.fixture + def client(self, config): + """Create a test client.""" + return TCPClient(config) + + def test_client_creation(self, client, config): + """Test creating a TCPClient.""" + assert client.host == "127.0.0.1" + assert client.port == 2001 + assert client.connected is False + + def test_client_uptime_disconnected(self, client): + """Test uptime when disconnected.""" + assert client.uptime == 0.0 + + def test_add_remove_callback(self, client): + """Test adding and removing callbacks.""" + callback = MagicMock() + + client.add_callback(callback) + assert callback in client._callbacks + + client.remove_callback(callback) + assert callback not in client._callbacks + + def test_emit_event(self, client): + """Test emitting events to callbacks.""" + callback = MagicMock() + client.add_callback(callback) + + client._emit_event("test", {"data": "value"}) + + callback.assert_called_once_with("test", {"data": "value"}) + + def test_emit_event_callback_error(self, client): + """Test that callback errors don't propagate.""" + callback = MagicMock(side_effect=Exception("Callback error")) + client.add_callback(callback) + + # Should not raise + client._emit_event("test", None) + + def test_get_statistics_disconnected(self, client): + """Test getting statistics when disconnected.""" + stats = client.get_statistics() + + assert stats["connected"] is False + assert stats["host"] == "127.0.0.1" + assert stats["port"] == 2001 + assert stats["uptime_seconds"] == 0.0 + + @pytest.mark.asyncio + async def test_send_command_not_connected(self, client): + """Test sending command when not connected.""" + from af_server_client.core.protocol import Protocol + + cmd = Protocol() + cmd.data = "test" + + with pytest.raises(SendError, match="Not connected"): + await client.send_command(cmd) + + @pytest.mark.asyncio + async def test_disconnect_when_not_connected(self, client): + """Test disconnecting when already disconnected.""" + # Should not raise + await client.disconnect() + assert client.connected is False + + @pytest.mark.asyncio + async def test_connect_already_connected(self, client): + """Test connect returns True when already connected.""" + client.connected = True + + result = await client.connect() + + assert result is True + + +class TestTCPClientWithMockConnection: + """Tests for TCPClient with mocked connection.""" + + @pytest.fixture + def config(self): + """Create a test configuration.""" + config = Config() + config.network.host = "127.0.0.1" + config.network.port = 2001 + config.network.timeout = 1.0 + config.network.heartbeat_interval = 0 + return config + + @pytest.mark.asyncio + async def test_connect_success(self, config): + """Test successful connection.""" + client = TCPClient(config) + + mock_reader = AsyncMock() + mock_writer = AsyncMock() + mock_writer.close = MagicMock() + mock_writer.wait_closed = AsyncMock() + + with patch("socket.socket") as mock_socket_class: + mock_socket = MagicMock() + mock_socket_class.return_value = mock_socket + + with patch("asyncio.get_event_loop") as mock_loop: + mock_loop.return_value.sock_connect = AsyncMock() + + with patch("asyncio.open_connection", return_value=(mock_reader, mock_writer)): + result = await client.connect() + + assert result is True + assert client.connected is True + + # Cleanup + await client.disconnect() + + @pytest.mark.asyncio + async def test_connect_timeout(self, config): + """Test connection timeout.""" + client = TCPClient(config) + + with patch("socket.socket"): + with patch("asyncio.get_event_loop") as mock_loop: + mock_loop.return_value.sock_connect = AsyncMock( + side_effect=asyncio.TimeoutError() + ) + + with pytest.raises(Exception, match="timed out"): + await client.connect() + + assert client.connected is False