diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 022407e..8b13b15 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -42,9 +42,10 @@ jobs: run: | python -m pip install pip-audit python -m pip_audit - - name: Set development API_URL for non-main branches + - name: Set development api urls for non-main branches run: | - echo "API_URL=https://api-dev.rouast.com/vitallens-dev/file" >> $GITHUB_ENV + echo "API_FILE_URL=https://api-dev.rouast.com/vitallens-dev/file" >> $GITHUB_ENV + echo "API_STREAM_URL=https://api-dev.rouast.com/vitallens-dev/stream" >> $GITHUB_ENV echo "API_RESOLVE_URL=https://api-dev.rouast.com/vitallens-dev/resolve-model" >> $GITHUB_ENV - name: Lint with flake8 run: | @@ -88,9 +89,10 @@ jobs: run: | python -m pip install pip-audit python -m pip_audit - - name: Set development API_URL for non-main branches + - name: Set development api urls for non-main branches run: | - echo "API_URL=https://api-dev.rouast.com/vitallens-dev/file" | Out-File -FilePath $env:GITHUB_ENV -Append + echo "API_FILE_URL=https://api-dev.rouast.com/vitallens-dev/file" | Out-File -FilePath $env:GITHUB_ENV -Append + echo "API_STREAM_URL=https://api-dev.rouast.com/vitallens-dev/stream" | Out-File -FilePath $env:GITHUB_ENV -Append echo "API_RESOLVE_URL=https://api-dev.rouast.com/vitallens-dev/resolve-model" | Out-File -FilePath $env:GITHUB_ENV -Append - name: Lint with flake8 run: | diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..b0e7af4 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,44 @@ +name: Publish and release + +on: + push: + tags: + - 'v*' + +jobs: + build-publish-release: + name: Build, publish to PyPI, and create GitHub Release + runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/p/vitallens + permissions: + id-token: write + contents: write + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install build tool + run: python -m pip install --upgrade build + + - name: Build source and wheel distributions + run: python -m build + + - name: Publish package to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + + - name: Create GitHub Release + uses: softprops/action-gh-release@v2 + with: + files: dist/* + generate_release_notes: true + draft: false + prerelease: ${{ contains(github.ref_name, 'beta') }} \ No newline at end of file diff --git a/LICENSE b/LICENSE index 5bd51e2..df5cb27 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2024 Rouast Labs +Copyright (c) 2026 Rouast Labs Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 5cfcef8..13a37e0 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ The library provides: - **High-Fidelity Accuracy:** A simple interface to the VitalLens API for state-of-the-art estimation (heart rate, respiratory rate, HRV). - **Local Fallbacks:** Implementations of classic rPPG algorithms (`pos`, `chrom`, `g`) for local, API-free processing. - **Flexible Input:** Support for video files and in-memory `np.ndarray`. +- **Real-time Streaming:** `stream()` context manager for low-latency live inference. - **Face Detection:** Integrated fast face detection and ROI management. Using a different language? Check out our [JavaScript client](https://github.com/Rouast-Labs/vitallens.js) and [iOS SDK](https://github.com/Rouast-Labs/vitallens-ios). @@ -45,7 +46,7 @@ from vitallens import VitalLens vl = VitalLens(method="pos") results = vl("path/to/video.mp4") -print("Heart Rate:", results[0]['vital_signs']['heart_rate']['value']) +print("Heart Rate:", results[0]['vitals']['heart_rate']['value']) ``` ### Get High-Fidelity Accuracy (with API Key) @@ -53,13 +54,13 @@ print("Heart Rate:", results[0]['vital_signs']['heart_rate']['value']) To get improved accuracy and advanced metrics like **Respiratory Rate** and **HRV**, use the `vitallens` method. You can get a free key from the [API Dashboard](https://www.rouast.com/api). ```python -from vitallens import VitalLens, Method +from vitallens import VitalLens # Automatically selects the best model for your plan vl = VitalLens(method="vitallens", api_key="YOUR_API_KEY") results = vl("path/to/video.mp4") -vitals = results[0]['vital_signs'] +vitals = results[0]['vitals'] print(f"Heart Rate: {vitals['heart_rate']['value']:.1f} bpm") print(f"Respiratory Rate: {vitals['respiratory_rate']['value']:.1f} rpm") @@ -68,6 +69,23 @@ print(f"Respiratory Rate: {vitals['respiratory_rate']['value']:.1f} rpm") if 'hrv_sdnn' in vitals: print(f"HRV (SDNN): {vitals['hrv_sdnn']['value']:.1f} ms") ``` + +### Real-time Streaming + +Process live frames from a webcam or stream. + +```python +import time +from vitallens import VitalLens + +# Process live frames +vl = VitalLens(method="vitallens", api_key="YOUR_API_KEY") + +with vl.stream() as session: + # In your capture loop (e.g., OpenCV) + session.push(frame, timestamp=time.time()) + results = session.get_result(block=False) +``` ## Documentation diff --git a/docs/results.md b/docs/results.md index c15e44a..de235ed 100644 --- a/docs/results.md +++ b/docs/results.md @@ -23,7 +23,7 @@ Each face entry follows this structure. Note that strict types (like `np.ndarray "confidence": [0.6115, 0.9207, 0.9183, ...], "note": "Face detection coordinates..." }, - "vital_signs": { + "vitals": { "heart_rate": { "value": 60.5, "unit": "bpm", @@ -43,6 +43,14 @@ Each face entry follows this structure. Note that strict types (like `np.ndarray "note": "Global estimate of Heart Rate Variability (SDNN)..." } }, + "waveforms": { + "ppg_waveform": { + "data": [0.1, -0.2, ...], + "unit": "unitless", + "confidence": [0.9, 0.9, ...], + "note": "..." + } + }, "message": "The provided values are estimates..." } ] diff --git a/examples/README.md b/examples/README.md index 83fd953..973c15a 100644 --- a/examples/README.md +++ b/examples/README.md @@ -3,6 +3,15 @@ The `examples/` folder contains sample scripts and video data to help you evaluate `vitallens` against ground truth data, run it in Docker, or integrate it into your own pipeline. +## Real-time Webcam Demo (`live.py`) + +This script opens your webcam and streams frames to the API in chunks. + +```bash +pip install opencv-python +python examples/live.py --method=vitallens --api_key=YOUR_API_KEY +``` + ## Evaluation Script (`test.py`) This directory contains sample scripts and video data to help you evaluate `vitallens` against ground truth data, run it in Docker, or integrate it into your own pipeline. @@ -54,7 +63,7 @@ vl = VitalLens(method="vitallens", api_key="YOUR_API_KEY") results = vl("path/to/video.mp4") # Access results -print("Heart Rate:", results[0]['vital_signs']['heart_rate']['value']) +print("Heart Rate:", results[0]['vitals']['heart_rate']['value']) ``` ### Processing Raw Frames (Numpy/OpenCV) @@ -88,6 +97,51 @@ video_arr = np.array(frames) results = vl(video_arr, fps=fps) ``` +### Real-time Streaming + +For live feeds or webcams. There are two ways to handle results: + +#### Polling (Non-blocking) + +Best for applications with their own main loop (e.g., OpenCV display). + +```python +import time +from vitallens import VitalLens + +vl = VitalLens(method="vitallens", api_key="YOUR_API_KEY") + +with vl.stream() as session: + while True: + frame, ts = get_frame() # Your capture logic + session.push(frame, timestamp=ts) + + # Check for results whenever you want + results = session.get_result(block=False) + if results: + print(results[0]['vitals']['heart_rate']['value']) +``` + +#### Callback + +Best for event-driven applications. The callback is triggered automatically as soon as inference finishes. + +```python +import time +from vitallens import VitalLens + +def my_callback(results): + print(f"Callback received HR: {results[0]['vitals']['heart_rate']['value']}") + +vl = VitalLens(method="vitallens", api_key="YOUR_API_KEY") + +with vl.stream(on_result=my_callback) as session: + while True: + frame, ts = get_frame() + session.push(frame, timestamp=ts) + # No need to call get_result() +``` + ## Running with Docker If you encounter dependency issues (e.g., with `onnxruntime` or `ffmpeg`), you can run the example scripts inside our Docker container. @@ -114,14 +168,3 @@ Since the plot cannot display inside the container, copy it out after running: ```bash docker cp :/app/results.png . ``` - -## Real-time Webcam Demo (`live.py`) - -> **Note:** This script is experimental and optimized for testing the `Mode.BURST` functionality. - -This script opens your webcam and streams frames to the API in chunks. - -```bash -pip install opencv-python -python examples/live.py --method=vitallens --api_key=YOUR_API_KEY -``` diff --git a/examples/live.py b/examples/live.py index 9554c52..2fdb145 100644 --- a/examples/live.py +++ b/examples/live.py @@ -1,191 +1,188 @@ import argparse -import concurrent.futures +from collections import deque import cv2 import numpy as np -from prpy.constants import SECONDS_PER_MINUTE -from prpy.numpy.face import get_upper_body_roi_from_det -from prpy.numpy.physio import estimate_rate_from_signal, EScope, EMethod -from prpy.numpy.physio import HR_MIN, HR_MAX, RR_MIN, RR_MAX -import sys -import threading import time -import warnings - -sys.path.append('../vitallens-python') -from vitallens import VitalLens, Mode, Method -from vitallens.buffer import SignalBuffer, MultiSignalBuffer -from vitallens.constants import API_MIN_FRAMES - -def draw_roi(frame, roi): - roi = np.asarray(roi).astype(np.int32) - frame = cv2.rectangle(frame, (roi[0], roi[1]), (roi[2], roi[3]), (0, 255, 0), 1) - -def draw_signal(frame, roi, sig, sig_name, sig_conf_name, draw_area_tl_x, draw_area_tl_y, color): - def _draw(frame, vals, display_height, display_width, min_val, max_val, color, thickness): - height_mult = display_height/(max_val - min_val) - width_mult = display_width/(vals.shape[0] - 1) - p1 = (int(draw_area_tl_x), int(draw_area_tl_y + (max_val - vals[0]) * height_mult)) - for i, s in zip(range(1, len(vals)), vals[1:]): - p2 = (int(draw_area_tl_x + i * width_mult), int(draw_area_tl_y + (max_val - s) * height_mult)) - frame = cv2.line(frame, p1, p2, color, thickness) - p1 = p2 - # Derive dims from roi - display_height = (roi[3] - roi[1]) / 2.0 - display_width = (roi[2] - roi[0]) * 1.2 - # Draw signal - if sig_name in sig: - vals = np.asarray(sig[sig_name]) - min_val = np.min(vals) - max_val = np.max(vals) - if max_val - min_val == 0: - return frame - _draw(frame=frame, vals=vals, display_height=display_height, display_width=display_width, - min_val=min_val, max_val=max_val, color=color, thickness=2) - # Draw confidence - if sig_conf_name in sig: - vals = np.asarray(sig[sig_conf_name]) - _draw(frame=frame, vals=vals, display_height=display_height, display_width=display_width, - min_val=0., max_val=1., color=color, thickness=1) - -def draw_fps(frame, fps, text, draw_area_bl_x, draw_area_bl_y): - cv2.putText(frame, text=f"{text}: {fps:.1f}", org=(draw_area_bl_x, draw_area_bl_y), - fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.6, color=(0,255,0), thickness=1) - -def draw_vital(frame, sig, text, sig_name, fps, color, draw_area_bl_x, draw_area_bl_y): - if sig_name in sig: - f_range = (HR_MIN/SECONDS_PER_MINUTE, HR_MAX/SECONDS_PER_MINUTE) if 'ppg' in sig_name else (RR_MIN/SECONDS_PER_MINUTE, RR_MAX/SECONDS_PER_MINUTE) - val = estimate_rate_from_signal(signal=sig[sig_name], f_s=fps, f_range=f_range, scope=EScope.GLOBAL, method=EMethod.PERIODOGRAM) - cv2.putText(frame, text=f"{text}: {val:.1f}", org=(draw_area_bl_x, draw_area_bl_y), - fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.6, color=color, thickness=2) - -class VitalLensRunnable: - def __init__(self, method, api_key, proxies=None): - self.active = threading.Event() - self.result = [] - self.vl = VitalLens(method=method, - mode=Mode.BURST, - api_key=api_key, - proxies=proxies, - detect_faces=True, - estimate_rolling_vitals=True, - export_to_json=False) - def __call__(self, inputs, fps): - self.active.set() - self.result = self.vl(np.asarray(inputs), fps=fps) - self.active.clear() - -def run(args): +from vitallens import VitalLens +import vitallens_core as vc + +def hex_to_bgr(hex_color): + if not hex_color: return (255, 255, 255) + hex_color = hex_color.lstrip('#') + r, g, b = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4)) + return (b, g, r) + +def draw_waveform(frame, data, color, rect, title): + x, y, w, h = rect + cv2.putText(frame, title, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (200, 200, 200), 1) + if len(data) < 2: return + + min_val, max_val = min(data), max(data) + if max_val - min_val == 0: return + + pts = [] + step = w / (len(data) - 1) + for i, val in enumerate(data): + px = int(x + i * step) + py = int(y + h - ((val - min_val) / (max_val - min_val)) * h) + pts.append((px, py)) + + pts = np.array(pts, np.int32).reshape((-1, 1, 2)) + cv2.polylines(frame, [pts], isClosed=False, color=color, thickness=2, lineType=cv2.LINE_AA) + +def main(): + parser = argparse.ArgumentParser(description="VitalLens Live Webcam Demo") + parser.add_argument('--method', type=str, default='pos', help='Method to use (e.g., pos, chrom, g, vitallens)') + parser.add_argument('--api_key', type=str, default=None, help='API key (required for vitallens method)') + args = parser.parse_args() + + vl = VitalLens(method=args.method, api_key=args.api_key) + vl.rppg.fps_target = 15.0 + cap = cv2.VideoCapture(0) - executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) - proxies = None - if args.proxy: - proxies = {"https": args.proxy, "http": args.proxy} - vl = VitalLensRunnable(method=args.method, api_key=args.api_key, proxies=proxies) - signal_buffer = MultiSignalBuffer(size=240, ndim=1, ignore_k=['face']) - fps_buffer = SignalBuffer(size=240, ndim=1, pad_val=np.nan) - frame_buffer = [] - # Sample frames from cv2 video stream attempting to achieve this framerate - target_fps = 30. - # Check if the webcam is opened correctly if not cap.isOpened(): - raise IOError("Cannot open webcam") - # Read first frame to get dims - _, frame = cap.read() - height, width, _ = frame.shape - roi = None - i = 0 - t, p_t = time.time(), time.time() - fps, p_fps = 30.0, 30.0 - ds_factor = 1 - n_frames = 0 - signals = None - while True: - ret, frame = cap.read() - if not ret: - break - # Measure frequency - t_prev = t - t = time.time() - if not vl.active.is_set(): - # Process result if available - if len(vl.result) > 0: - # Results are available - fetch and reset - result = vl.result[0] - vl.result = [] - # Update the buffer - signals = signal_buffer.update({ - **{ - f"{key}_sig": value['value'] if 'value' in value else np.array(value['data']) - for key, value in result['vital_signs'].items() - }, - **{ - f"{key}_conf": value['confidence'] if isinstance(value['confidence'], np.ndarray) else np.array(value['confidence']) - for key, value in result['vital_signs'].items() - }, - 'face_conf': result['face']['confidence'], - }, dt=n_frames) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=RuntimeWarning) - # Measure actual effective sampling frequency at which neural net input was sampled - fps = np.nanmean(fps_buffer.update([(1./(t - t_prev))/ds_factor], dt=n_frames)) - roi = get_upper_body_roi_from_det(result['face']['coordinates'][-1], clip_dims=(width, height), cropped=True) - # Measure prediction frequency - how often predictions are made - p_t_prev = p_t - p_t = time.time() - p_fps = 1./(p_t - p_t_prev) + print("Error: Could not open webcam.") + return + + print(f"Starting live stream using {args.method}. Press 'q' to quit.") + + font = cv2.FONT_HERSHEY_SIMPLEX + latest_vitals = {} + latest_coords = None + ppg_history = deque(maxlen=150) + ppg_conf = deque(maxlen=150) + resp_history = deque(maxlen=150) + resp_conf = deque(maxlen=150) + vital_conf_thresh = 0.8 + hrv_conf_thresh = 0.7 + start_time = time.time() + + ppg_meta = vc.get_vital_info("ppg_waveform") + hr_meta = vc.get_vital_info("heart_rate") + sdnn_meta = vc.get_vital_info("hrv_sdnn") + rmssd_meta = vc.get_vital_info("hrv_rmssd") + resp_meta = vc.get_vital_info("respiratory_waveform") + rr_meta = vc.get_vital_info("respiratory_rate") + + with vl.stream() as session: + while True: + ret, frame = cap.read() + if not ret: + break + + timestamp = time.time() - start_time + rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + session.push(rgb_frame, timestamp) + + if session.current_face is not None: + if latest_coords is None: + print("[DEBUG] Face initially detected!") + latest_coords = session.current_face + + res = session.get_result(block=False) + + if res and len(res) > 0: + new_vitals = res[0].get('vitals', {}) + if new_vitals: + latest_vitals.update(new_vitals) + + waveforms = res[0].get('waveforms', {}) + if 'ppg_waveform' in waveforms: + ppg_history.extend(waveforms['ppg_waveform']['data']) + ppg_conf.extend(waveforms['ppg_waveform']['confidence']) + if 'respiratory_waveform' in waveforms: + resp_history.extend(waveforms['respiratory_waveform']['data']) + resp_conf.extend(waveforms['respiratory_waveform']['confidence']) + + if session.current_face is None: + state_text = "Searching" + msg_text = "Position your face in the frame" + color = (0, 165, 255) + elif 'heart_rate' not in latest_vitals: + state_text = "Calibrating" + msg_text = "Hold still and ensure good lighting..." + color = (255, 0, 255) else: - # No results available - roi = None - signal_buffer.clear() - # Start next prediction - is_api_method = str(args.method).lower().startswith("vitallens") - is_api_method = is_api_method or (isinstance(args.method, Method) and args.method == Method.VITALLENS) - if len(frame_buffer) >= (API_MIN_FRAMES if is_api_method else 1): - n_frames = len(frame_buffer) - executor.submit(vl, frame_buffer.copy(), fps) - frame_buffer.clear() - # Sample frames - if i % ds_factor == 0: - # Add current frame to the buffer (BGR -> RGB) - frame_buffer.append(frame[...,::-1]) - i += 1 - # Display - if roi is not None: - draw_roi(frame, roi) - draw_signal( - frame=frame, roi=roi, sig=signals, sig_name='ppg_waveform_sig', sig_conf_name='ppg_waveform_conf', - draw_area_tl_x=roi[2]+20, draw_area_tl_y=roi[1], color=(0, 0, 255)) - draw_signal( - frame=frame, roi=roi, sig=signals, sig_name='respiratory_waveform_sig', sig_conf_name='respiratory_waveform_conf', - draw_area_tl_x=roi[2]+20, draw_area_tl_y=int(roi[1]+(roi[3]-roi[1])/2.0), color=(255, 0, 0)) - draw_fps(frame, fps=fps, text="fps", draw_area_bl_x=roi[0], draw_area_bl_y=roi[3]+20) - draw_fps(frame, fps=p_fps, text="p_fps", draw_area_bl_x=int(roi[0]+0.4*(roi[2]-roi[0])), draw_area_bl_y=roi[3]+20) - draw_vital(frame, sig=signals, text="hr [bpm]", sig_name='ppg_waveform_sig', fps=fps, color=(0,0,255), draw_area_bl_x=roi[2]+20, draw_area_bl_y=int(roi[1]+(roi[3]-roi[1])/2.0)) - draw_vital(frame, sig=signals, text="rr [rpm]", sig_name='respiratory_waveform_sig', fps=fps, color=(255,0,0), draw_area_bl_x=roi[2]+20, draw_area_bl_y=roi[3]) - cv2.imshow('Live', frame) - c = cv2.waitKey(1) - if c == 27: - break - # Even out fps - dt_req = 1./target_fps - (time.time() - t) - if dt_req > 0: time.sleep(dt_req) + state_text = "Tracking" + msg_text = "Tracking vitals" + color = (0, 255, 0) + + fh, fw = frame.shape[:2] + + # Top bar + cv2.rectangle(frame, (0, 0), (fw, 55), (30, 30, 30), -1) + cv2.putText(frame, state_text, (20, 25), font, 0.7, color, 2) + cv2.putText(frame, msg_text, (20, 45), font, 0.5, (200, 200, 200), 1) + + if session.current_face is not None: + x1, y1, x2, y2 = map(int, session.current_face) + cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) + + # Bottom panel + panel_h = 240 + cv2.rectangle(frame, (0, fh - panel_h), (fw, fh), (30, 30, 30), -1) + + n_samples = int(vl.rppg.fps_target) + ppg_last_sec = list(ppg_conf)[-n_samples:] + resp_last_sec = list(resp_conf)[-n_samples:] + avg_ppg_conf = sum(ppg_last_sec) / len(ppg_last_sec) if ppg_last_sec else 0 + avg_resp_conf = sum(resp_last_sec) / len(resp_last_sec) if resp_last_sec else 0 + + wave_w = fw - 240 + + # Fetch metadata + ppg_meta = vc.get_vital_info("ppg_waveform") + hr_meta = vc.get_vital_info("heart_rate") + sdnn_meta = vc.get_vital_info("hrv_sdnn") + rmssd_meta = vc.get_vital_info("hrv_rmssd") + resp_meta = vc.get_vital_info("respiratory_waveform") + rr_meta = vc.get_vital_info("respiratory_rate") + + # --- PPG Row --- + row1_y = fh - panel_h + 30 + if avg_ppg_conf >= vital_conf_thresh: + draw_waveform(frame, list(ppg_history), hex_to_bgr(ppg_meta.color), (20, row1_y, wave_w - 40, 60), ppg_meta.display_name) + + # HR + hr_val = "--" + if 'heart_rate' in latest_vitals and latest_vitals['heart_rate']['confidence'] >= vital_conf_thresh: + hr_val = f"{latest_vitals['heart_rate']['value']:.0f}" + cv2.putText(frame, f"{hr_meta.short_name}", (wave_w, row1_y), font, 0.5, (200, 200, 200), 1) + cv2.putText(frame, f"{hr_val} {hr_meta.unit}", (wave_w, row1_y + 30), font, 0.7, hex_to_bgr(hr_meta.color), 2) + + # SDNN + sdnn_val = "--" + if 'hrv_sdnn' in latest_vitals and latest_vitals['hrv_sdnn']['confidence'] >= hrv_conf_thresh: + sdnn_val = f"{latest_vitals['hrv_sdnn']['value']:.0f}" + cv2.putText(frame, f"{sdnn_meta.short_name}", (wave_w + 120, row1_y), font, 0.4, (200, 200, 200), 1) + cv2.putText(frame, f"{sdnn_val} {sdnn_meta.unit}", (wave_w + 120, row1_y + 20), font, 0.5, hex_to_bgr(sdnn_meta.color), 1) + + # RMSSD + rmssd_val = "--" + if 'hrv_rmssd' in latest_vitals and latest_vitals['hrv_rmssd']['confidence'] >= hrv_conf_thresh: + rmssd_val = f"{latest_vitals['hrv_rmssd']['value']:.0f}" + cv2.putText(frame, f"{rmssd_meta.short_name}", (wave_w + 120, row1_y + 50), font, 0.4, (200, 200, 200), 1) + cv2.putText(frame, f"{rmssd_val} {rmssd_meta.unit}", (wave_w + 120, row1_y + 70), font, 0.5, hex_to_bgr(rmssd_meta.color), 1) + + # --- Resp Row --- + row2_y = fh - panel_h + 140 + if avg_resp_conf >= vital_conf_thresh: + draw_waveform(frame, list(resp_history), hex_to_bgr(resp_meta.color), (20, row2_y, wave_w - 40, 60), resp_meta.display_name) + + rr_val = "--" + if 'respiratory_rate' in latest_vitals and latest_vitals['respiratory_rate']['confidence'] >= vital_conf_thresh: + rr_val = f"{latest_vitals['respiratory_rate']['value']:.0f}" + cv2.putText(frame, f"{rr_meta.short_name}", (wave_w, row2_y), font, 0.5, (200, 200, 200), 1) + cv2.putText(frame, f"{rr_val} {rr_meta.unit}", (wave_w, row2_y + 30), font, 0.7, hex_to_bgr(rr_meta.color), 2) + + cv2.imshow("VitalLens Live", frame) + + if cv2.waitKey(1) & 0xFF == ord('q'): + break cap.release() cv2.destroyAllWindows() -def method_type(name): - try: - return Method[name.upper()] - except KeyError: - pass - if name.lower().startswith("vitallens"): - return name - raise argparse.ArgumentTypeError(f"{name} is not a valid Method") - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--api_key', type=str, default=None, help='Your API key (Optional if using proxy auth).') - parser.add_argument('--method', type=method_type, default='vitallens', help='Choice of method (vitallens, pos, chrom, g, or specific version like vitallens-2.0)') - parser.add_argument('--proxy', type=str, default=None, help='Proxy URL (e.g., http://user:pass@10.10.1.10:3128)') - args = parser.parse_args() - run(args) \ No newline at end of file +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/test.py b/examples/test.py index ebc6b5c..2947d89 100644 --- a/examples/test.py +++ b/examples/test.py @@ -70,56 +70,56 @@ def run(args=None): if not result: print("No faces detected, cannot plot results.") return - vital_signs = result[0]['vital_signs'] - if "respiratory_waveform" in vital_signs: + vitals = result[0].get('vitals', {}) + waveforms = result[0].get('waveforms', {}) + rolling_vitals = result[0].get('rolling_vitals', {}) + if "respiratory_waveform" in waveforms: fig, (ax1, ax2) = plt.subplots(2, sharex=True, figsize=(15, 8)) else: fig, ax1 = plt.subplots(1, figsize=(15, 6)) method_name = args.method.name if hasattr(args.method, 'name') else str(args.method) method_color = get_method_color(args.method) fig.suptitle(f"Vital signs from {args.video_path} using {method_name} ({time_ms:.2f} ms)") - # PPG Waveform and Heart Rate Plot (ax1) ax1.set_ylabel('Waveform (unitless)', color=method_color) ax1.tick_params(axis='y', labelcolor=method_color) - if "ppg_waveform" in vital_signs: - hr_string = f" -> Global HR: {vital_signs['heart_rate']['value']:.1f} bpm" if "heart_rate" in vital_signs else "" - ax1.plot(vital_signs['ppg_waveform']['data'], color=method_color, label=f"PPG Waveform{hr_string}", zorder=10) - ax1.plot(vital_signs['ppg_waveform']['confidence'], color=method_color, linestyle='--', label='PPG Confidence', zorder=5) + if "ppg_waveform" in waveforms: + hr_string = f" -> Global HR: {vitals['heart_rate']['value']:.1f} bpm" if "heart_rate" in vitals else "" + ax1.plot(waveforms['ppg_waveform']['data'], color=method_color, label=f"PPG Waveform{hr_string}", zorder=10) + ax1.plot(waveforms['ppg_waveform']['confidence'], color=method_color, linestyle='--', label='PPG Confidence', zorder=5) if ppg_gt is not None: hr_gt = estimate_hr_from_signal(signal=ppg_gt, f_s=fps, scope=EScope.GLOBAL, method=EMethod.PERIODOGRAM) ax1.plot(ppg_gt, color=COLOR_GT, label=f"Ground Truth PPG -> HR: {hr_gt:.1f} bpm", zorder=0) ax1_handles, ax1_labels = ax1.get_legend_handles_labels() - if "rolling_heart_rate" in vital_signs: + if "heart_rate" in rolling_vitals: ax1_hr = ax1.twinx() hr_color = '#ff7f0e' ax1_hr.set_ylabel('Heart Rate (bpm)', color=hr_color) ax1_hr.tick_params(axis='y', labelcolor=hr_color) ax1_hr.set_ylim(35, 180) - rolling_hr_data = vital_signs['rolling_heart_rate']['data'] + rolling_hr_data = rolling_vitals['heart_rate']['data'] line_hr, = ax1_hr.plot(rolling_hr_data, color=hr_color, linestyle='-', label='Rolling Heart Rate') ax1_handles.append(line_hr) ax1_labels.append('Rolling Heart Rate') ax1.legend(ax1_handles, ax1_labels, loc='upper left') ax1.grid(True, linestyle=':', alpha=0.6) - # Respiratory Waveform and Rate Plot (ax2) - if "respiratory_waveform" in vital_signs: + if "respiratory_waveform" in waveforms: ax2.set_xlabel('Frame Index') ax2.set_ylabel('Waveform (unitless)', color=method_color) ax2.tick_params(axis='y', labelcolor=method_color) - rr_string = f" -> Global RR: {vital_signs['respiratory_rate']['value']:.1f} bpm" if "respiratory_rate" in vital_signs else "" - ax2.plot(vital_signs['respiratory_waveform']['data'], color=method_color, label=f"Respiratory Waveform{rr_string}", zorder=10) - ax2.plot(vital_signs['respiratory_waveform']['confidence'], color=method_color, linestyle='--', label='Respiratory Confidence', zorder=5) + rr_string = f" -> Global RR: {vitals['respiratory_rate']['value']:.1f} bpm" if "respiratory_rate" in vitals else "" + ax2.plot(waveforms['respiratory_waveform']['data'], color=method_color, label=f"Respiratory Waveform{rr_string}", zorder=10) + ax2.plot(waveforms['respiratory_waveform']['confidence'], color=method_color, linestyle='--', label='Respiratory Confidence', zorder=5) if resp_gt is not None: rr_gt = estimate_rr_from_signal(signal=resp_gt, f_s=fps, scope=EScope.GLOBAL, method=EMethod.PERIODOGRAM) - ax2.plot(resp_gt, color=COLOR_GT, label=f"Ground Truth Respiration -> RR: {rr_gt:.1f} bpm", zorder=0) + ax2.plot(resp_gt, color=COLOR_GT, label=f"Ground Truth Respiration -> RR: {rr_gt:.1f} bpm", zorder=0) ax2_handles, ax2_labels = ax2.get_legend_handles_labels() - if "rolling_respiratory_rate" in vital_signs: + if "respiratory_rate" in rolling_vitals: ax2_rr = ax2.twinx() rr_color = '#d62728' ax2_rr.set_ylabel('Respiratory Rate (bpm)', color=rr_color) ax2_rr.tick_params(axis='y', labelcolor=rr_color) ax2_rr.set_ylim(0, 45) - rolling_rr_data = vital_signs['rolling_respiratory_rate']['data'] + rolling_rr_data = rolling_vitals['respiratory_rate']['data'] line_rr, = ax2_rr.plot(rolling_rr_data, color=rr_color, linestyle='-', label='Rolling Respiratory Rate') ax2_handles.append(line_rr) ax2_labels.append('Rolling Respiratory Rate') diff --git a/pyproject.toml b/pyproject.toml index fc0b146..f9db0b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "python-dotenv>=1.0", "pyyaml>=6.0.1", "requests>=2.32.0", + "vitallens-core==0.2.3", ] dynamic = ["version"] diff --git a/tests/conftest.py b/tests/conftest.py index 8d03ad7..66ba78d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,23 +1,3 @@ -# Copyright (c) 2024 Philipp Rouast -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - import numpy as np import os from prpy.ffmpeg.probe import probe_video diff --git a/tests/test_buffer.py b/tests/test_buffer.py deleted file mode 100644 index b31a3b1..0000000 --- a/tests/test_buffer.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright (c) 2024 Rouast Labs -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import numpy as np -import pytest - -import sys -sys.path.append('../vitallens-python') - -from vitallens.buffer import SignalBuffer, MultiSignalBuffer - -@pytest.mark.parametrize("pad_val", [0, -1]) -def test_signal_buffer(pad_val): - # 1 dim - buffer = SignalBuffer(size=8, ndim=1, pad_val=pad_val) - with pytest.raises(Exception): - buffer.get() - np.testing.assert_allclose( - buffer.update(signal=[.2, .4,], dt=2), - np.asarray([pad_val, pad_val, pad_val, pad_val, pad_val, pad_val, .2, .4])) - np.testing.assert_allclose( - buffer.update(signal=[.1, .3, .5, .6], dt=2), - np.asarray([pad_val, pad_val, pad_val, pad_val, .15, .35, .5, .6])) - np.testing.assert_allclose( - buffer.update(signal=[.6, .7], dt=1), - np.asarray([pad_val, pad_val, pad_val, .15, .35, .5, .6, .7])) - np.testing.assert_allclose( - buffer.update(signal=[.8, .9, .8, .7, .6], dt=4), - np.asarray([.35, .5, .6, .75, .9, .8, .7, .6])) - # 2 dim - buffer = SignalBuffer(size=4, ndim=2, pad_val=pad_val) - with pytest.raises(Exception): - buffer.get() - np.testing.assert_allclose( - buffer.update(signal=[[.1, .2,]], dt=1), - np.asarray([[pad_val, pad_val], [pad_val, pad_val], [pad_val, pad_val], [.1, .2]])) - np.testing.assert_allclose( - buffer.update(signal=[[.1, .3], [.5, .6]], dt=3), - np.asarray([[.1, .2], [pad_val, pad_val], [.1, .3], [.5, .6]])) - np.testing.assert_allclose( - buffer.update(signal=[[.3, .2], [.5, .6], [.5, .6]], dt=2), - np.asarray([[.1, .3], [.4, .4], [.5, .6], [.5, .6]])) - -@pytest.mark.parametrize("pad_val", [0, -1]) -def test_multi_signal_buffer(pad_val): - # 1 dim, 2 signals - buffer = MultiSignalBuffer(size=8, ndim=1, ignore_k=[], pad_val=pad_val) - with pytest.raises(Exception): - buffer.get() - with pytest.raises(Exception): - buffer.update(signals=[0., 1.]) - out = buffer.update( - signals={"a": [.2, .4,], "b": [.1, .2]}, dt=1) - np.testing.assert_allclose( - out["a"], - np.asarray([pad_val, pad_val, pad_val, pad_val, pad_val, pad_val, .2, .4])) - np.testing.assert_allclose( - out["b"], - np.asarray([pad_val, pad_val, pad_val, pad_val, pad_val, pad_val, .1, .2])) - out = buffer.update( - signals={"a": [.1, .3, .5, .6], "b": [.1, .4, .5]}, dt=2) - np.testing.assert_allclose( - out["a"], - np.asarray([pad_val, pad_val, pad_val, pad_val, .15, .35, .5, .6])) - np.testing.assert_allclose( - out["b"], - np.asarray([pad_val, pad_val, pad_val, pad_val, .1, .15, .4, .5])) - out = buffer.update( - signals={"a": [.6, .7], "b": [.3], "c": [.4, .8]}, dt=1) - np.testing.assert_allclose( - out["a"], - np.asarray([pad_val, pad_val, pad_val, .15, .35, .5, .6, .7])) - np.testing.assert_allclose( - out["b"], - np.asarray([pad_val, pad_val, pad_val, .1, .15, .4, .5, .3])) - np.testing.assert_allclose( - out["c"], - np.asarray([pad_val, pad_val, pad_val, pad_val, pad_val, pad_val, .4, .8])) - # 2 dim, 2 signals - buffer = MultiSignalBuffer(size=4, ndim=2, pad_val=pad_val, ignore_k=['c']) - out = buffer.update( - signals={"a": [[.2, .4,], [.1, .4]], "b": [[.1, .2], [.6, .6]]}, dt=1) - np.testing.assert_allclose( - out["a"], - np.asarray([[pad_val, pad_val], [pad_val, pad_val], [.2, .4], [.1, .4]])) - np.testing.assert_allclose( - out["b"], - np.asarray([[pad_val, pad_val], [pad_val, pad_val], [.1, .2], [.6, .6]])) - out = buffer.update( - signals={"a": [[.1, .1,], [.1, .1,], [.1, .1]], "c": [[.1], [.1]]}, dt=2) - assert len(out) == 1 - np.testing.assert_allclose( - out["a"], - np.asarray([[.2, .4], [.1, .25], [.1, .1], [.1, .1]])) diff --git a/tests/test_client.py b/tests/test_client.py index 85a101a..dd9c9a1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,28 +1,8 @@ -# Copyright (c) 2024 Philipp Rouast -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - import json import numpy as np import os import pytest -from unittest.mock import patch, ANY +from unittest.mock import patch import sys sys.path.append('../vitallens-python') @@ -43,12 +23,12 @@ def test_VitalLens(request, method_arg, detect_faces, file, export): test_video_fps = request.getfixturevalue('test_video_fps') result = vl(test_video_ndarray, fps=test_video_fps, faces = None if detect_faces else [247, 57, 440, 334], export_filename="test") assert len(result) == 1 - assert result[0]['face']['coordinates'].shape == (630, 4) - assert result[0]['face']['confidence'].shape == (630,) - assert result[0]['vital_signs']['ppg_waveform']['data'].shape == (630,) - assert result[0]['vital_signs']['ppg_waveform']['confidence'].shape == (630,) - np.testing.assert_allclose(result[0]['vital_signs']['heart_rate']['value'], 60, atol=10) - assert result[0]['vital_signs']['heart_rate']['confidence'] == 1.0 + assert np.asarray(result[0]['face']['coordinates']).shape == (630, 4) + assert np.asarray(result[0]['face']['confidence']).shape == (630,) + assert np.asarray(result[0]['waveforms']['ppg_waveform']['data']).shape == (630,) + assert np.asarray(result[0]['waveforms']['ppg_waveform']['confidence']).shape == (630,) + np.testing.assert_allclose(result[0]['vitals']['heart_rate']['value'], 60, atol=10) + assert result[0]['vitals']['heart_rate']['confidence'] == 1.0 if export: test_json_path = os.path.join("test.json") assert os.path.exists(test_json_path) @@ -56,10 +36,10 @@ def test_VitalLens(request, method_arg, detect_faces, file, export): data = json.load(f) assert np.asarray(data[0]['face']['coordinates']).shape == (630, 4) assert np.asarray(data[0]['face']['confidence']).shape == (630,) - assert np.asarray(data[0]['vital_signs']['ppg_waveform']['data']).shape == (630,) - assert np.asarray(data[0]['vital_signs']['ppg_waveform']['confidence']).shape == (630,) - np.testing.assert_allclose(data[0]['vital_signs']['heart_rate']['value'], 60, atol=10) - assert data[0]['vital_signs']['heart_rate']['confidence'] == 1.0 + assert np.asarray(data[0]['waveforms']['ppg_waveform']['data']).shape == (630,) + assert np.asarray(data[0]['waveforms']['ppg_waveform']['confidence']).shape == (630,) + np.testing.assert_allclose(data[0]['vitals']['heart_rate']['value'], 60, atol=10) + assert data[0]['vitals']['heart_rate']['confidence'] == 1.0 os.remove(test_json_path) def test_VitalLens_API(request): @@ -69,32 +49,29 @@ def test_VitalLens_API(request): test_video_fps = request.getfixturevalue('test_video_fps') result = vl(test_video_ndarray, fps=test_video_fps, faces=None, export_filename="test") assert len(result) == 1 - assert result[0]['face']['coordinates'].shape == (630, 4) - assert result[0]['vital_signs']['ppg_waveform']['data'].shape == (630,) - assert result[0]['vital_signs']['ppg_waveform']['confidence'].shape == (630,) - assert result[0]['vital_signs']['respiratory_waveform']['data'].shape == (630,) - assert result[0]['vital_signs']['respiratory_waveform']['confidence'].shape == (630,) - assert 'hrv_sdnn' in result[0]['vital_signs'] - assert 'hrv_rmssd' in result[0]['vital_signs'] - assert 'hrv_lfhf' in result[0]['vital_signs'] - np.testing.assert_allclose(result[0]['vital_signs']['heart_rate']['value'], 60, atol=0.5) - np.testing.assert_allclose(result[0]['vital_signs']['heart_rate']['confidence'], 1.0, atol=0.1) - np.testing.assert_allclose(result[0]['vital_signs']['respiratory_rate']['value'], 13, atol=1.0) - np.testing.assert_allclose(result[0]['vital_signs']['respiratory_rate']['confidence'], 1.0, atol=0.1) + assert np.asarray(result[0]['face']['coordinates']).shape == (630, 4) + assert np.asarray(result[0]['waveforms']['ppg_waveform']['data']).shape == (630,) + assert np.asarray(result[0]['waveforms']['ppg_waveform']['confidence']).shape == (630,) + assert np.asarray(result[0]['waveforms']['respiratory_waveform']['data']).shape == (630,) + assert np.asarray(result[0]['waveforms']['respiratory_waveform']['confidence']).shape == (630,) + assert 'hrv_sdnn' in result[0]['vitals'] + assert 'hrv_rmssd' in result[0]['vitals'] + np.testing.assert_allclose(result[0]['vitals']['heart_rate']['value'], 60, atol=0.5) + np.testing.assert_allclose(result[0]['vitals']['heart_rate']['confidence'], 1.0, atol=0.1) + np.testing.assert_allclose(result[0]['vitals']['respiratory_rate']['value'], 13, atol=1.0) + np.testing.assert_allclose(result[0]['vitals']['respiratory_rate']['confidence'], 1.0, atol=0.1) supports_hrv = result[0].get('model_used', '') == 'vitallens-2.0' if supports_hrv: - np.testing.assert_allclose(result[0]['vital_signs']['hrv_sdnn']['value'], 40, atol=20) - np.testing.assert_allclose(result[0]['vital_signs']['hrv_rmssd']['value'], 30, atol=20) - np.testing.assert_allclose(result[0]['vital_signs']['hrv_lfhf']['value'], 1.5, atol=1.0) + np.testing.assert_allclose(result[0]['vitals']['hrv_sdnn']['value'], 40, atol=20) + np.testing.assert_allclose(result[0]['vitals']['hrv_rmssd']['value'], 30, atol=20) assert not os.path.exists("test.json") def test_VitalLens_proxies(): """Test that proxies are passed down correctly.""" proxies = {"https": "http://10.10.1.10:3128"} with patch('vitallens.client.VitalLensRPPGMethod') as MockRPPG: - vl = VitalLens(method="vitallens-2.0", api_key="test", proxies=proxies) + _ = VitalLens(method="vitallens-2.0", api_key="test", proxies=proxies) MockRPPG.assert_called_with( - mode=ANY, api_key="test", requested_model_name="vitallens-2.0", proxies=proxies @@ -104,10 +81,8 @@ def test_VitalLens_auth_offloading(): """Test initialization without API key (for proxy auth offloading).""" proxies = {"https": "http://auth-proxy:8080"} with patch('vitallens.client.VitalLensRPPGMethod') as MockRPPG: - # Should NOT raise an error despite api_key being None - vl = VitalLens(method="vitallens", api_key=None, proxies=proxies) + _ = VitalLens(method="vitallens", api_key=None, proxies=proxies) MockRPPG.assert_called_with( - mode=ANY, api_key=None, requested_model_name="vitallens", proxies=proxies diff --git a/tests/test_signal.py b/tests/test_signal.py index fdc3e82..3b8e429 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -1,31 +1,10 @@ -# Copyright (c) 2024 Rouast Labs -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - import numpy as np import pytest import sys sys.path.append('../vitallens-python') -from vitallens.enums import Method -from vitallens.signal import reassemble_from_windows, assemble_results, estimate_rolling_vitals +from vitallens.signal import reassemble_from_windows @pytest.mark.parametrize( "name, x_batches, idxs_batches, expected_x, expected_idxs", @@ -77,106 +56,3 @@ def test_reassemble_from_windows_edge_cases(name, x_batches, idxs_batches, expec out_x, out_idxs = reassemble_from_windows(x=x_batches, idxs=idxs_batches) np.testing.assert_allclose(out_x, expected_x, err_msg=f"Failed on case: {name} (data)") np.testing.assert_equal(out_idxs, expected_idxs, err_msg=f"Failed on case: {name} (indices)") - -@pytest.mark.parametrize("signals", [('ppg_waveform',), ('respiratory_waveform',), ('ppg_waveform','respiratory_waveform')]) -@pytest.mark.parametrize("can_provide_confidence", [True, False]) -@pytest.mark.parametrize("min_t_too_long", [True, False]) -@pytest.mark.parametrize("method_arg", [Method.G, "g", "vitallens-2.0"]) -def test_assemble_results(signals, can_provide_confidence, min_t_too_long, method_arg): - sample_video_hr = 60.5 - sample_video_rr = 12. - sig_ppg = [-0.4083,-0.4870,-0.5349,-0.5744,-0.5844,-0.2449,0.1615,0.3899,0.2642,0.4055,0.8509,1.2770,1.2852,1.1860,1.0460,0.8798,0.6355,0.3384,0.1496,-0.0557,-0.2071,-0.3669,-0.5196,-0.6868,-0.8333,-0.9207,-0.9780,-1.0261,-1.0442,-0.9659,-0.9299,-0.9369,-0.9989,-0.9606,-0.9246,-0.9377,-0.9436,-0.4722,0.3687,1.2693,1.7929,1.9902,1.9927,1.7719,1.4279,1.0693,0.7988,0.5610,0.3588,0.1722,0.0160,-0.1632,-0.3328,-0.5028,-0.6100,-0.7312,-0.8427,-0.9521,-1.0169,-1.0677,-1.0762,-1.0810,-1.0155,-1.0183,-0.9524,-0.5942,0.1621,1.0288,1.6288,1.8838,1.8947,1.6794,1.3318,0.9841,0.7191,0.5162,0.3285,0.1507,-0.0154,-0.1985,-0.3625,-0.5233,-0.6556,-0.7937,-0.9154,-1.0269,-1.1071,-1.1602,-1.1696,-1.1418,-1.0680,-1.0450,-1.0053,-0.7943,-0.1595,0.5459,1.2173,1.5796,1.8557,1.8893,1.7867,1.5623,1.3052,0.9894,0.6957,0.4454,0.2840,0.1159,-0.0753,-0.2834,-0.4472,-0.5744,-0.6988,-0.8263,-0.9340,-1.0200,-1.0884,-1.1461,-1.1877,-1.2084,-1.2175,-1.1728,-1.0763,-1.0812,-0.9748,-0.5586,0.2922,1.0402,1.5650,1.7920,1.9558,1.9332,1.7531,1.5007,1.2016,0.9378,0.6770,0.4527,0.2762,0.1168,-0.0452,-0.2213,-0.3531,-0.4610,-0.5556,-0.6738,-0.7797,-0.8611,-0.9315,-1.0020,-1.0779,-1.1226,-1.1486,-1.1595,-1.1606,-1.1465,-1.0825,-0.7966,-0.1147,0.7163,1.4678,1.8762,2.0371,1.9188,1.6311,1.3107,1.0236,0.7714,0.5440,0.3526,0.1949,0.0479,-0.0887,-0.2237,-0.3338,-0.4495,-0.5715,-0.7017,-0.8126,-0.9056,-0.9881,-1.0832,-1.1496,-1.1891,-1.1639,-1.0625,-0.9432,-0.8945,-0.7143,-0.1024,0.7737,1.5254,1.8308,1.8765,1.7334,1.4570,1.1337,0.8512,0.6467,0.4641,0.2975,0.1319,-0.0166,-0.1772,-0.3255,-0.4689,-0.5952,-0.7292,-0.8450,-0.9470,-1.0155,-1.0691,-1.0933,-1.0868,-0.9516,-0.9493,-0.7276,-0.1801,0.6854,1.3954,1.6719,1.7254,1.5889,1.3148,0.9823,0.7029,0.4945,0.3113,0.1323,-0.0327,-0.1748,-0.3303,-0.4868,-0.6315,-0.7396,-0.8467,-0.9409,-1.0106,-1.0768,-1.1120,-1.1232,-1.0409,-0.9668,-0.9050,-0.5287,0.2341,1.0636,1.6463,1.8035,1.7923,1.5829,1.3407,1.0977,0.8118,0.5318,0.2946,0.1286,-0.0579,-0.2498,-0.4388,-0.5757,-0.7087,-0.8260,-0.9523,-1.0450,-1.1196,-1.1497,-1.1653,-1.1699,-1.1829,-1.0498,-0.5364,0.2799,1.1249,1.6516,1.8845,1.8873,1.7041,1.4127,1.0875,0.7965,0.5300,0.2986,0.1053,-0.0560,-0.2257,-0.3977,-0.5707,-0.7084,-0.8404,-0.9615,-1.0957,-1.1583,-1.1935,-1.1824,-1.1457,-1.1101,-1.0692,-0.9919,-0.7167,-0.1740,0.5670,1.2549,1.7388,1.9519,1.9377,1.7372,1.4552,1.1682,0.8869,0.5943,0.3481,0.1584,0.0013,-0.1499,-0.2991,-0.4343,-0.5679,-0.6806,-0.7918,-0.8780,-0.9716,-1.0572,-1.1466,-1.1819,-1.1744,-1.1190,-1.0441,-0.9120,-0.8688,-0.4748,0.2072,1.1213,1.7113,1.9401,1.9397,1.7687,1.4714,1.1611,0.8781,0.6450,0.4126,0.2196,0.0540,-0.0807,-0.2300,-0.3880,-0.5466,-0.6817,-0.8037,-0.9018,-0.9984,-1.0749,-1.1407,-1.1673,-1.1641,-1.1091,-1.0658,-0.8800,-0.6699,-0.2174,-0.0367,0.4601,0.8853,1.5025,1.7526,1.8027,1.7299,1.5605,1.2811,0.9786,0.7044,0.5232,0.3470,0.1675,-0.0036,-0.1306,-0.2650,-0.3877,-0.5090,-0.6084,-0.7226,-0.8290,-0.9122,-0.9587,-0.9864,-1.0103,-1.0283,-0.9953,-0.9831,-0.9334,-0.9280,-0.8389,-0.8085,-0.5427,0.0905,0.9084,1.5709,1.8341,1.9146,1.7963,1.5712,1.2910,1.0356,0.8188,0.6085,0.4131,0.2185,0.0654,-0.0746,-0.2145,-0.3641,-0.4905,-0.6023,-0.7098,-0.7949,-0.8602,-0.8976,-0.9368,-0.9790,-1.0263,-1.0601,-1.0585,-0.9929,-0.9197,-0.8934,-0.5820,-0.0033,0.8495,1.3758,1.6934,1.7771,1.7700,1.5628,1.2534,0.9562,0.7120,0.4919,0.2770,0.1065,-0.0485,-0.2027,-0.3606,-0.5129,-0.6067,-0.6866,-0.7421,-0.8288,-0.9179,-0.9830,-1.0401,-1.0513,-1.0079,-0.9813,-0.9353,-0.8247,-0.4010,-0.1069,0.4898,0.9894,1.6220,1.7962,1.7257,1.5288,1.3169,1.0462,0.7720,0.5112,0.3459,0.1939,0.0050,-0.2020,-0.3345,-0.4689,-0.5922,-0.7290,-0.8134,-0.8942,-0.9755,-1.0423,-1.0725,-1.1050,-1.0997,-1.0037,-0.7404,-0.6626,-0.4402,0.1287,0.9602,1.6379,1.7996,1.7824,1.5706,1.2648,0.9476,0.6820,0.4924,0.3139,0.1485,-0.0451,-0.2120,-0.3790,-0.5063,-0.6320,-0.7392,-0.8419,-0.9265,-0.9984,-1.0418,-1.0621,-1.0510,-1.0458,-0.9745,-0.6929,-0.1304,0.5416,1.1707,1.5759,1.8158,1.7314,1.4575,1.1060,0.8223,0.5809,0.3580,0.1750,0.0179,-0.1880,-0.3995,-0.5799,-0.6943,-0.8105,-0.9105,-1.0006,-1.0555,-1.0929,-1.0890,-1.0717,-1.0691,-1.0835,-1.0090,-0.5589,0.1409,0.7784,1.2456,1.5164,1.7678,1.8317,1.7655,1.6013,1.3243,0.9859,0.6393,0.3173,0.0752,-0.1309,-0.2915,-0.4423,-0.5697,-0.6980,-0.8111,-0.9306,-1.0014,-1.0585,-1.0806,-1.1191,-1.1529,-1.1912,-1.2005,-1.0287,-0.8734,-0.6632,-0.1333,0.6502,1.4507,1.8187,1.9089,1.8552,1.6615,1.3969,1.0882,0.7952,0.5774,0.3760,0.1681,-0.0522,-0.2297,-0.3979,-0.5552,-0.7063,-0.8194,-0.9158,-0.9969,-1.0783,-1.1462,-1.2057,-1.2415,-1.2600,-1.2322,-1.2379,-1.1944,-1.0021,-0.4662,0.1166,0.8041,1.3121,1.7669,1.9082,1.8469,1.6798,1.4911,1.2721,1.0299,0.7710,0.5731,0.4015,0.2350,0.0763,-0.0481,-0.1557,-0.2654,-0.3696,-0.4470,-0.4934,-0.5307,-0.5684,-0.5941,-0.6083,-0.5997,-0.5896] - sig_resp = [-0.4324,-0.4442,-0.4730,-0.5232,-0.5757,-0.6074,-0.6844,-0.7499,-0.7832,-0.8396,-0.8881,-0.9281,-0.9436,-0.9487,-0.9491,-0.9269,-0.9221,-0.9240,-0.9347,-0.9437,-0.9585,-0.9540,-0.9338,-0.9078,-0.8627,-0.8189,-0.7855,-0.7440,-0.7117,-0.6426,-0.5554,-0.4623,-0.3637,-0.2596,-0.1547,-0.0324,0.0944,0.2257,0.3440,0.4705,0.6045,0.7422,0.8750,1.0062,1.1484,1.2928,1.4215,1.5167,1.5979,1.6471,1.6939,1.7510,1.8110,1.8573,1.8936,1.9008,1.9097,1.9078,1.9006,1.8969,1.9110,1.9067,1.9157,1.9061,1.8858,1.8603,1.8291,1.7905,1.7706,1.7287,1.6765,1.6192,1.5532,1.4840,1.4081,1.3216,1.2261,1.1254,1.0129,0.9080,0.7978,0.6857,0.5737,0.4790,0.3823,0.2906,0.1970,0.1181,0.0383,-0.0394,-0.1151,-0.1816,-0.2455,-0.2982,-0.3536,-0.4018,-0.4663,-0.5228,-0.5761,-0.6221,-0.6669,-0.6920,-0.7127,-0.7386,-0.7638,-0.7934,-0.8393,-0.8664,-0.8951,-0.9204,-0.9440,-0.9696,-0.9861,-1.0070,-1.0347,-1.0560,-1.0746,-1.0844,-1.1015,-1.0939,-1.1073,-1.1189,-1.1326,-1.1467,-1.1533,-1.1623,-1.1715,-1.1809,-1.1908,-1.1942,-1.2037,-1.2043,-1.2101,-1.2070,-1.1965,-1.1833,-1.1704,-1.1642,-1.1589,-1.1561,-1.1513,-1.1412,-1.1347,-1.1252,-1.1139,-1.1010,-1.0913,-1.0819,-1.0680,-1.0562,-1.0411,-1.0304,-1.0120,-0.9971,-0.9912,-0.9839,-0.9804,-0.9773,-0.9752,-0.9684,-0.9603,-0.9545,-0.9491,-0.9377,-0.9265,-0.9218,-0.9166,-0.9055,-0.8931,-0.8822,-0.8694,-0.8535,-0.8382,-0.8242,-0.8076,-0.7919,-0.7743,-0.7459,-0.7080,-0.6586,-0.6052,-0.5505,-0.4848,-0.4215,-0.3650,-0.2922,-0.2100,-0.1125,-0.0058,0.1081,0.2148,0.3224,0.4412,0.5610,0.6832,0.8138,0.9512,1.0997,1.2302,1.3378,1.4186,1.5184,1.6057,1.6866,1.7678,1.8244,1.8790,1.9368,1.9751,2.0002,2.0167,2.0295,2.0488,2.0849,2.0859,2.0811,2.0841,2.0746,2.0665,2.0609,2.0420,2.0159,1.9937,1.9644,1.9397,1.9156,1.8799,1.8457,1.8103,1.7617,1.7115,1.6584,1.5931,1.5253,1.4613,1.3873,1.3171,1.2425,1.1544,1.0590,0.9636,0.8712,0.7767,0.6814,0.5764,0.4750,0.3885,0.3049,0.2085,0.1181,0.0326,-0.0465,-0.1085,-0.1795,-0.2547,-0.3140,-0.3735,-0.4235,-0.4584,-0.5092,-0.5709,-0.6075,-0.6466,-0.6928,-0.7457,-0.7992,-0.8393,-0.8734,-0.9147,-0.9457,-0.9727,-1.0151,-1.0535,-1.0728,-1.0942,-1.1186,-1.1384,-1.1405,-1.1404,-1.1360,-1.1368,-1.1337,-1.1298,-1.1241,-1.1200,-1.1043,-1.0916,-1.0852,-1.0707,-1.0577,-1.0443,-1.0293,-1.0272,-1.0183,-1.0117,-1.0043,-0.9976,-0.9870,-0.9833,-0.9759,-0.9643,-0.9624,-0.9561,-0.9534,-0.9613,-0.9557,-0.9561,-0.9541,-0.9521,-0.9456,-0.9377,-0.9280,-0.9239,-0.9188,-0.9110,-0.8997,-0.8774,-0.8478,-0.8161,-0.7824,-0.7494,-0.7210,-0.6965,-0.6715,-0.6421,-0.6246,-0.6005,-0.5708,-0.5459,-0.5227,-0.5015,-0.4871,-0.4636,-0.4325,-0.4133,-0.3928,-0.3751,-0.3681,-0.3518,-0.3344,-0.3214,-0.3038,-0.2856,-0.2579,-0.2166,-0.1874,-0.1588,-0.1067,-0.0367,0.0487,0.1506,0.2599,0.3652,0.4692,0.5817,0.6779,0.7226,0.7310,0.7650,0.8024,0.8559,0.9108,0.9558,1.0005,1.0394,1.0869,1.1470,1.2036,1.2719,1.3737,1.4867,1.5720,1.6496,1.6978,1.7176,1.7153,1.7012,1.6804,1.6463,1.6054,1.5412,1.4662,1.3951,1.3317,1.2737,1.2064,1.1354,1.0632,1.0007,0.9330,0.8620,0.7905,0.7077,0.6304,0.5644,0.4979,0.4172,0.3288,0.2425,0.1608,0.0793,-0.0018,-0.0814,-0.1521,-0.2260,-0.3034,-0.3680,-0.4493,-0.5257,-0.5760,-0.6128,-0.6412,-0.6783,-0.7052,-0.7345,-0.7585,-0.7879,-0.8096,-0.8138,-0.8272,-0.8324,-0.8417,-0.8547,-0.8643,-0.8809,-0.8880,-0.9008,-0.9105,-0.9224,-0.9311,-0.9307,-0.9414,-0.9524,-0.9543,-0.9593,-0.9597,-0.9615,-0.9631,-0.9528,-0.9398,-0.9298,-0.9131,-0.8907,-0.8762,-0.8626,-0.8450,-0.8260,-0.8044,-0.7859,-0.7786,-0.7676,-0.7738,-0.7786,-0.7674,-0.7566,-0.7501,-0.7418,-0.7263,-0.7146,-0.7056,-0.6904,-0.6724,-0.6512,-0.6283,-0.5931,-0.5534,-0.5171,-0.4914,-0.4563,-0.4155,-0.3779,-0.3248,-0.2643,-0.1956,-0.1338,-0.0715,0.0031,0.0715,0.1554,0.2518,0.3645,0.4471,0.5403,0.6558,0.7771,0.8941,1.0039,1.1138,1.2320,1.3445,1.4598,1.5576,1.6224,1.6718,1.7437,1.8025,1.8454,1.8709,1.8890,1.9007,1.9109,1.9123,1.9020,1.8805,1.8468,1.8263,1.8030,1.7746,1.7301,1.6743,1.6189,1.5589,1.4977,1.4392,1.3745,1.3065,1.2448,1.1911,1.1365,1.0858,1.0443,1.0021,0.9629,0.9137,0.8608,0.8075,0.7466,0.6832,0.6214,0.5552,0.4891,0.4234,0.3538,0.2734,0.2010,0.1345,0.0692,0.0093,-0.0520,-0.1129,-0.1637,-0.2060,-0.2496,-0.2922,-0.3331,-0.3685,-0.3916,-0.4148,-0.4489,-0.4686,-0.4783,-0.4870,-0.4890,-0.4876,-0.4888,-0.4839,-0.4793,-0.4687,-0.4597,-0.4563,-0.4561,-0.4462,-0.4493,-0.4491,-0.4449,-0.4439,-0.4555,-0.4746,-0.4861,-0.4966,-0.5116,-0.5201,-0.5310,-0.5374,-0.5412,-0.5408,-0.5554,-0.5644,-0.5713,-0.5683,-0.5611,-0.5641,-0.5617,-0.5542,-0.5501,-0.5416,-0.5289,-0.5235,-0.5142,-0.4928,-0.4791,-0.4654,-0.4430,-0.4232,-0.3989,-0.3727,-0.3625,-0.3501,-0.3473,-0.3508,-0.3454,-0.3384,-0.3338,-0.3229,-0.3140,-0.3142,-0.3018,-0.2893,-0.2950,-0.2807,-0.2658,-0.2481,-0.2267,-0.2118,-0.1976,-0.1737,-0.1605,-0.1410,-0.1267,-0.1216,-0.1157,-0.1018,-0.0941,-0.0848,-0.0718,-0.0644,-0.0583,-0.0512,-0.0590,-0.0598] - if min_t_too_long: - target_len = 100 - sig_ppg = sig_ppg[:target_len] - sig_resp = sig_resp[:target_len] - else: - target_len = len(sig_ppg) - sig, train_signals, pred_signals = [], [], [] - for s in ('ppg_waveform','respiratory_waveform'): - if s in signals: - sig.append(sig_ppg if s == 'ppg_waveform' else sig_resp) - train_signals.append('ppg_waveform' if s == 'ppg_waveform' else 'respiratory_waveform') - pred_signals.append('ppg_waveform' if s == 'ppg_waveform' else 'respiratory_waveform') - pred_signals.append('heart_rate' if s == 'ppg_waveform' else 'respiratory_rate') - sig = np.asarray(sig) - conf = np.ones_like(sig) - if can_provide_confidence: - live = np.asarray([0.0044,0.6019,0.0187,0.2596,0.8126,0.917,0.9268,0.8473,0.6633,0.7306,0.7722,0.7069,0.4417,0.4826,0.6779,0.8098,0.4191,0.6346,0.7833,0.5752,0.715,0.8923,0.8444,0.9169,0.85,0.9615,0.9797,0.9907,0.9819,0.9908,0.9856,0.974,0.979,0.9916,0.9928,0.997,0.9906,0.9916,0.9975,0.9989,0.9968,0.9989,0.9996,0.9999,0.9995,0.9994,0.9988,0.9993,0.9974,0.9987,0.9995,0.998,0.9942,0.9984,0.9972,0.9993,0.9986,0.9987,0.9992,0.9991,0.9967,0.9988,0.9991,0.9983,0.997,0.9978,0.999,0.9986,0.9941,0.9997,0.9998,0.9997,0.9995,0.9989,0.9995,0.9981,0.9987,0.9985,0.998,0.9985,0.9938,0.9989,0.9982,0.9987,0.996,0.9965,0.998,0.9979,0.9954,0.9965,0.9919,0.9912,0.9902,0.9949,0.9954,0.9971,0.9942,0.9968,0.9996,0.9997,0.998,0.9994,0.9966,0.9934,0.9835,0.9863,0.9899,0.9938,0.993,0.9907,0.9937,0.9933,0.9809,0.9896,0.9917,0.9961,0.9913,0.9961,0.9972,0.997,0.9927,0.9938,0.9851,0.9785,0.9816,0.9868,0.9946,0.9949,0.9951,0.9973,0.9988,0.9991,0.9969,0.9975,0.9923,0.9918,0.9489,0.9819,0.9672,0.9722,0.9354,0.9632,0.9769,0.982,0.947,0.9703,0.9911,0.994,0.9864,0.9886,0.9892,0.9801,0.9805,0.9799,0.9735,0.9897,0.9816,0.9734,0.9765,0.9928,0.9753,0.9893,0.9968,0.9993,0.9974,0.998,0.9934,0.9716,0.9208,0.9156,0.9118,0.9587,0.9399,0.9594,0.985,0.9688,0.9514,0.9673,0.9799,0.9766,0.9799,0.9917,0.9873,0.9925,0.9875,0.9909,0.9939,0.9958,0.9921,0.9952,0.9988,0.9986,0.9995,0.9999,0.9998,0.9998,0.9983,0.9989,0.9984,0.9987,0.9831,0.9991,0.9989,0.9989,0.9957,0.9988,0.9988,0.9994,0.9976,0.9984,0.9991,0.9979,0.9976,0.9991,0.9993,0.9982,0.9975,0.9987,0.999,0.9988,0.9982,0.9997,0.9998,0.9997,0.9937,0.9981,0.9985,0.9966,0.9984,0.9964,0.9962,0.9976,0.9952,0.9981,0.9974,0.9987,0.9955,0.9975,0.9969,0.9942,0.9921,0.9925,0.9908,0.9955,0.9935,0.9974,0.9961,0.9982,0.9963,0.9988,0.9984,0.999,0.9967,0.9969,0.9925,0.9983,0.9958,0.996,0.9877,0.9924,0.996,0.9985,0.9805,0.9906,0.9978,0.9989,0.9907,0.9943,0.9955,0.9904,0.9796,0.9932,0.9936,0.9942,0.9977,0.9994,0.9997,0.9996,0.998,0.9961,0.9931,0.993,0.9741,0.9842,0.9874,0.9908,0.993,0.9963,0.994,0.9757,0.9825,0.9953,0.9929,0.9964,0.9901,0.9895,0.985,0.9602,0.9788,0.976,0.9867,0.9893,0.9634,0.996,0.9983,0.9994,0.997,0.9969,0.9843,0.9885,0.9561,0.9733,0.9894,0.9841,0.9685,0.9813,0.9654,0.9602,0.9312,0.9795,0.9803,0.9911,0.963,0.985,0.9835,0.9832,0.9664,0.98,0.971,0.9576,0.9514,0.9916,0.9889,0.9875,0.9966,0.9969,0.9994,0.9983,0.9821,0.971,0.9825,0.9859,0.9729,0.9824,0.9866,0.9894,0.979,0.9961,0.998,0.9969,0.9864,0.9965,0.999,0.9989,0.9966,0.9983,0.9939,0.9933,0.9912,0.9893,0.9873,0.9905,0.9783,0.9956,0.994,0.9973,0.9989,0.9993,0.9988,0.9968,0.9914,0.9823,0.9798,0.9924,0.9727,0.9976,0.9938,0.9951,0.9959,0.9982,0.9975,0.9975,0.9962,0.9965,0.9981,0.9985,0.9974,0.9982,0.9951,0.9878,0.9834,0.9906,0.992,0.9915,0.9903,0.9938,0.997,0.9983,0.9829,0.9987,0.9993,0.999,0.9877,0.9958,0.9813,0.9781,0.9662,0.9879,0.9895,0.9888,0.9533,0.9747,0.9791,0.9726,0.9817,0.989,0.9951,0.9908,0.9836,0.9914,0.9928,0.9952,0.9566,0.9848,0.9908,0.9748,0.9769,0.9924,0.9913,0.9836,0.9941,0.9971,0.9992,0.9994,0.992,0.988,0.9466,0.9786,0.9292,0.9763,0.9878,0.9888,0.9551,0.9852,0.9514,0.9689,0.9072,0.9636,0.9827,0.9254,0.5153,0.927,0.8591,0.8552,0.7245,0.864,0.8398,0.9472,0.9488,0.9813,0.9679,0.9933,0.9655,0.9905,0.9891,0.9762,0.9358,0.9778,0.9239,0.9572,0.8879,0.9795,0.9783,0.9898,0.8942,0.9667,0.9849,0.9866,0.9873,0.9833,0.9965,0.9968,0.9941,0.9964,0.9946,0.9866,0.9843,0.9951,0.9947,0.9906,0.9987,0.9993,0.9988,0.9989,0.972,0.9862,0.9909,0.9969,0.9858,0.9905,0.9946,0.9919,0.9912,0.996,0.9954,0.9982,0.9978,0.9975,0.9985,0.9988,0.9983,0.9977,0.9935,0.9942,0.9968,0.9984,0.9736,0.9983,0.9994,0.9994,0.9992,0.9985,0.999,0.9984,0.9954,0.998,0.9987,0.9988,0.9909,0.9983,0.9968,0.9974,0.9973,0.9983,0.9976,0.9961,0.9946,0.9957,0.9969,0.9965,0.9942,0.9944,0.995,0.9966,0.9983,0.9985,0.9987,0.9996,0.9996,0.9989,0.998,0.9969,0.9788,0.9939,0.9957,0.9964,0.9938,0.9981,0.9976,0.9974,0.9818,0.9947,0.9972,0.9988,0.99,0.9968,0.9959,0.9955,0.9876,0.9895,0.9921,0.9884,0.9879,0.9877,0.9978,0.9981,0.9979,0.9991,0.9987,0.9982,0.9786,0.988,0.9688,0.985,0.947,0.9824,0.9853,0.98,0.933,0.9825,0.9815,0.9875,0.9719,0.9896,0.9916,0.9955,0.9787,0.9852,0.9799,0.9653,0.9122,0.9627,0.9731,0.9746,0.9802,0.9858,0.988,0.9987,0.9979,0.9975,0.9926,0.9842,0.8907,0.9364,0.9357,0.9767,0.8572,0.9597,0.9523,0.9594,0.9192,0.9429,0.9818,0.9881,0.9765,0.9808,0.977,0.9845,0.9615,0.9823]) - live = live[:target_len] - else: - live = np.ones((sig.shape[1],)) - fps = 30. - sig_dict = {} - conf_dict = {} - sig_iter = [sig] if sig.ndim == 1 else sig - conf_iter = [conf] if conf.ndim == 1 else conf - keys = [k for k in ('ppg_waveform', 'respiratory_waveform') if k in signals] - sig_dict = dict(zip(keys, sig_iter)) - conf_dict = dict(zip(keys, conf_iter)) - method_str = method_arg.value if isinstance(method_arg, Method) else str(method_arg) - out_data, out_conf, out_live, out_note, out_live = assemble_results( - sig=sig_dict, - conf=conf_dict, - live=live, - fps=fps, - pred_signals=pred_signals, - method_name=method_str, - can_provide_confidence=can_provide_confidence - ) - if 'ppg_waveform' in signals: - np.testing.assert_allclose(out_data['ppg_waveform'], sig_ppg) - assert method_str in out_note['ppg_waveform'] - if min_t_too_long: - assert np.isnan(out_data.get('heart_rate', np.nan)) - else: - np.testing.assert_allclose(out_data['heart_rate'], sample_video_hr, atol=0.5) - if 'respiratory_waveform' in signals: - np.testing.assert_allclose(out_data['respiratory_waveform'], sig_resp) - assert method_str in out_note['respiratory_waveform'] - if min_t_too_long: - assert np.isnan(out_data.get('respiratory_rate', np.nan)) - else: - np.testing.assert_allclose(out_data['respiratory_rate'], sample_video_rr, atol=1.) - -def test_estimate_rolling_vitals(): - """Test rolling vitals estimation with signal filtering and duration checks.""" - fps = 30. - n_frames = int(60 * fps) - data = { - 'ppg_waveform': np.random.rand(n_frames), - 'respiratory_waveform': np.random.rand(n_frames) - } - conf = { - 'ppg_waveform': np.random.rand(n_frames), - 'respiratory_waveform': np.random.rand(n_frames) - } - # Test 1: Selective signal processing - signals_available = {'heart_rate'} - vital_signs = {} - estimate_rolling_vitals( - vital_signs_dict=vital_signs, - data=data, - conf=conf, - signals_available=signals_available, - fps=fps, - video_duration_s=60. - ) - assert 'rolling_heart_rate' in vital_signs - assert 'rolling_respiratory_rate' not in vital_signs - # Test 2: Duration constraints - short_n = int(5 * fps) - short_data = {k: v[:short_n] for k,v in data.items()} - short_conf = {k: v[:short_n] for k,v in conf.items()} - vital_signs_short = {} - estimate_rolling_vitals( - vital_signs_dict=vital_signs_short, - data=short_data, - conf=short_conf, - signals_available={'heart_rate', 'hrv_sdnn'}, - fps=fps, - video_duration_s=5.0 - ) - assert 'rolling_heart_rate' not in vital_signs_short - assert 'rolling_hrv_sdnn' not in vital_signs_short \ No newline at end of file diff --git a/tests/test_simple_rppg_method.py b/tests/test_simple_rppg_method.py index 50a4dae..3a713e0 100644 --- a/tests/test_simple_rppg_method.py +++ b/tests/test_simple_rppg_method.py @@ -1,99 +1,86 @@ -# Copyright (c) 2024 Rouast Labs -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - import numpy as np -import pytest +import vitallens_core as vc import sys sys.path.append('../vitallens-python') -from vitallens.enums import Mode from vitallens.methods.chrom import CHROMRPPGMethod from vitallens.methods.g import GRPPGMethod from vitallens.methods.pos import POSRPPGMethod +from vitallens.methods.simple_rppg_method import SimpleRPPGMethod -@pytest.mark.parametrize("override_fps_target", [None, 15]) -def test_CHROMRPPGMethod(request, override_fps_target): - method = CHROMRPPGMethod(mode=Mode.BATCH) - res = method.algorithm(np.random.rand(100, 3), fps=30.) - assert res.shape == (100,) - res = method.pulse_filter(res, fps=30.) +def test_CHROMRPPGMethod(): + method = CHROMRPPGMethod() + assert method.method.value == "chrom" + assert method.session_config.model_name == "chrom" + assert method.est_window_length == 48 + assert method.est_window_overlap == 47 + assert "heart_rate" in method.session_config.supported_vitals + rgb_input = np.random.rand(100, 3) + res = method.algorithm(rgb_input, fps=30.0) + assert isinstance(res, np.ndarray) assert res.shape == (100,) - test_video_ndarray = request.getfixturevalue('test_video_ndarray') - test_video_fps = request.getfixturevalue('test_video_fps') - test_video_faces = request.getfixturevalue('test_video_faces') - data, unit, conf, note, live = method( - inputs=test_video_ndarray, faces=test_video_faces, - fps=test_video_fps, override_fps_target=override_fps_target) - assert all(key in data for key in method.signals) - assert all(key in unit for key in method.signals) - assert all(key in conf for key in method.signals) - assert all(key in note for key in method.signals) - assert data['ppg_waveform'].shape == (test_video_ndarray.shape[0],) - assert conf['ppg_waveform'].shape == (test_video_ndarray.shape[0],) - np.testing.assert_equal(conf['ppg_waveform'], np.ones((test_video_ndarray.shape[0],), np.float32)) - assert conf['heart_rate'] == 1.0 - np.testing.assert_equal(live, np.ones((test_video_ndarray.shape[0],), np.float32)) -@pytest.mark.parametrize("override_fps_target", [None, 15]) -def test_GRPPGMethod(request, override_fps_target): - method = GRPPGMethod(mode=Mode.BATCH) - res = method.algorithm(np.random.rand(100, 3), fps=30.) - assert res.shape == (100,) - res = method.pulse_filter(res, fps=30.) +def test_GRPPGMethod(): + method = GRPPGMethod() + assert method.method.value == "g" + assert method.session_config.model_name == "g" + assert method.est_window_length == 64 + assert method.est_window_overlap == 0 + assert "heart_rate" in method.session_config.supported_vitals + rgb_input = np.random.rand(100, 3) + res = method.algorithm(rgb_input, fps=30.0) + assert isinstance(res, np.ndarray) assert res.shape == (100,) - test_video_ndarray = request.getfixturevalue('test_video_ndarray') - test_video_fps = request.getfixturevalue('test_video_fps') - test_video_faces = request.getfixturevalue('test_video_faces') - data, unit, conf, note, live = method( - inputs=test_video_ndarray, faces=test_video_faces, - fps=test_video_fps, override_fps_target=override_fps_target) - assert all(key in data for key in method.signals) - assert all(key in unit for key in method.signals) - assert all(key in conf for key in method.signals) - assert all(key in note for key in method.signals) - assert data['ppg_waveform'].shape == (test_video_ndarray.shape[0],) - assert conf['ppg_waveform'].shape == (test_video_ndarray.shape[0],) - np.testing.assert_equal(conf['ppg_waveform'], np.ones((test_video_ndarray.shape[0],), np.float32)) - assert conf['heart_rate'] == 1.0 - np.testing.assert_equal(live, np.ones((test_video_ndarray.shape[0],), np.float32)) -@pytest.mark.parametrize("override_fps_target", [None, 15]) -def test_POSRPPGMethod(request, override_fps_target): - method = POSRPPGMethod(mode=Mode.BATCH) - res = method.algorithm(np.random.rand(100, 3), fps=30.) - assert res.shape == (100,) - res = method.pulse_filter(res, fps=30.) +def test_POSRPPGMethod(): + method = POSRPPGMethod() + assert method.method.value == "pos" + assert method.session_config.model_name == "pos" + assert method.est_window_length == 48 + assert method.est_window_overlap == 47 + assert "heart_rate" in method.session_config.supported_vitals + rgb_input = np.random.rand(100, 3) + res = method.algorithm(rgb_input, fps=30.0) + assert isinstance(res, np.ndarray) assert res.shape == (100,) + +class DummySimpleRPPG(SimpleRPPGMethod): + def __init__(self): + super().__init__() + config = vc.SessionConfig( + model_name="dummy", + supported_vitals=["heart_rate"], + return_waveforms=["ppg_waveform"], + fps_target=30.0, + input_size=100, + n_inputs=1, + roi_method="face" + ) + self.parse_config(config, est_window_length=30, est_window_overlap=0) + def algorithm(self, rgb: np.ndarray, fps: float) -> np.ndarray: + return np.mean(rgb, axis=-1) + +def test_SimpleRPPGMethod_infer_batch(request): + method = DummySimpleRPPG() test_video_ndarray = request.getfixturevalue('test_video_ndarray') test_video_fps = request.getfixturevalue('test_video_fps') test_video_faces = request.getfixturevalue('test_video_faces') - data, unit, conf, note, live = method( - inputs=test_video_ndarray, faces=test_video_faces, - fps=test_video_fps, override_fps_target=override_fps_target) - assert all(key in data for key in method.signals) - assert all(key in unit for key in method.signals) - assert all(key in conf for key in method.signals) - assert all(key in note for key in method.signals) - assert data['ppg_waveform'].shape == (test_video_ndarray.shape[0],) - assert conf['ppg_waveform'].shape == (test_video_ndarray.shape[0],) - np.testing.assert_equal(conf['ppg_waveform'], np.ones((test_video_ndarray.shape[0],), np.float32)) - assert conf['heart_rate'] == 1.0 - np.testing.assert_equal(live, np.ones((test_video_ndarray.shape[0],), np.float32)) + sig_dict, conf_dict, live = method.infer_batch( + inputs=test_video_ndarray, faces=test_video_faces, fps=test_video_fps + ) + assert 'ppg_waveform' in sig_dict + assert sig_dict['ppg_waveform'].shape == (test_video_ndarray.shape[0],) + assert conf_dict['ppg_waveform'].shape == (test_video_ndarray.shape[0],) + assert live.shape == (test_video_ndarray.shape[0],) + +def test_SimpleRPPGMethod_infer_stream(): + method = DummySimpleRPPG() + frames = np.random.rand(16, 3) + fps = 30.0 + sig_dict, conf_dict, live, state = method.infer_stream(frames, fps, None) + assert 'ppg_waveform' in sig_dict + assert sig_dict['ppg_waveform'].shape == (16,) + assert conf_dict['ppg_waveform'].shape == (16,) + assert live.shape == (16,) + assert state is None \ No newline at end of file diff --git a/tests/test_ssd.py b/tests/test_ssd.py index 56c2afd..854eb35 100644 --- a/tests/test_ssd.py +++ b/tests/test_ssd.py @@ -1,23 +1,3 @@ -# Copyright (c) Rouast Labs -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - import numpy as np import pytest diff --git a/tests/test_stream.py b/tests/test_stream.py new file mode 100644 index 0000000..d3cc8b9 --- /dev/null +++ b/tests/test_stream.py @@ -0,0 +1,77 @@ +import numpy as np +import time +import vitallens_core as vc +from vitallens.stream import FrameBuffer, InferenceContext, BufferManager, StreamSession +from vitallens.methods.simple_rppg_method import SimpleRPPGMethod + +class DummyStreamRPPG(SimpleRPPGMethod): + def __init__(self): + super().__init__() + config = vc.SessionConfig( + model_name="dummy", + supported_vitals=["heart_rate"], + return_waveforms=["ppg_waveform"], + fps_target=30.0, + input_size=10, + n_inputs=1, + roi_method="face" + ) + self.parse_config(config, est_window_length=5, est_window_overlap=0) + self.input_size = config.input_size + + def algorithm(self, rgb: np.ndarray, fps: float) -> np.ndarray: + return np.mean(rgb, axis=-1) + + def pulse_filter(self, sig: np.ndarray, fps: float) -> np.ndarray: + return sig + +def test_frame_buffer(): + roi = vc.Rect(x=0.0, y=0.0, width=10.0, height=10.0) + buf = FrameBuffer(buffer_id="test_id", roi=roi, max_capacity=5, timestamp=1.0) + for i in range(7): + ctx = InferenceContext(timestamp=float(i), roi=roi, face_conf=1.0) + buf.append(np.zeros((10, 10, 3)), ctx) + assert buf.count == 5 + assert buf.last_seen == 6.0 + payload = buf.execute(take_count=3, keep_count=1) + assert len(payload) == 3 + assert buf.count == 3 + +def test_buffer_manager(): + config = vc.SessionConfig( + model_name="dummy", + supported_vitals=[], + fps_target=30.0, + input_size=10, + n_inputs=4, + roi_method="face" + ) + buf_config = vc.compute_buffer_config(config) + manager = BufferManager(buf_config) + roi = vc.Rect(x=0.0, y=0.0, width=10.0, height=10.0) + buf_id = manager.register_target(roi, 1.0) + assert buf_id is not None + for i in range(16): + ctx = InferenceContext(timestamp=float(i)/30.0, roi=roi, face_conf=1.0) + manager.append(buf_id, np.zeros((10, 10, 3)), ctx) + command = manager.poll() + assert command is not None + assert command.buffer_id == buf_id + payload = manager.execute(command) + assert payload is not None + assert len(payload) == command.take_count + +def test_stream_session(): + rppg = DummyStreamRPPG() + with StreamSession(rppg_method=rppg, face_detector=None) as session: + for i in range(25): + frame = np.zeros((480, 640, 3), dtype=np.uint8) + face = np.array([100, 100, 200, 200]) + session.push(frame, timestamp=float(i)/30.0, face=face) + time.sleep(0.5) + res = session.get_result(block=False) + assert res is not None + assert len(res) > 0 + assert 'vitals' in res[0] + assert 'waveforms' in res[0] + assert 'ppg_waveform' in res[0]['waveforms'] \ No newline at end of file diff --git a/tests/test_utils.py b/tests/test_utils.py index 92bf857..2f736f5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,23 +1,3 @@ -# Copyright (c) Rouast Labs -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - import numpy as np from prpy.numpy.image import parse_image_inputs, probe_image_inputs import pytest diff --git a/tests/test_vitallens.py b/tests/test_vitallens.py index a77c773..65ceba5 100644 --- a/tests/test_vitallens.py +++ b/tests/test_vitallens.py @@ -1,37 +1,16 @@ -# Copyright (c) Rouast Labs -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - import base64 import json import numpy as np from prpy.numpy.image import parse_image_inputs -from prpy.numpy.physio import CALC_HR_MIN_T, CALC_RR_MIN_T, CALC_HRV_SDNN_MIN_T, CALC_HRV_RMSSD_MIN_T, CALC_HRV_LFHF_MIN_T import pytest import requests from unittest.mock import Mock, patch +import vitallens_core as vc import sys sys.path.append('../vitallens-python') -from vitallens.constants import API_MAX_FRAMES, API_MIN_FRAMES, API_URL, API_RESOLVE_URL -from vitallens.enums import Method, Mode +from vitallens.constants import API_MAX_FRAMES, API_MIN_FRAMES, API_FILE_URL, API_RESOLVE_URL from vitallens.methods.vitallens import VitalLensRPPGMethod, _resolve_model_config from vitallens.errors import VitalLensAPIKeyError, VitalLensAPIError, VitalLensAPIQuotaExceededError @@ -62,14 +41,15 @@ def create_mock_response( def mock_resolve_config(): """Provides a default successful mock for the resolve_model_config call.""" with patch('vitallens.methods.vitallens._resolve_model_config') as mock: - mock.return_value = { - 'model': 'vitallens-2.0', - 'n_inputs': 5, - 'input_size': 40, - 'fps_target': 30, - 'roi_method': 'upper_body_cropped', - 'signals': {'heart_rate', 'respiratory_rate', 'hrv_sdnn', 'hrv_rmssd', 'hrv_lfhf', 'ppg_waveform', 'respiratory_waveform'} - } + mock.return_value = vc.SessionConfig( + model_name="vitallens-2.0", + n_inputs=5, + input_size=40, + fps_target=30.0, + roi_method="upper_body_cropped", + supported_vitals=["heart_rate", "respiratory_rate", "hrv_sdnn", "hrv_rmssd", "hrv_lfhf"], + return_waveforms=["ppg_waveform", "respiratory_waveform"] + ) yield mock def create_mock_api_response( @@ -90,7 +70,7 @@ def create_mock_api_response( api_key = headers.get("x-api-key") if api_key is None or not isinstance(api_key, str) or len(api_key) < 30: return create_mock_response( - status_code=403, json_data={"vital_signs": None, "face": None, "state": None, "message": "Error"}) + status_code=403, json_data={"vitals": None, "waveforms": None, "face": None, "state": None, "message": "Error"}) if api_key == "QUOTA_EXCEEDED": return create_mock_response( status_code=429, json_data={"message": "Quota Exceeded"}) @@ -101,12 +81,12 @@ def create_mock_api_response( video_base64 = json["video"] if video_base64 is None: return create_mock_response( - status_code=400, json_data={"vital_signs": None, "face": None, "state": None, "message": "Error"}) + status_code=400, json_data={"vitals": None, "waveforms": None, "face": None, "state": None, "message": "Error"}) try: video = np.frombuffer(base64.b64decode(video_base64), dtype=np.uint8) video = video.reshape((-1, 40, 40, 3)) except Exception as e: - return create_mock_response(status_code=422, json_data={"vital_signs": None, "face": None, "state": None, "message": f"Unprocessable video: {e}"}) + return create_mock_response(status_code=422, json_data={"vitals": None, "waveforms": None, "face": None, "state": None, "message": f"Unprocessable video: {e}"}) n_frames = video.shape[0] if "state" in json: n_frames_out = n_frames - 4 @@ -117,20 +97,23 @@ def create_mock_api_response( min_frames = 5 if "state" in json else API_MIN_FRAMES if n_frames < min_frames or n_frames > API_MAX_FRAMES: return create_mock_response(status_code=400, json_data={"message": "Incorrect number of frames."}) - vital_signs_data = { + vitals_data = { "heart_rate": {"value": 60.0, "unit": "bpm", "confidence": 0.99, "note": "Note"}, - "respiratory_rate": {"value": 15.0, "unit": "bpm", "confidence": 0.97, "note": "Note"}, + "respiratory_rate": {"value": 15.0, "unit": "bpm", "confidence": 0.97, "note": "Note"} + } + waveforms_data = { "ppg_waveform": {"data": np.random.rand(n_frames_out).tolist(), "unit": "unitless", "confidence": np.ones(n_frames_out).tolist(), "note": "Note"}, "respiratory_waveform": {"data": np.random.rand(n_frames_out).tolist(), "unit": "unitless", "confidence": np.ones(n_frames_out).tolist(), "note": "Note"} } if 'vitallens-2.0' in model: - vital_signs_data["hrv_sdnn"] = {"value": 45.0, "unit": "ms", "confidence": 0.95, "note": "Note"} - vital_signs_data["hrv_rmssd"] = {"value": 35.0, "unit": "ms", "confidence": 0.96, "note": "Note"} - vital_signs_data["hrv_lfhf"] = {"value": 1.5, "unit": "unitless", "confidence": 0.92, "note": "Note"} + vitals_data["hrv_sdnn"] = {"value": 45.0, "unit": "ms", "confidence": 0.95, "note": "Note"} + vitals_data["hrv_rmssd"] = {"value": 35.0, "unit": "ms", "confidence": 0.96, "note": "Note"} + vitals_data["hrv_lfhf"] = {"value": 1.5, "unit": "unitless", "confidence": 0.92, "note": "Note"} return create_mock_response( status_code=200, json_data={ - "vital_signs": vital_signs_data, + "vitals": vitals_data, + "waveforms": waveforms_data, "face": {"confidence": np.random.rand(n_frames_out).tolist(), "note": "Note"}, "state": {"data": np.zeros((256,), dtype=np.float32).tolist(), "note": "Note"}, "model_used": model, @@ -145,17 +128,17 @@ def test_VitalLensRPPGMethod_file_mock(mock_post, mock_resolve_config, request, test_video_path = request.getfixturevalue('test_video_path') test_video_ndarray = request.getfixturevalue('test_video_ndarray') test_video_faces = request.getfixturevalue('test_video_faces') - method = VitalLensRPPGMethod(mode=Mode.BATCH, - api_key=api_key, - requested_model_name=requested_model) - data, unit, conf, note, live = method( + method = VitalLensRPPGMethod( + api_key=api_key, + requested_model_name=requested_model + ) + data, conf, live = method.infer_batch( inputs=test_video_path, faces=test_video_faces, override_fps_target=override_fps_target, - override_global_parse=override_global_parse) - assert all(key in data for key in method.signals) - assert all(key in unit for key in method.signals) - assert all(key in conf for key in method.signals) - assert all(key in note for key in method.signals) + override_global_parse=override_global_parse + ) + assert 'ppg_waveform' in data + assert 'respiratory_waveform' in data assert data['ppg_waveform'].shape == (test_video_ndarray.shape[0],) assert conf['ppg_waveform'].shape == (test_video_ndarray.shape[0],) assert data['respiratory_waveform'].shape == (test_video_ndarray.shape[0],) @@ -172,51 +155,79 @@ def test_VitalLensRPPGMethod_ndarray_mock(mock_post, mock_resolve_config, reques test_video_ndarray = request.getfixturevalue('test_video_ndarray') test_video_fps = request.getfixturevalue('test_video_fps') test_video_faces = request.getfixturevalue('test_video_faces') - method = VitalLensRPPGMethod(mode=Mode.BATCH, - api_key=api_key, - requested_model_name=requested_model) + method = VitalLensRPPGMethod( + api_key=api_key, + requested_model_name=requested_model + ) if long: n_repeats = (API_MAX_FRAMES * 3) // test_video_ndarray.shape[0] + 1 test_video_ndarray = np.repeat(test_video_ndarray, repeats=n_repeats, axis=0) test_video_faces = np.repeat(test_video_faces, repeats=n_repeats, axis=0) - data, unit, conf, note, live = method( + data, conf, live = method.infer_batch( inputs=test_video_ndarray, faces=test_video_faces, - fps=test_video_fps, override_fps_target=override_fps_target, - override_global_parse=override_global_parse) - assert all(key in data for key in method.signals) - assert all(key in unit for key in method.signals) - assert all(key in conf for key in method.signals) - assert all(key in note for key in method.signals) + fps=test_video_fps, + override_fps_target=override_fps_target, + override_global_parse=override_global_parse + ) + assert 'ppg_waveform' in data + assert 'respiratory_waveform' in data assert data['ppg_waveform'].shape == (test_video_ndarray.shape[0],) assert conf['ppg_waveform'].shape == (test_video_ndarray.shape[0],) assert data['respiratory_waveform'].shape == (test_video_ndarray.shape[0],) assert conf['respiratory_waveform'].shape == (test_video_ndarray.shape[0],) assert live.shape == (test_video_ndarray.shape[0],) -@patch('requests.post', side_effect=create_mock_api_response) -def test_VitalLensRPPGMethod_burst_mock(mock_post, mock_resolve_config, request): +@patch('requests.Session.post') +def test_VitalLensRPPGMethod_stream_mock(mock_post, mock_resolve_config, request): api_key = request.getfixturevalue('test_dev_api_key') - test_video_ndarray = request.getfixturevalue('test_video_ndarray') - test_video_fps = request.getfixturevalue('test_video_fps') - test_video_faces = request.getfixturevalue('test_video_faces') - method = VitalLensRPPGMethod(mode=Mode.BURST, api_key=api_key, requested_model_name="vitallens") - # First call, initializes state - chunk1 = test_video_ndarray[:16] - faces1 = test_video_faces[:16] - data1, _, _, _, live1 = method(inputs=chunk1, faces=faces1, fps=test_video_fps) + method = VitalLensRPPGMethod( + api_key=api_key, + requested_model_name="vitallens" + ) + frames1 = np.random.randint(0, 255, (16, 40, 40, 3), dtype=np.uint8) + frames2 = np.random.randint(0, 255, (5, 40, 40, 3), dtype=np.uint8) + mock_resp1 = Mock() + mock_resp1.status_code = 200 + mock_resp1.json.return_value = { + "waveforms": { + "ppg_waveform": {"data": [0.1]*16, "confidence": [1.0]*16}, + "respiratory_waveform": {"data": [0.2]*16, "confidence": [0.9]*16} + }, + "face": {"confidence": [0.99]*16}, + "state": {"data": [0.5]*256} + } + mock_resp2 = Mock() + mock_resp2.status_code = 200 + mock_resp2.json.return_value = { + "waveforms": { + "ppg_waveform": {"data": [0.1]*5, "confidence": [1.0]*5}, + "respiratory_waveform": {"data": [0.2]*5, "confidence": [0.9]*5} + }, + "face": {"confidence": [0.99]*5}, + "state": {"data": [0.6]*256} + } + mock_post.side_effect = [mock_resp1, mock_resp2] + data1, conf1, live1, state1 = method.infer_stream(frames=frames1, fps=30.0, state=None) assert 'ppg_waveform' in data1 + assert 'respiratory_waveform' in data1 assert data1['ppg_waveform'].shape == (16,) + assert conf1['ppg_waveform'].shape == (16,) assert live1.shape == (16,) - assert method.state is not None - assert mock_post.call_args.kwargs['json'].get('state') is None - # Second call, should use state - chunk2 = test_video_ndarray[16:21] - faces2 = test_video_faces[16:21] - data2, _, _, _, live2 = method(inputs=chunk2, faces=faces2, fps=test_video_fps) + assert state1 is not None + assert len(state1) == 256 + args1, kwargs1 = mock_post.call_args_list[0] + assert 'X-State' not in kwargs1['headers'] + assert kwargs1['headers']['x-api-key'] == api_key + data2, conf2, live2, state2 = method.infer_stream(frames=frames2, fps=30.0, state=state1) assert 'ppg_waveform' in data2 + assert 'respiratory_waveform' in data2 assert data2['ppg_waveform'].shape == (5,) + assert conf2['ppg_waveform'].shape == (5,) assert live2.shape == (5,) - assert mock_post.call_args.kwargs['json'].get('state') is not None + assert state2 is not None + args2, kwargs2 = mock_post.call_args_list[1] + assert 'X-State' in kwargs2['headers'] + assert kwargs2['headers']['x-api-key'] == api_key @pytest.mark.parametrize("process_signals", [True, False]) @pytest.mark.parametrize("n_frames", [16, 250]) @@ -233,30 +244,39 @@ def test_VitalLens_API_valid_response(request, process_signals, n_frames): if process_signals: payload['fps'] = str(30) payload['process_signals'] = "True" - response = requests.post(API_URL, headers=headers, json=payload) + response = requests.post(API_FILE_URL, headers=headers, json=payload) response_body = json.loads(response.text) assert response.status_code == 200 - assert all(key in response_body for key in ["face", "vital_signs", "state", "message"]) - vital_signs = response_body["vital_signs"] - assert all(key in vital_signs for key in ["ppg_waveform", "respiratory_waveform"]) - ppg_waveform_data = np.asarray(response_body["vital_signs"]["ppg_waveform"]["data"]) - ppg_waveform_conf = np.asarray(response_body["vital_signs"]["ppg_waveform"]["confidence"]) - resp_waveform_data = np.asarray(response_body["vital_signs"]["respiratory_waveform"]["data"]) - resp_waveform_conf = np.asarray(response_body["vital_signs"]["respiratory_waveform"]["confidence"]) + assert all(key in response_body for key in ["face", "vitals", "waveforms", "state", "message"]) + waveforms = response_body["waveforms"] + assert all(key in waveforms for key in ["ppg_waveform", "respiratory_waveform"]) + ppg_waveform_data = np.asarray(waveforms["ppg_waveform"]["data"]) + ppg_waveform_conf = np.asarray(waveforms["ppg_waveform"]["confidence"]) + resp_waveform_data = np.asarray(waveforms["respiratory_waveform"]["data"]) + resp_waveform_conf = np.asarray(waveforms["respiratory_waveform"]["confidence"]) assert ppg_waveform_data.shape == (n_frames,) assert ppg_waveform_conf.shape == (n_frames,) assert resp_waveform_data.shape == (n_frames,) assert resp_waveform_conf.shape == (n_frames,) - if process_signals: assert "heart_rate" in vital_signs - else: assert "heart_rate" not in vital_signs - if process_signals: assert "respiratory_rate" in vital_signs - else: assert "respiratory_rate" not in vital_signs - if process_signals: assert "hrv_sdnn" in vital_signs - else: assert "hrv_sdnn" not in vital_signs - if process_signals: assert "hrv_rmssd" in vital_signs - else: assert "hrv_rmssd" not in vital_signs - if process_signals: assert "hrv_lfhf" in vital_signs - else: assert "hrv_lfhf" not in vital_signs + vitals = response_body["vitals"] + if process_signals and n_frames >= 150: + assert "heart_rate" in vitals + else: + assert "heart_rate" not in vitals + if process_signals and n_frames >= 300: + assert "respiratory_rate" in vitals + else: + assert "respiratory_rate" not in vitals + if process_signals and n_frames >= 600: + assert "hrv_sdnn" in vitals + assert "hrv_rmssd" in vitals + else: + assert "hrv_sdnn" not in vitals + assert "hrv_rmssd" not in vitals + if process_signals and n_frames >= 1650: + assert "hrv_lfhf" in vitals + else: + assert "hrv_lfhf" not in vitals live = np.asarray(response_body["face"]["confidence"]) assert live.shape == (n_frames,) state = np.asarray(response_body["state"]["data"]) @@ -271,21 +291,21 @@ def test_VitalLens_API_wrong_api_key(request): roi=test_video_faces[0].tolist(), library='prpy', scale_algorithm='bilinear') headers = {"x-api-key": "WRONG_API_KEY"} payload = {"video": base64.b64encode(frames[:16].tobytes()).decode('utf-8'), "origin": "vitallens-python"} - response = requests.post(API_URL, headers=headers, json=payload) + response = requests.post(API_FILE_URL, headers=headers, json=payload) assert response.status_code == 403 def test_VitalLens_API_no_video(request): api_key = request.getfixturevalue('test_dev_api_key') headers = {"x-api-key": api_key} payload = {"some_key": "irrelevant", "origin": "vitallens-python"} - response = requests.post(API_URL, headers=headers, json=payload) + response = requests.post(API_FILE_URL, headers=headers, json=payload) assert response.status_code == 400 def test_VitalLens_API_no_parseable_video(request): api_key = request.getfixturevalue('test_dev_api_key') headers = {"x-api-key": api_key} payload = {"video": "not_parseable", "origin": "vitallens-python"} - response = requests.post(API_URL, headers=headers, json=payload) + response = requests.post(API_FILE_URL, headers=headers, json=payload) assert response.status_code == 422 def test_resolve_model_config_errors(): @@ -302,7 +322,7 @@ def test_resolve_model_config_errors(): def test_VitalLens_API_integration(mock_resolve_config, request): api_key = request.getfixturevalue('test_dev_api_key') - method = VitalLensRPPGMethod(requested_model_name="vitallens", api_key=api_key, mode=Mode.BATCH) + method = VitalLensRPPGMethod(requested_model_name="vitallens", api_key=api_key) assert method.resolved_model == 'vitallens-2.0' assert method.input_size == 40 @@ -311,16 +331,15 @@ def test_VitalLensRPPGMethod_init_errors(mock_resolve): """Tests that exceptions from _resolve_model_config are propagated.""" mock_resolve.side_effect = VitalLensAPIKeyError("Test key error") with pytest.raises(VitalLensAPIKeyError, match="Test key error"): - VitalLensRPPGMethod(mode=Mode.BATCH, api_key="any_key", requested_model_name="vitallens") + VitalLensRPPGMethod(api_key="any_key", requested_model_name="vitallens") mock_resolve.side_effect = VitalLensAPIError("Test server error") with pytest.raises(VitalLensAPIError, match="Test server error"): - VitalLensRPPGMethod(mode=Mode.BATCH, api_key="any_key", requested_model_name="vitallens") + VitalLensRPPGMethod(api_key="any_key", requested_model_name="vitallens") -@patch('requests.post') +@patch('requests.Session.post') @patch('requests.get') def test_proxy_and_auth_offloading(mock_get, mock_post, request): """Verify that proxies are passed and headers are handled correctly.""" - # Mock config resolution success mock_get.return_value = create_mock_response(200, { 'resolved_model': 'vitallens-2.0', 'config': { @@ -330,36 +349,33 @@ def test_proxy_and_auth_offloading(mock_get, mock_post, request): 'fps_target': 30 } }) - # Mock inference success mock_post.return_value = create_mock_response(200, { - "vital_signs": {"ppg_waveform": {"data": [0.1]*16, "confidence": [1.0]*16}, "respiratory_waveform": {"data": [0.1]*16, "confidence": [1.0]*16}}, + "vitals": {}, + "waveforms": { + "ppg_waveform": {"data": [0.1]*16, "confidence": [1.0]*16}, + "respiratory_waveform": {"data": [0.1]*16, "confidence": [1.0]*16} + }, "face": {"confidence": [1.0]*16}, "state": {"data": []} }) proxies = {"https": "http://proxy:8080"} # Case 1: API Key + Proxy - # Check resolve call - method = VitalLensRPPGMethod(mode=Mode.BATCH, api_key="my_key", requested_model_name="vitallens", proxies=proxies) + method = VitalLensRPPGMethod(api_key="my_key", requested_model_name="vitallens", proxies=proxies) args, kwargs = mock_get.call_args assert kwargs['proxies'] == proxies assert kwargs['headers']['x-api-key'] == "my_key" - # Run inference test_video_ndarray = request.getfixturevalue('test_video_ndarray') faces = request.getfixturevalue('test_video_faces') - method(test_video_ndarray[:16], faces[:16], fps=30.0) - # Check inference call + method.infer_batch(test_video_ndarray[:16], faces[:16], fps=30.0) args, kwargs = mock_post.call_args assert kwargs['proxies'] == proxies assert kwargs['headers']['x-api-key'] == "my_key" # Case 2: No API Key (Auth Offloading) + Proxy - method_offload = VitalLensRPPGMethod(mode=Mode.BATCH, api_key=None, requested_model_name="vitallens", proxies=proxies) - # Check resolve call + method_offload = VitalLensRPPGMethod(api_key=None, requested_model_name="vitallens", proxies=proxies) args, kwargs = mock_get.call_args assert kwargs['proxies'] == proxies - assert 'x-api-key' not in kwargs['headers'] # Should be missing - # Run inference - method_offload(test_video_ndarray[:16], faces[:16], fps=30.0) - # Check inference call + assert 'x-api-key' not in kwargs['headers'] + method_offload.infer_batch(test_video_ndarray[:16], faces[:16], fps=30.0) args, kwargs = mock_post.call_args assert kwargs['proxies'] == proxies - assert 'x-api-key' not in kwargs['headers'] # Should be missing \ No newline at end of file + assert 'x-api-key' not in kwargs['headers'] \ No newline at end of file diff --git a/vitallens/__init__.py b/vitallens/__init__.py index fa1d9f1..8a633ef 100644 --- a/vitallens/__init__.py +++ b/vitallens/__init__.py @@ -18,4 +18,4 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from .client import VitalLens, Method, Mode +from .client import VitalLens, Method diff --git a/vitallens/buffer.py b/vitallens/buffer.py deleted file mode 100644 index 19e9edb..0000000 --- a/vitallens/buffer.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright (c) 2024 Rouast Labs -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import numpy as np -from typing import Union -import warnings - -class SignalBuffer: - """A buffer for an arbitrary float signal. - - Builds the buffer up over time. - - Returns overlaps using mean. - """ - def __init__( - self, - size: int, - ndim: int, - pad_val: float = 0, - init_t: int = 0 - ): - """Initialise the signal buffer. - Args: - size: The temporal length of the buffer. Signals are pushed through in FIFO order. - ndim: The number of dims for the signal (e.g., ndim=1 means scalar signal with time dimension) - pad_val: Constant scalar to pad empty time steps with - init_t: Time for initialisation - """ - self.size = size - self.pad_val = pad_val - self.min_increment = 1 - self.ndim = ndim - # Buffer is a nested list with up to self.size lists - # Each element is a tuple - # - 0: t_start - # - 1: t_end - # - 2: signal - self.buffer = [] - self.t = init_t - self.out = None - def update( - self, - signal: Union[list, np.ndarray], - dt: int - ) -> np.ndarray: - """Update the buffer with new signal series and amount of time steps passed. - - Support arbitrary number of signal dims, but mostly intended for scalar series or one more dim - (e.g., RGB signal over time). - Args: - signal: The signal. list or ndarray, shape (n_frames, dim1, dim2, ...) - dt: The number of time steps passed. Scalar - Returns: - out: The signal of buffer size, with overlaps averaged. Shape (self.size, dim1, dim2, ...) - """ - if not isinstance(signal, np.ndarray): signal = np.asarray(signal) - if signal.size == 1 and self.ndim == 1: signal = np.full((dt,), fill_value=signal) - assert isinstance(signal, np.ndarray), f"signal should be np.ndarray but is {type(signal)}" - assert len(signal.shape) == self.ndim - assert len(signal) >= 1 - assert dt >= self.min_increment - # Initialise self.out if necessary - if self.out is None: - self.out = np.empty((self.size, self.size,) + signal.shape[1:], signal.dtype) - self.out[:] = np.nan - # Update self.t - self.t += dt - # Update self.buffer - self.buffer.append((self.t - signal.shape[0], self.t, signal)) - # Delete old buffer elements - i = 0 - while i < len(self.buffer): - if self.buffer[i][1] <= self.t - self.size: - self.buffer.pop(0) - else: - i += 1 - return self.get() - def get(self) -> np.ndarray: - """Get the series of current buffer contents, with overlaps averaged. - Returns: - out: The signal of buffer size, with overlaps averaged. Shape (self.size, dim1, dim2, ...) - """ - # No elements yet - assert self.t > 0, "Update at least once before calling get()" - # Assign buffer elements to self.out - for i in range(len(self.buffer)): - adj_t = self.t - self.size - adj_t_start = self.buffer[i][0] - adj_t - adj_t_end = self.buffer[i][1] - adj_t - outside = 0 if adj_t_start >= 0 else abs(adj_t_start) - self.out[i][adj_t_start+outside:adj_t_end] = self.buffer[i][2][outside:] - # Reduce via mean (ignore warnings due to nan) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=RuntimeWarning) - result = np.nanmean(self.out, axis=0) - # Replace np.nan with pad_val - return np.nan_to_num(result, nan=self.pad_val) - def clear(self): - """Clear the contents""" - self.buffer.clear() - self.t = 0 - self.out = None - -class MultiSignalBuffer: - """Manages a dict of SignalBuffer instances""" - def __init__( - self, - size: int, - ndim: int, - ignore_k: list, - pad_val: float = 0 - ): - """Initialise the multi signal buffer. - Args: - size: The temporal length of each buffer. Signals are pushed through in FIFO order. - ndim: The number of dims for each signal (e.g., ndim=1 means scalar signal with time dimension) - ignore_k: List of keys to ignore in update step - pad_val: Constant scalar to pad empty time steps with - """ - self.size = size - self.ndim = ndim - self.min_increment = 1 - self.ignore_k = ignore_k - self.pad_val = pad_val - self.signals = {} - self.t = 0 - def update( - self, - signals: dict, - dt: int - ) -> dict: - """Initialise or update each of the buffers corresponding to the entries in signals dict. - Args: - signals: Dictionary of signal updates. Each entry ndarray or list of shape (n_frames, dim1, dim2, ...) - dt: The number of time steps passed. Scalar - Returns: - out: Dictionary of buffered signals, with overlaps averaged. Each entry of shape (self.size, dim1, dim2, ...) - """ - assert isinstance(signals, dict) - result = {} - for k in signals: - if k in self.ignore_k: - continue - if k not in self.signals: - # Add k to self.signals - self.signals[k] = SignalBuffer( - size=self.size, ndim=self.ndim, pad_val=self.pad_val, init_t=self.t) - # Run update - result[k] = self.signals[k].update(signals[k], dt) - self.t += dt - return result - def get(self) -> dict: - """Get the series of current buffer contents, with overlaps averaged. - Returns: - out: Dictionary of buffered signals, with overlaps averaged. Each entry of shape (self.size, dim1, dim2, ...) - """ - assert self.t > 0, "Update at least once before calling get()" - result = {} - for k in self.signals: - if k in self.ignore_k: - continue - result[k] = self.signals[k].get() - return result - def clear(self): - """Clear the contents""" - for k in self.signals: - self.signals[k].clear() - self.signals = {} - \ No newline at end of file diff --git a/vitallens/client.py b/vitallens/client.py index b56fcbc..64ca072 100644 --- a/vitallens/client.py +++ b/vitallens/client.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Rouast Labs +# Copyright (c) 2026 Rouast Labs # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -24,16 +24,16 @@ import numpy as np import os from prpy.numpy.image import probe_image_inputs -from typing import Union +from typing import Union, Callable +import vitallens_core as vc -from vitallens.constants import DISCLAIMER, API_MAX_FRAMES -from vitallens.enums import Method, Mode +from vitallens.enums import Method from vitallens.methods.g import GRPPGMethod from vitallens.methods.chrom import CHROMRPPGMethod from vitallens.methods.pos import POSRPPGMethod from vitallens.methods.vitallens import VitalLensRPPGMethod -from vitallens.signal import estimate_rolling_vitals from vitallens.ssd import FaceDetector +from vitallens.stream import StreamSession from vitallens.utils import check_faces, convert_ndarray_to_list logging.getLogger().setLevel("INFO") @@ -42,7 +42,6 @@ class VitalLens: def __init__( self, method: Union[Method, str] = "vitallens", - mode: Mode = Mode.BATCH, api_key: str = None, proxies: dict = None, detect_faces: bool = True, @@ -52,7 +51,8 @@ def __init__( fdet_score_threshold: float = 0.9, fdet_iou_threshold: float = 0.3, export_to_json: bool = True, - export_dir: str = "." + export_dir: str = ".", + mode=None # Deprecated ): """Initialises the client. Loads face detection model if necessary. @@ -66,7 +66,6 @@ def __init__( Args: method: The rPPG method to be used for inference. - mode: Operate in batch or burst mode api_key: Usage key for the VitalLens API (required for vitallens methods, unless using proxy) proxies: Dictionary mapping protocol to the URL of the proxy. detect_faces: `True` if faces need to be detected, otherwise `False`. @@ -79,35 +78,39 @@ def __init__( export_to_json: If `True`, write results to a json file. export_dir: The directory to which json files are written. """ + if mode is not None: + logging.warning("The 'mode' argument is deprecated. Use vl(video) for batch and vl.stream() for live streams.") + if isinstance(method, Method): self.method_name = method.value else: self.method_name = str(method) - self.mode = mode + if self.method_name.startswith("vitallens"): self.rppg = VitalLensRPPGMethod( - mode=mode, api_key=api_key, requested_model_name=self.method_name, proxies=proxies ) elif self.method_name == "g": - self.rppg = GRPPGMethod(mode=mode) + self.rppg = GRPPGMethod() elif self.method_name == "chrom": - self.rppg = CHROMRPPGMethod(mode=mode) + self.rppg = CHROMRPPGMethod() elif self.method_name == "pos": - self.rppg = POSRPPGMethod(mode=mode) + self.rppg = POSRPPGMethod() else: raise ValueError(f"Unknown method or model: {self.method_name}") + self.detect_faces = detect_faces self.estimate_rolling_vitals = estimate_rolling_vitals self.export_to_json = export_to_json self.export_dir = export_dir + self.fdet_fs = fdet_fs + if detect_faces: self.face_detector = FaceDetector( max_faces=fdet_max_faces, fs=fdet_fs, score_threshold=fdet_score_threshold, iou_threshold=fdet_iou_threshold) - assert not (fdet_max_faces > 1 and mode == Mode.BURST), "burst mode only supported for one face" def __call__( self, @@ -142,19 +145,28 @@ def __call__( [ { 'face': { - 'coordinates': , - 'confidence': , - 'note': + 'coordinates': [[247, 52, 444, 332], ...], + 'confidence': [0.6115, 0.9207, 0.9183, ...], + 'note': "Face detection coordinates..." }, - 'vital_signs': { + 'vitals': { 'heart_rate': { - 'value': , - 'unit': , - 'confidence': , - 'note': + 'value': 60.5, + 'unit': 'bpm', + 'confidence': 0.9242, + 'note': 'Global estimate of heart rate...' }, }, + 'waveforms': { + 'ppg_waveform': { + 'data': [0.1, 0.2, ...], + 'unit': 'unitless', + 'confidence': [0.9, 0.9, ...], + 'note': '...' + }, + + }, 'message': }, { @@ -164,11 +176,6 @@ def __call__( ] """ # Probe inputs - if self.mode == Mode.BURST and not isinstance(video, np.ndarray): - raise ValueError("Must provide `np.ndarray` inputs for burst mode.") - if self.mode == Mode.BURST and isinstance(self.rppg, VitalLensRPPGMethod): - if video.shape[0] > (API_MAX_FRAMES - self.rppg.n_inputs + 1): - raise ValueError(f"Maximum number of frames in burst mode is {API_MAX_FRAMES - self.rppg.n_inputs + 1}, but received {video.shape[0]}.") inputs_shape, fps, _ = probe_image_inputs(video, fps=fps, allow_image=False) # Warning if using long video with simple rPPG method target_fps = override_fps_target if override_fps_target is not None else self.rppg.fps_target @@ -177,10 +184,7 @@ def __call__( _, height, width, _ = inputs_shape if self.detect_faces: # Detect faces - if self.mode == Mode.BURST: - faces_rel, _ = self.face_detector(inputs=video[-1:], n_frames=1, fps=fps) - else: - faces_rel, _ = self.face_detector(inputs=video, n_frames=inputs_shape[0], fps=fps) + faces_rel, _ = self.face_detector(inputs=video, n_frames=inputs_shape[0], fps=fps) # If no faces detected: return empty list if len(faces_rel) == 0: logging.warning("No faces detected to in the video") @@ -191,53 +195,52 @@ def __call__( faces = np.transpose(faces, (1, 0, 2)) # Check if the faces are valid faces = check_faces(faces, inputs_shape) - # Get video duration for rolling vital calculations - video_duration_s = inputs_shape[0] / fps # Run separately for each face results = [] for face in faces: - # Run selected rPPG method - data, unit, conf, note, live = self.rppg( + sig, conf, live = self.rppg.infer_batch( inputs=video, faces=face, fps=fps, override_fps_target=override_fps_target, override_global_parse=override_global_parse ) - # Rounding live = np.round(live, 4) - # Parse face results + + self.rppg.session_config.estimate_rolling_vitals = self.estimate_rolling_vitals + session = vc.Session(self.rppg.session_config) + signals_input = {k: vc.SignalInput(data=v.tolist(), confidence=conf[k].tolist()) for k, v in sig.items()} + face_input = vc.FaceInput(coordinates=face.tolist(), confidence=live.tolist()) + timestamps = (np.arange(inputs_shape[0]) / fps).tolist() + session_input = vc.SessionInput(face=face_input, signals=signals_input, timestamp=timestamps) + session_result = session.process(session_input, "Global") + face_result = { 'face': { - 'coordinates': face, - 'confidence': live, - 'note': "Face detection coordinates for this face with live confidence levels." + 'coordinates': session_result.face.coordinates if session_result.face else face.tolist(), + 'confidence': session_result.face.confidence if session_result.face else live.tolist(), + 'note': session_result.face.note if session_result.face and session_result.face.note else "Face detection coordinates for this face with live confidence levels." + }, + 'vitals': { + k: {'value': v.value, 'unit': v.unit, 'confidence': v.confidence, 'note': v.note} + for k, v in session_result.vitals.items() }, - 'vital_signs': {}, - 'message': DISCLAIMER + 'waveforms': { + k: {'data': v.data, 'unit': v.unit, 'confidence': v.confidence, 'note': v.note} + for k, v in session_result.waveforms.items() + }, + 'rolling_vitals': { + k: {'data': v.data, 'unit': v.unit, 'confidence': v.confidence, 'note': v.note} + for k, v in session_result.rolling_vitals.items() + } if session_result.rolling_vitals else {}, + 'message': session_result.message, + 'fps': session_result.fps, + 'n': len(session_result.timestamp), + 'time': session_result.timestamp } - # Parse vital signs results - for name in self.rppg.signals: - if name in data and name in unit and name in conf and name in note: - is_scalar = np.ndim(data[name]) == 0 - val = np.round(data[name], 2 if is_scalar else 8) - c_val = np.round(conf[name], 4) - if is_scalar: - val, c_val = float(val), float(c_val) - face_result['vital_signs'][name] = { - ('value' if is_scalar else 'data'): val, - 'unit': unit[name], - 'confidence': c_val, - 'note': note[name] - } - if self.estimate_rolling_vitals: - estimate_rolling_vitals( - vital_signs_dict=face_result['vital_signs'], data=data, conf=conf, - signals_available=set(self.rppg.signals), fps=fps, video_duration_s=video_duration_s) + results.append(face_result) - for res in results: - res['fps'] = fps - res['n'] = inputs_shape[0] + # Export to json if self.export_to_json: os.makedirs(self.export_dir, exist_ok=True) @@ -245,6 +248,42 @@ def __call__( with open(os.path.join(self.export_dir, export_filename), 'w') as f: json.dump(convert_ndarray_to_list(results), f, indent=4) return results - def reset(self): - """Resets the client state if applicable.""" - self.rppg.reset() + + def stream(self, on_result: Callable = None) -> StreamSession: + """Returns a context manager for real-time vital sign estimation. + + This method creates a `StreamSession` that manages background inference threads, + sliding window buffers, and signal state via `vitallens-core`. It is designed + for low-latency applications like webcam feeds. + + Usage: + ```python + with vl.stream() as session: + session.push(frame, timestamp) + results = session.get_result(block=False) + ``` + + Args: + on_result: An optional callback function triggered automatically whenever + new inference results are available. The function should accept one + argument (the results list). + + Returns: + session: A `StreamSession` context manager. The session object provides: + * `push(frame, timestamp, face=None)`: Ingests a new RGB frame. + `timestamp` should be in seconds (e.g., `time.time()`). + * `get_result(block=False, timeout=None)`: Pulls the latest analysis + results from the queue. + * `current_face`: The most recently detected face coordinates. + + Note: + The results returned by `get_result()` or the callback follow the same + format as `__call__`, but represent the physiological state of the + current sliding window rather than a global file average. + """ + return StreamSession( + rppg_method=self.rppg, + face_detector=self.face_detector if self.detect_faces else None, + fdet_fs=self.fdet_fs, + on_result=on_result + ) diff --git a/vitallens/constants.py b/vitallens/constants.py index 9c3da7e..f2e7d5b 100644 --- a/vitallens/constants.py +++ b/vitallens/constants.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Rouast Labs +# Copyright (c) 2026 Rouast Labs # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -26,16 +26,20 @@ API_MIN_FRAMES = 16 API_MAX_FRAMES = 900 API_OVERLAP = 30 -API_URL = "https://api.rouast.com/vitallens-v3/file" -if 'API_URL' in os.environ: - API_URL = os.getenv('API_URL') +API_FILE_URL = "https://api.rouast.com/vitallens-v3/file" +if 'API_FILE_URL' in os.environ: + API_FILE_URL = os.getenv('API_FILE_URL') +API_STREAM_URL = "https://api.rouast.com/vitallens-v3/stream" +if 'API_STREAM_URL' in os.environ: + API_STREAM_URL = os.getenv('API_STREAM_URL') API_RESOLVE_URL = "https://api.rouast.com/vitallens-v3/resolve-model" if 'API_RESOLVE_URL' in os.environ: API_RESOLVE_URL = os.getenv('API_RESOLVE_URL') # For local development against dev endpoints, create a `.env` file in the root # of the project and set the variables, for example: -# API_URL="https://api-dev.rouast.com/vitallens-dev/file" +# API_FILE_URL="https://api-dev.rouast.com/vitallens-dev/file" +# API_STREAM_URL="https://api-dev.rouast.com/vitallens-dev/stream" # API_RESOLVE_URL="https://api-dev.rouast.com/vitallens-dev/resolve-model" # Video error message diff --git a/vitallens/enums.py b/vitallens/enums.py index 43b4096..d366ab7 100644 --- a/vitallens/enums.py +++ b/vitallens/enums.py @@ -1,4 +1,4 @@ -# Copyright (c) Rouast Labs +# Copyright (c) 2026 Rouast Labs # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -32,7 +32,3 @@ class Method(Enum): VITALLENS_1_0 = "vitallens-1.0" VITALLENS_1_1 = "vitallens-1.1" VITALLENS_2_0 = "vitallens-2.0" - -class Mode(IntEnum): - BATCH = 1 - BURST = 2 diff --git a/vitallens/methods/__init__.py b/vitallens/methods/__init__.py index c9a81eb..e99b3a3 100644 --- a/vitallens/methods/__init__.py +++ b/vitallens/methods/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Rouast Labs +# Copyright (c) 2026 Rouast Labs # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/vitallens/methods/chrom.py b/vitallens/methods/chrom.py index 4743ed0..630187d 100644 --- a/vitallens/methods/chrom.py +++ b/vitallens/methods/chrom.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Rouast Labs +# Copyright (c) 2026 Rouast Labs # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -21,34 +21,31 @@ import logging import numpy as np from prpy.constants import SECONDS_PER_MINUTE -from prpy.numpy.core import standardize, div0 -from prpy.numpy.filters import detrend, butter_bandpass -from prpy.numpy.physio import detrend_lambda_for_hr_response +from prpy.numpy.core import div0 +from prpy.numpy.filters import butter_bandpass from prpy.numpy.stride_tricks import window_view, reduce_window_view +import vitallens_core as vc -from vitallens.enums import Method, Mode +from vitallens.enums import Method from vitallens.methods.simple_rppg_method import SimpleRPPGMethod class CHROMRPPGMethod(SimpleRPPGMethod): """The CHROM algorithm by De Haan and Jeanne (2013)""" - def __init__( - self, - mode: Mode - ): - """Initialize the `CHROMRPPGMethod` - - Args: - mode: The operation mode - """ - super(CHROMRPPGMethod, self).__init__(mode=mode) + def __init__(self): + """Initialize the `CHROMRPPGMethod`""" + super(CHROMRPPGMethod, self).__init__() self.method = Method.CHROM - self.parse_config({ - 'signals': ['heart_rate', 'ppg_waveform'], - 'roi_method': 'face', - 'fps_target': 30, - 'est_window_length': 48, - 'est_window_overlap': 47 - }) + config = vc.SessionConfig( + model_name="chrom", + supported_vitals=["heart_rate"], + return_waveforms=["ppg_waveform"], + fps_target=30.0, + input_size=100, + n_inputs=48, + roi_method="face" + ) + self.parse_config(config, est_window_length=48, est_window_overlap=47) + def algorithm( self, rgb: np.ndarray, @@ -89,23 +86,3 @@ def algorithm( chrom = -1 * chrom # Return result return chrom - def pulse_filter( - self, - sig: np.ndarray, - fps: float - ) -> np.ndarray: - """Apply filters to the estimated pulse signal. - - Args: - sig: The estimated pulse signal. Shape (n_frames,) - fps: The rate at which video was sampled. Scalar - Returns: - x: The filtered pulse signal. Shape (n_frames,) - """ - # Detrend (high-pass equivalent) - Lambda = detrend_lambda_for_hr_response(fps) - sig = detrend(sig, Lambda) - # Standardize - sig = standardize(sig) - # Return - return sig diff --git a/vitallens/methods/g.py b/vitallens/methods/g.py index 95785d1..ae48c0d 100644 --- a/vitallens/methods/g.py +++ b/vitallens/methods/g.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Rouast Labs +# Copyright (c) 2026 Rouast Labs # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -19,35 +19,28 @@ # SOFTWARE. import numpy as np -from prpy.numpy.core import standardize -from prpy.numpy.filters import detrend, moving_average -from prpy.numpy.physio import detrend_lambda_for_hr_response -from prpy.numpy.physio import moving_average_size_for_hr_response +import vitallens_core as vc -from vitallens.enums import Method, Mode +from vitallens.enums import Method from vitallens.methods.simple_rppg_method import SimpleRPPGMethod class GRPPGMethod(SimpleRPPGMethod): """The G algorithm by Verkruysse (2008)""" - def __init__( - self, - mode: Mode - ): - """Initialize the `GRPPGMethod` - - Args: - config: The configuration dict - mode: The operation mode - """ - super(GRPPGMethod, self).__init__(mode=mode) + def __init__(self): + """Initialize the `GRPPGMethod`""" + super(GRPPGMethod, self).__init__() self.method = Method.G - self.parse_config({ - 'signals': ['heart_rate', 'ppg_waveform'], - 'roi_method': 'face', - 'fps_target': 30, - 'est_window_length': 64, - 'est_window_overlap': 0 - }) + config = vc.SessionConfig( + model_name="g", + supported_vitals=["heart_rate"], + return_waveforms=["ppg_waveform"], + fps_target=30.0, + input_size=100, + n_inputs=64, + roi_method="face" + ) + self.parse_config(config, est_window_length=64, est_window_overlap=0) + def algorithm( self, rgb: np.ndarray, @@ -67,26 +60,3 @@ def algorithm( g = -1 * g # Return return g - def pulse_filter(self, - sig: np.ndarray, - fps: float - ) -> np.ndarray: - """Apply filters to the estimated pulse signal. - - Args: - sig: The estimated pulse signal. Shape (n_frames,) - fps: The rate at which signal was sampled. - Returns: - out: The filtered pulse signal. Shape (n_frames,) - """ - # Detrend (high-pass equivalent) - Lambda = detrend_lambda_for_hr_response(fps) - sig = detrend(sig, Lambda) - # Moving average (low-pass equivalent) - size = moving_average_size_for_hr_response(fps) - sig = moving_average(sig, size) - # Standardize - sig = standardize(sig) - # Return - return sig - \ No newline at end of file diff --git a/vitallens/methods/pos.py b/vitallens/methods/pos.py index 2143cd0..8e163e3 100644 --- a/vitallens/methods/pos.py +++ b/vitallens/methods/pos.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Rouast Labs +# Copyright (c) 2026 Rouast Labs # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -20,36 +20,30 @@ import logging import numpy as np -from prpy.numpy.core import standardize, div0 -from prpy.numpy.filters import detrend, moving_average -from prpy.numpy.physio import detrend_lambda_for_hr_response -from prpy.numpy.physio import moving_average_size_for_hr_response +from prpy.numpy.core import div0 from prpy.numpy.stride_tricks import window_view, reduce_window_view +import vitallens_core as vc -from vitallens.enums import Method, Mode +from vitallens.enums import Method from vitallens.methods.simple_rppg_method import SimpleRPPGMethod class POSRPPGMethod(SimpleRPPGMethod): """The POS algorithm by Wang et al. (2017)""" - def __init__( - self, - mode: Mode - ): - """Initialize the `POSRPPGMethod` - - Args: - config: The configuration dict - mode: The operation mode - """ - super(POSRPPGMethod, self).__init__(mode=mode) + def __init__(self): + """Initialize the `POSRPPGMethod`""" + super(POSRPPGMethod, self).__init__() self.method = Method.POS - self.parse_config({ - 'signals': ['heart_rate', 'ppg_waveform'], - 'roi_method': 'face', - 'fps_target': 30, - 'est_window_length': 48, - 'est_window_overlap': 47 - }) + config = vc.SessionConfig( + model_name="pos", + supported_vitals=["heart_rate"], + return_waveforms=["ppg_waveform"], + fps_target=30.0, + input_size=100, + n_inputs=48, + roi_method="face" + ) + self.parse_config(config, est_window_length=48, est_window_overlap=47) + def algorithm( self, rgb: np.ndarray, @@ -89,27 +83,3 @@ def algorithm( pos = -1 * pos # Return return pos - def pulse_filter( - self, - sig: np.ndarray, - fps: float - ) -> np.ndarray: - """Apply filters to the estimated pulse signal. - - Args: - sig: The estimated pulse signal. Shape (n_frames,) - fps: The rate at which signal was sampled. Scalar - Returns: - out: The filtered pulse signal. Shape (n_frames,) - """ - # Detrend (high-pass equivalent) - Lambda = detrend_lambda_for_hr_response(fps) - sig = detrend(sig, Lambda) - # Moving average (low-pass equivalent) - size = moving_average_size_for_hr_response(fps) - sig = moving_average(sig, size) - # Standardize - sig = standardize(sig) - # Return - return sig - \ No newline at end of file diff --git a/vitallens/methods/rppg_method.py b/vitallens/methods/rppg_method.py index c25e406..4872804 100644 --- a/vitallens/methods/rppg_method.py +++ b/vitallens/methods/rppg_method.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Rouast Labs +# Copyright (c) 2026 Rouast Labs # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -19,38 +19,30 @@ # SOFTWARE. import abc - -from vitallens.enums import Mode +import vitallens_core as vc class RPPGMethod(metaclass=abc.ABCMeta): """Abstract superclass for rPPG methods""" - def __init__( - self, - mode: Mode - ): - """Initialize the `RPPGMethod` - - Args: - mode: The operation mode - """ - self.op_mode = mode - def parse_config( - self, - config: dict - ): + def __init__(self): + """Initialize the `RPPGMethod`""" + pass + + def parse_config(self, config: vc.SessionConfig): """Set properties based on the config. Args: - config: The method's config dict + config: The method's config """ - self.roi_method = config['roi_method'] - self.fps_target = config['fps_target'] + self.session_config = config + self.roi_method = config.roi_method + self.fps_target = config.fps_target + @abc.abstractmethod - def __call__(self, frames, faces, fps, override_fps_target, override_global_parse): - """Run inference. Abstract method to be implemented in subclasses.""" + def infer_batch(self, frames, faces, fps, override_fps_target, override_global_parse): + """Run batch inference. Abstract method to be implemented in subclasses.""" pass + @abc.abstractmethod - def reset(self): - """Reset. Abstract method to be implemented in subclasses.""" + def infer_stream(self, frames, fps, state): + """Run stream inference. Abstract method to be implemented in subclasses.""" pass - \ No newline at end of file diff --git a/vitallens/methods/simple_rppg_method.py b/vitallens/methods/simple_rppg_method.py index e186bb9..091b51c 100644 --- a/vitallens/methods/simple_rppg_method.py +++ b/vitallens/methods/simple_rppg_method.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Rouast Labs +# Copyright (c) 2026 Rouast Labs # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -24,64 +24,49 @@ from prpy.numpy.image import reduce_roi, parse_image_inputs from prpy.numpy.interp import interpolate_filtered from typing import Union, Tuple +import vitallens_core as vc -from vitallens.buffer import SignalBuffer -from vitallens.enums import Mode from vitallens.methods.rppg_method import RPPGMethod -from vitallens.signal import assemble_results from vitallens.utils import merge_faces class SimpleRPPGMethod(RPPGMethod): """A simple rPPG method using a handcrafted algorithm based on RGB signal trace""" - def __init__( - self, - mode: Mode - ): - """Initialize the `SimpleRPPGMethod` - - Args: - mode: The operation mode - """ - super(SimpleRPPGMethod, self).__init__(mode=mode) + def __init__(self): + """Initialize the `SimpleRPPGMethod`""" + super(SimpleRPPGMethod, self).__init__() self.n_inputs = 1 + def parse_config( self, - config: dict + config: vc.SessionConfig, + est_window_length: int, + est_window_overlap: int ): """Set properties based on the config. Args: - config: The method's config dict + config: The method's config object + est_window_length: The length of the estimation window + est_window_overlap: The overlap of consecutive estimation windows """ super(SimpleRPPGMethod, self).parse_config(config=config) - self.signals = config['signals'] - self.est_window_length = config['est_window_length'] - self.est_window_overlap = config['est_window_overlap'] + self.signals = config.supported_vitals + (config.return_waveforms or []) + self.est_window_length = est_window_length + self.est_window_overlap = est_window_overlap self.est_window_flexible = self.est_window_length == 0 - if self.op_mode == Mode.BURST: - self.buffer = SignalBuffer(size=self.est_window_length, ndim=2) + @abc.abstractmethod - def algorithm( - self, - rgb: np.ndarray, - fps: float - ): + def algorithm(self, rgb: np.ndarray, fps: float): """The algorithm. Abstract method to be implemented by subclasses.""" pass - @abc.abstractmethod - def pulse_filter(self, - sig: np.ndarray, - fps: float - ) -> np.ndarray: - """The post-processing filter to be applied to estimated pulse signal. Abstract method to be implemented by subclasses.""" - pass - def __call__( + + def infer_batch( self, inputs: Union[np.ndarray, str], faces: np.ndarray, fps: float, override_fps_target: float = None, - override_global_parse: float = None, + override_global_parse: bool = None, ) -> Tuple[dict, dict, dict, dict, np.ndarray]: """Estimate pulse signal from video frames using the subclass algorithm. @@ -104,9 +89,9 @@ def __call__( u_roi = merge_faces(faces) faces = faces - [u_roi[0], u_roi[1], u_roi[0], u_roi[1]] # Parse the inputs + target_fps = override_fps_target if override_fps_target is not None else self.fps_target frames_ds, fps, inputs_shape, ds_factor, _ = parse_image_inputs( - inputs=inputs, fps=fps, roi=u_roi, target_size=None, - target_fps=override_fps_target if override_fps_target is not None else self.fps_target, + inputs=inputs, fps=fps, roi=u_roi, target_size=None, target_fps=target_fps, preserve_aspect_ratio=False, library='prpy', scale_algorithm='bilinear', trim=None, allow_image=False, videodims=True) assert inputs_shape[0] == faces.shape[0], "Need same number of frames as face detections" @@ -114,14 +99,8 @@ def __call__( assert frames_ds.shape[0] == faces_ds.shape[0], "Need same number of frames as face detections" fps_ds = fps*1.0/ds_factor # Extract rgb signal (n_frames_ds, 3) - if self.op_mode == Mode.BATCH: - roi_ds = np.asarray([get_roi_from_det(f, roi_method=self.roi_method) for f in faces_ds], dtype=np.int64) # roi for each frame (n, 4) - rgb_ds = reduce_roi(video=frames_ds, roi=roi_ds) - else: - # Use the last face detection for cropping (n_frames, 3) - rgb_ds = reduce_roi(video=frames_ds, roi=np.asarray(get_roi_from_det(faces_ds[-1], roi_method=self.roi_method), dtype=np.int64)) - # Push to buffer and get buffered vals (pred_window_length, 3) - rgb_ds = self.buffer.update(rgb_ds, dt=inputs.shape[0]) + roi_ds = np.asarray([get_roi_from_det(f, roi_method=self.roi_method) for f in faces_ds], dtype=np.int64) # roi for each frame (n, 4) + rgb_ds = reduce_roi(video=frames_ds, roi=roi_ds) # Perform rppg algorithm step (n_frames_ds,) sig_ds = self.algorithm(rgb_ds, fps_ds) # Interpolate to original sampling rate (n_frames,) @@ -130,24 +109,35 @@ def __call__( t_out=np.arange(inputs_shape[0]), axis=0, extrapolate=True) # Filter (n_frames,) - sig = self.pulse_filter(sig, fps) + # sig = self.pulse_filter(sig, fps) # Package into dict sig_dict = {'ppg_waveform': sig} conf_dict = {'ppg_waveform': np.ones_like(sig)} live = np.ones_like(sig) - # Resolve method name - method_name = self.method.value if hasattr(self.method, 'value') else str(self.method) # Assemble and return the results - return assemble_results( - sig=sig_dict, - conf=conf_dict, - live=live, - fps=fps, - pred_signals=self.signals, - method_name=method_name, - can_provide_confidence=False - ) - def reset(self): - """Reset""" - if self.op_mode == Mode.BURST: - self.buffer.clear() + return sig_dict, conf_dict, live + + def infer_stream( + self, + frames: np.ndarray, + fps: float, + state + ): + """Estimate pulse signal from a sequence of frames in a streaming context. + + Args: + frames: The input video frames of shape (n_frames, h, w, 3). + fps: The sampling frequency of the input frames. + state: The internal state of the rPPG method (unused for simple methods). + Returns: + Tuple of + - sig_dict: A dictionary of the estimated signals. + - conf_dict: A dictionary of the estimated confidences. + - live_out: Dummy live confidence estimation (set to always 1). Shape (n_frames,) + - state: The updated internal state of the rPPG method (None). + """ + sig = self.algorithm(frames, fps) + sig_dict = {'ppg_waveform': sig} + conf_dict = {'ppg_waveform': np.ones_like(sig)} + live_out = np.ones_like(sig) + return sig_dict, conf_dict, live_out, None diff --git a/vitallens/methods/vitallens.py b/vitallens/methods/vitallens.py index f6da1bc..006c9a3 100644 --- a/vitallens/methods/vitallens.py +++ b/vitallens/methods/vitallens.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Rouast Labs +# Copyright (c) 2026 Rouast Labs # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -20,34 +20,32 @@ import base64 import concurrent.futures +import gzip import math import numpy as np -from prpy.numpy.core import standardize from prpy.numpy.face import get_roi_from_det -from prpy.numpy.filters import detrend, detrend_lambda_for_cutoff -from prpy.numpy.filters import moving_average, moving_average_size_for_response from prpy.numpy.image import probe_image_inputs, parse_image_inputs from prpy.numpy.interp import interpolate_filtered -from prpy.numpy.physio import VITAL_REGISTRY, get_vital_key_from_alias from prpy.numpy.utils import enough_memory_for_ndarray import json import logging import requests import os from typing import Union, Tuple +import vitallens_core as vc -from vitallens.constants import API_MAX_FRAMES, API_URL, API_OVERLAP, API_RESOLVE_URL -from vitallens.enums import Mode +from vitallens.constants import API_MAX_FRAMES, API_OVERLAP +from vitallens.constants import API_FILE_URL, API_STREAM_URL, API_RESOLVE_URL from vitallens.errors import VitalLensAPIKeyError, VitalLensAPIQuotaExceededError, VitalLensAPIError from vitallens.methods.rppg_method import RPPGMethod -from vitallens.signal import reassemble_from_windows, assemble_results +from vitallens.signal import reassemble_from_windows from vitallens.utils import check_faces_in_roi def _resolve_model_config( api_key: str, model_name: str, proxies: dict = None - ) -> dict: + ) -> vc.SessionConfig: """Calls the /resolve-model endpoint to get the correct model config. Args: @@ -55,7 +53,7 @@ def _resolve_model_config( model_name: The requested model name (e.g., 'vitallens', 'vitallens-2.0') proxies: Dictionary mapping protocol to the URL of the proxy Returns: - resolved_config: The config of the resolved model + resolved_config: The SessionConfig of the resolved model """ headers = {} if api_key: @@ -67,12 +65,16 @@ def _resolve_model_config( response = requests.get(API_RESOLVE_URL, headers=headers, params=params, proxies=proxies) response.raise_for_status() data = response.json() - resolved_config = data['config'] - resolved_config['model'] = data['resolved_model'] - if 'supported_vitals' in resolved_config: - vitals = resolved_config.pop('supported_vitals') - resolved_config['signals'] = {get_vital_key_from_alias(c) for c in vitals} - return resolved_config + config_dict = data['config'] + return vc.SessionConfig( + model_name=data['resolved_model'], + supported_vitals=config_dict.get('supported_vitals', []), + fps_target=float(config_dict.get('fps_target', 30.0)), + input_size=int(config_dict.get('input_size', 40)), + n_inputs=int(config_dict.get('n_inputs', 4)), + roi_method=config_dict.get('roi_method', 'upper_body_cropped'), + return_waveforms=['ppg_waveform', 'respiratory_waveform'] + ) except requests.exceptions.HTTPError as e: error_msg = f"Failed to resolve model config: Status {e.response.status_code}" try: @@ -92,7 +94,6 @@ class VitalLensRPPGMethod(RPPGMethod): """RPPG method using the VitalLens API for inference""" def __init__( self, - mode: Mode, api_key: str, requested_model_name: str, proxies: dict = None @@ -100,43 +101,37 @@ def __init__( """Initialize the `VitalLensRPPGMethod` Args: - mode: The operation mode api_key: The API key requested_model_name: The requested VitalLens model name, proxies: Dictionary mapping protocol to the URL of the proxy """ - super(VitalLensRPPGMethod, self).__init__(mode=mode) - + super(VitalLensRPPGMethod, self).__init__() if proxies is None and (api_key is None or api_key == ''): raise VitalLensAPIKeyError() self.api_key = api_key self.proxies = proxies - - # Resolve model config - resolved_config = _resolve_model_config(api_key=api_key, model_name=requested_model_name, proxies=proxies) - self.parse_config(resolved_config) - - self.resolved_model = resolved_config['model'] + self.session_config = _resolve_model_config(api_key=api_key, model_name=requested_model_name, proxies=proxies) + self.parse_config(self.session_config) + self.resolved_model = self.session_config.model_name self.requested_model_name = requested_model_name if requested_model_name != "vitallens" else None + self.http_session = requests.Session() - if mode == Mode.BURST: - self.state = None - self.input_buffer = None - - def parse_config(self, config: dict): + def parse_config(self, config: vc.SessionConfig): """Set properties based on the config. Args: - config: The method's config dict + config: The method's SessionConfig """ super(VitalLensRPPGMethod, self).parse_config(config=config) - self.n_inputs = int(config['n_inputs']) - self.input_size = int(config['input_size']) - self.signals = config.get('signals', set()) + self.n_inputs = config.n_inputs + self.input_size = config.input_size + vitals = [vc.get_vital_info(v).id for v in config.supported_vitals if vc.get_vital_info(v) is not None] + waveforms = config.return_waveforms or [] + self.signals = set(vitals + waveforms) self.est_window_length = 0 self.est_window_overlap = 0 - def __call__( + def infer_batch( self, inputs: Union[np.ndarray, str], faces: np.ndarray, @@ -156,16 +151,13 @@ def __call__( If None, choose based on video. Returns: Tuple of - - out_data: The estimated data/value for each signal. - - out_unit: The estimation unit for each signal. - - out_conf: The estimation confidence for each signal. - - out_note: An explanatory note for each signal. + - sig_dict: A dictionary of the estimated signals. + - conf_dict: A dictionary of the estimated confidences. - live: The face live confidence. Shape (n_frames,) """ inputs_shape, fps, video_issues = probe_image_inputs(inputs, fps=fps) - # Check the number of frames to be processed - inputs_n = inputs_shape[0] fps_target = override_fps_target if override_fps_target is not None else self.fps_target + inputs_n = inputs_shape[0] expected_ds_factor = max(round(fps / fps_target), 1) expected_ds_n = math.ceil(inputs_n / expected_ds_factor) # Check if we should parse the video globally @@ -219,29 +211,73 @@ def stack_dicts(dict_list): sig = interpolate_filtered(t_in=idxs, s_in=rec_sig, t_out=np.arange(inputs_n), axis=1, extrapolate=True) conf = interpolate_filtered(t_in=idxs, s_in=rec_conf, t_out=np.arange(inputs_n), axis=1, extrapolate=True) # Unpack - sig_ds = dict(zip(keys, sig)) - conf_ds = dict(zip(keys, conf)) + sig_dict = dict(zip(keys, sig)) + conf_dict = dict(zip(keys, conf)) else: - sig_ds, conf_ds = {}, {} + sig_dict, conf_dict = {}, {} # Aggregate liveness live_stacked = np.array(live_results)[:, np.newaxis, :] rec_live, _ = reassemble_from_windows(x=live_stacked, idxs=idxs_results) live = interpolate_filtered(t_in=idxs, s_in=rec_live[0], t_out=np.arange(inputs_n), axis=0, extrapolate=True) - # Postprocess - if self.op_mode == Mode.BATCH: - for key in sig_ds: - if 'waveform' in key: - sig_ds[key] = self.postprocess(sig_ds[key], fps, type=key) # Assemble and return the results - return assemble_results( - sig=sig_ds, - conf=conf_ds, - live=live, - fps=fps, - pred_signals=list(self.signals), - method_name=self.resolved_model, - can_provide_confidence=True - ) + return sig_dict, conf_dict, live + + def infer_stream(self, frames: np.ndarray, fps: float, state=None): + """Estimate vitals from a sequence of frames using the VitalLens streaming API. + + Args: + frames: The input video frames of shape (n_frames, h, w, 3). + fps: The sampling frequency of the input frames. + state: The internal state of the rPPG method used to maintain temporal continuity. + Returns: + Tuple of + - sig_dict: A dictionary of the estimated signals. + - conf_dict: A dictionary of the estimated confidences. + - live: The face live confidence. Shape (n_frames,) + - new_state: The updated internal state of the rPPG method. + """ + headers = { + "Content-Type": "application/octet-stream", + "X-Encoding": "gzip" + } + if self.api_key: + headers["x-api-key"] = self.api_key + origin = os.getenv('VITALLENS_API_ORIGIN', 'vitallens-python') + headers["X-Origin"] = origin + if self.requested_model_name: + headers["X-Model"] = self.requested_model_name + if state is not None: + state_bytes = np.asarray(state, dtype=np.float32).tobytes() + headers["X-State"] = base64.b64encode(state_bytes).decode('utf-8') + + # Compress the raw video bytes + raw_rgb_bytes = frames.tobytes() + compressed_data = gzip.compress(raw_rgb_bytes) + + # Post the binary payload + response = self.http_session.post(API_STREAM_URL, headers=headers, data=compressed_data, proxies=self.proxies) + + if response.status_code != 200: + response_body = response.json() + logging.error(f"Error {response.status_code}: {response_body.get('message')}") + if response.status_code == 403: raise VitalLensAPIKeyError() + elif response.status_code == 429: raise VitalLensAPIQuotaExceededError() + else: raise VitalLensAPIError(f"API Error: {response_body.get('message')}") + + response_body = response.json() + + api_waveforms = response_body.get("waveforms", {}) + sig_dict, conf_dict = {}, {} + for name, obj in api_waveforms.items(): + if 'data' in obj: + sig_dict[name] = np.asarray(obj['data']) + conf_dict[name] = np.asarray(obj.get('confidence', [1.0] * len(sig_dict[name]))) + + n_res = len(next(iter(sig_dict.values()))) if sig_dict else frames.shape[0] + live = np.asarray(response_body.get("face", {}).get("confidence", [1.0] * n_res)) + new_state = response_body.get("state", {}).get("data") + + return sig_dict, conf_dict, live, new_state def process_api_batch( self, @@ -301,16 +337,6 @@ def process_api_batch( else: idxs = list(range(0, inputs_shape[0], ds_factor)) else: - # Buffer inputs for burst mode - if self.op_mode == Mode.BURST: - # Inputs in burst mode are always np.ndarray - if self.state is not None: - # State has been initialized - assert self.input_buffer is not None - if inputs.shape[1:] != self.input_buffer.shape[1:]: - raise ValueError("In burst mode, input dimensions must be consistent.") - inputs = np.concatenate([self.input_buffer, inputs], axis=0) - self.input_buffer = inputs[-(self.n_inputs-1):] # Inputs have not been parsed globally. Parse the inputs frames_ds, _, _, ds_factor, idxs = parse_image_inputs( inputs=inputs, fps=fps, roi=roi, target_size=self.input_size, target_fps=fps_target, @@ -320,7 +346,6 @@ def process_api_batch( # Make sure we have the correct number of frames idxs = np.asarray(idxs) expected_n = math.ceil(((end-start) if start is not None and end is not None else inputs_shape[0]) / ds_factor) - if (self.op_mode == Mode.BURST and self.state is not None): expected_n += (self.n_inputs - 1) if frames_ds.shape[0] != expected_n or idxs.shape[0] != expected_n: raise ValueError("Unexpected number of frames returned. Try to set `override_global_parse` to `True` or `False`.") # Prepare API header and payload @@ -331,15 +356,8 @@ def process_api_batch( payload = {"video": base64.b64encode(frames_ds.tobytes()).decode('utf-8'), "origin": origin} if self.requested_model_name: payload['model'] = self.requested_model_name - if self.op_mode == Mode.BURST and self.state is not None: - # State and frame buffer have been initialized - assert self.input_buffer is not None - payload["state"] = base64.b64encode(self.state.astype(np.float32).tobytes()).decode('utf-8') - # Adjust idxs - idxs = idxs[(self.n_inputs-1):] - (self.n_inputs-1) - logging.debug(f"Providing state, which means that {self.n_inputs-1} less frames will be used and results for {self.n_inputs-1} less frames will be returned.") # Ask API to process video - response = requests.post(API_URL, headers=headers, json=payload, proxies=self.proxies) + response = self.http_session.post(API_FILE_URL, headers=headers, json=payload, proxies=self.proxies) response_body = json.loads(response.text) # Check if call was successful if response.status_code != 200: @@ -353,10 +371,10 @@ def process_api_batch( else: raise Exception(f"Error {response.status_code}: {response_body['message']}") # Dynamic dict parsing - api_vitals = response_body.get("vital_signs", {}) + api_waveforms = response_body.get("waveforms", {}) sig_ds = {} conf_ds = {} - for name, obj in api_vitals.items(): + for name, obj in api_waveforms.items(): if 'data' in obj: sig_ds[name] = np.asarray(obj['data']) c_val = obj.get('confidence') @@ -364,52 +382,5 @@ def process_api_batch( conf_ds[name] = np.asarray(c_val) else: conf_ds[name] = np.zeros_like(sig_ds[name]) - live_ds = np.asarray(response_body["face"]["confidence"]) - if self.op_mode == Mode.BURST: - self.state = np.asarray(response_body["state"]["data"], dtype=np.float32) + live_ds = np.asarray(response_body.get("face", {}).get("confidence", [1.0] * frames_ds.shape[0])) return sig_ds, conf_ds, live_ds, idxs - - def postprocess( - self, - sig: np.ndarray, - fps: float, - type: str, - filter: bool = True - ) -> np.ndarray: - """Apply filters to the estimated signal. - Args: - sig: The estimated signal. Shape (n_frames,) - fps: The rate at which video was sampled. Scalar - type: The vital registry key (e.g. 'ppg_waveform'). - filter: Whether to apply filters - Returns: - x: The filtered signal. Shape (n_frames,) - """ - n_frames = sig.shape[-1] - meta = VITAL_REGISTRY.get(type) - if not meta or not filter: - return sig - proc = meta.get('processing') - if not proc: - return sig - constraints = proc.get('constraints', {}) - fmin = constraints.get('fmin') - fmax = constraints.get('fmax') - processed_sig = sig.copy() - if proc.get('method') == 'detrend' and fmin: - if processed_sig.shape[-1] < 4 * API_MAX_FRAMES: - lambda_val = detrend_lambda_for_cutoff(fps, fmin) - processed_sig = detrend(processed_sig, lambda_val) - if fmax: - size = moving_average_size_for_response(fps, fmax) - processed_sig = moving_average(processed_sig, size) - if proc.get('standardize', False): - processed_sig = standardize(processed_sig) - assert processed_sig.shape == (n_frames,) - return processed_sig - - def reset(self): - """Reset""" - if self.op_mode == Mode.BURST: - self.state = None - self.input_buffer = None diff --git a/vitallens/signal.py b/vitallens/signal.py index 55c0a1d..1420f91 100644 --- a/vitallens/signal.py +++ b/vitallens/signal.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Rouast Labs +# Copyright (c) 2026 Rouast Labs # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -18,11 +18,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import logging import numpy as np -from prpy.numpy.physio import VITAL_REGISTRY, EScope -from prpy.numpy.rolling import rolling_calc -from typing import Tuple, Dict +from typing import Tuple def reassemble_from_windows( x: np.ndarray, @@ -52,185 +49,3 @@ def reassemble_from_windows( result = x.reshape(x.shape[0], -1)[:,mask] idxs = idxs.flatten()[mask] return result, idxs - -def assemble_results( - sig: Dict[str, np.ndarray], - conf: Dict[str, np.ndarray], - live: np.ndarray, - fps: float, - pred_signals: list, - method_name: str, - can_provide_confidence: bool = True - ) -> Tuple[dict, dict, dict, dict, np.ndarray]: - """Assemble rPPG method results in the format expected by the API. - Args: - sig: Dict of estimated signal arrays {name: array}. Each array shape (n_frames,) - conf: Dict of estimation confidence arrays {name: array}. Each array shape (n_frames,) - live: The liveness confidence. Shape (n_frames,) - fps: The sampling rate - pred_signals: The pred signals specs of the method - method_name: The name of the used method - can_provide_confidence: Whether the method can provide a confidence estimate - Returns: - Tuple of - - out_data: The estimated data/value for each signal. - - out_unit: The estimation unit for each signal. - - out_conf: The estimation confidence for each signal. - - out_note: An explanatory note for each signal. - - live: The face live confidence. Shape (n_frames,) - """ - # Infer the signal length in seconds - sig_len = len(next(iter(sig.values()))) if sig else 0 - sig_t = sig_len / fps if fps > 0 else 0 - out_data, out_unit, out_conf, out_note = {}, {}, {}, {} - # Standard confidence strings - conf_txt_scalar = ', with confidence level.' if can_provide_confidence else '.' - conf_txt_wave = ', with frame-wise confidence levels.' if can_provide_confidence else '.' - for name in pred_signals: - meta = VITAL_REGISTRY.get(name) - if not meta: - continue - vital_type = meta.get('type') - unit = meta.get('unit', '') - display_name = meta.get('display_name', name) - # --- CASE A: PROVIDED --- - if vital_type == 'provided': - if name in sig: - out_data[name] = sig[name] - out_conf[name] = conf.get(name, np.zeros_like(sig[name])) - out_note[name] = f"Estimate of the {display_name} using {method_name}{conf_txt_wave}" - out_unit[name] = unit - # --- CASE B: DERIVED (calculated locally) --- - elif vital_type == 'derived': - deriv = meta.get('derivation') - if not deriv: continue - source = deriv['source_signal'] - min_t = deriv['window']['required'] - calc_func = deriv['func'] - if source in sig and sig_t >= min_t: - try: - # Execute the function defined in prpy - val = calc_func(sig[source], fps, conf.get(source), scope=EScope.GLOBAL) - if isinstance(val, tuple): - val_est, conf_est = val - else: - val_est = val - conf_src = conf.get(source, np.zeros_like(sig[source])) - if deriv.get('confidenceAggregation') == 'min': - conf_est = np.nanmin(conf_src) - else: - conf_est = np.nanmean(conf_src) - out_data[name] = val_est - out_conf[name] = conf_est - out_note[name] = f"Global estimate of {display_name} using {method_name}{conf_txt_scalar}" - except Exception as e: - logging.warning(f"Failed to derive {name}: {e}") - out_data[name] = np.nan - out_conf[name] = np.nan - out_note[name] = f"Calculation error for {display_name}." - else: - out_data[name] = np.nan - out_conf[name] = np.nan - out_note[name] = f"Video too short ({sig_t:.1f}s) or signal too noisy to derive {display_name}." - out_unit[name] = unit - - return out_data, out_unit, out_conf, out_note, live - -def estimate_rolling_vitals( - vital_signs_dict: dict, - data: dict, - conf: dict, - signals_available: set, - fps: float, - video_duration_s: float - ): - """Helper to calculate and append rolling vitals to the results dictionary. - - Args: - vital_signs_dict: The draft dict of vital signs to be modified - data: The estimated data/value for each signal. - conf: The estimation confidence for each signal. - signals_available: The signals supported by the used rPPG method - fps: The frame rate - video_duration_s: The duration - Returns: - None. The results are appended to `vital_signs_dict` in place. - """ - # Helper to standardize output format - def _add_result(name, unit, display_name, val, c): - if np.all(np.isnan(val)): - vital_signs_dict[name] = { - 'data': np.nan, 'unit': unit, 'confidence': np.nan, - 'note': f'Video too short or signal too noisy for rolling {display_name}.' - } - else: - val = np.round(val, 8) - c = np.round(c, 4) - vital_signs_dict[name] = { - 'data': val, 'unit': unit, 'confidence': c, - 'note': f'Rolling estimate of {display_name} with frame-wise confidence levels.' - } - # Identify relevant registry keys based on model availability - target_vitals = set() - for key, meta in VITAL_REGISTRY.items(): - if key in signals_available: - target_vitals.add(key) - elif 'model_aliases' in meta: - for alias in meta['model_aliases']: - if alias in signals_available: - target_vitals.add(key) - break - for name in target_vitals: - meta = VITAL_REGISTRY[name] - if 'derivation' not in meta: - continue - deriv = meta['derivation'] - if video_duration_s <= deriv['window']['required']: - continue - source_name = deriv.get('source_signal') - if source_name not in data: - continue - # Window parameters - min_w_size = int(deriv['window']['required'] * fps) - max_w_size = int(deriv['window']['size'] * fps) - overlap = int(max_w_size * 7 / 8) - output_key = f"rolling_{name}" - display_name = meta.get('display_name', name) - val_roll = np.nan - conf_roll = None - # Calculation - try: - calc_func = deriv.get('func') - res = calc_func( - data[source_name], fps, conf.get(source_name), - scope=EScope.ROLLING, - min_window_size=min_w_size, - max_window_size=max_w_size, - window_size=max_w_size, - overlap=overlap - ) - if isinstance(res, tuple): - val_roll, conf_roll = res - else: - val_roll = res - except Exception: - continue - # Confidence aggregation - if conf_roll is None: - agg_method = deriv.get('confidenceAggregation', 'mean') - agg_func = np.nanmin if agg_method == 'min' else np.nanmean - try: - c_src = conf.get(source_name, np.zeros_like(data[source_name])) - conf_roll = rolling_calc( - x=c_src, - calc_fn=lambda x: agg_func(x, axis=-1), - min_window_size=max_w_size, max_window_size=max_w_size, overlap=overlap - ) - except Exception: - conf_roll = np.nan - # Final check for None/NaN fallback - if conf_roll is None or (isinstance(conf_roll, float) and np.isnan(conf_roll)): - conf_roll = np.nan - - # Add to results - _add_result(output_key, meta['unit'], display_name, val_roll, conf_roll) \ No newline at end of file diff --git a/vitallens/stream.py b/vitallens/stream.py new file mode 100644 index 0000000..6f480ab --- /dev/null +++ b/vitallens/stream.py @@ -0,0 +1,422 @@ +# Copyright (c) 2026 Rouast Labs +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from dataclasses import dataclass +import logging +import numpy as np +from prpy.numpy.image import parse_image_inputs +import queue +import threading +import time +from typing import Callable, Optional +import uuid +import vitallens_core as vc + +from vitallens.methods.simple_rppg_method import SimpleRPPGMethod + +@dataclass +class InferenceContext: + timestamp: float + roi: vc.Rect + face_conf: float + +class FrameBuffer: + def __init__(self, buffer_id: str, roi: vc.Rect, max_capacity: int, timestamp: float): + """Initialize a thread-safe frame buffer for a specific region of interest. + + Args: + buffer_id: Unique identifier for the buffer. + roi: The region of interest (ROI) associated with this buffer. + max_capacity: The maximum number of frames the buffer can hold. + timestamp: The creation timestamp of the buffer. + """ + self.id = buffer_id + self.roi = roi + self.created_at = timestamp + self.last_seen = timestamp + self.max_capacity = max_capacity + self.buffer = [] + + @property + def count(self) -> int: + return len(self.buffer) + + def append(self, frame: np.ndarray, context: InferenceContext): + """Add a frame and its metadata to the buffer, maintaining max capacity. + + Args: + frame: The parsed RGB image data. + context: Metadata including timestamp and ROI. + """ + self.buffer.append((frame, context)) + self.last_seen = context.timestamp + if len(self.buffer) > self.max_capacity: + overflow = len(self.buffer) - self.max_capacity + self.buffer = self.buffer[overflow:] + + def execute(self, take_count: int, keep_count: int) -> list: + """Extract frames for inference and manage buffer overlap. + + Args: + take_count: Number of frames to extract for the current batch. + keep_count: Number of frames to retain for the next sliding window. + Returns: + payload: A list of (frame, context) tuples, or None if insufficient frames. + """ + if take_count <= 0 or len(self.buffer) < take_count: + return None + payload = self.buffer[:take_count] + elements_to_remove = max(0, take_count - keep_count) + self.buffer = self.buffer[elements_to_remove:] + return payload\ + +class BufferManager: + """Thread-safe manager for multiple streaming frame buffers.""" + def __init__(self, buffer_config: vc.BufferConfig): + """Initialize the manager with specific buffering and overlap constraints. + + Args: + buffer_config: Configuration for stream window sizes and overlaps. + """ + self.buffer_planner = vc.BufferPlanner(buffer_config) + self.buffer_config = buffer_config + self.buffers = {} + self.state = None + self.current_timestamp = 0.0 + self.lock = threading.Lock() + + def _get_active_metadata(self): + return [ + vc.BufferMetadata( + id=buf.id, roi=buf.roi, count=buf.count, + created_at=buf.created_at, last_seen=buf.last_seen + ) for buf in self.buffers.values() + ] + + def register_target(self, target_roi: vc.Rect, timestamp: float): + """Register or update a target ROI in the buffer planner. + + Args: + target_roi: The latest detected face ROI to track. + timestamp: The current frame timestamp. + Returns: + buffer_id: The ID of the buffer assigned to this target, or None. + """ + with self.lock: + self.current_timestamp = max(self.current_timestamp, timestamp) + action = self.buffer_planner.evaluate_target(target_roi, timestamp, self._get_active_metadata()) + if action.action == vc.BufferActionType.Create: + new_id = str(uuid.uuid4()) + roi = action.roi if action.roi is not None else target_roi + max_cap = self.buffer_config.stream_max + 50 + self.buffers[new_id] = FrameBuffer(new_id, roi, max_cap, timestamp) + return new_id + elif action.action == vc.BufferActionType.KeepAlive: + if action.matched_id in self.buffers: + self.buffers[action.matched_id].last_seen = timestamp + return action.matched_id + return None + + def append(self, buffer_id: str, frame: np.ndarray, context: InferenceContext): + """Route a frame to a specific active buffer. + + Args: + buffer_id: The identifier for the target buffer. + frame: The parsed RGB image data. + context: Metadata including timestamp and ROI. + """ + with self.lock: + self.current_timestamp = max(self.current_timestamp, context.timestamp) + if buffer_id in self.buffers: + self.buffers[buffer_id].append(frame, context) + + def poll(self, flush: bool = False): + """Poll the planner for the next required inference command. + + Args: + flush: If True, forces the planner to return commands for remaining data. + Returns: + command: A vc.InferenceCommand containing the buffer ID and frame counts. + """ + with self.lock: + has_state = self.state is not None + plan = self.buffer_planner.poll( + self._get_active_metadata(), + self.current_timestamp, + vc.InferenceMode.Stream, + has_state, + flush + ) + for drop_id in plan.buffers_to_drop: + if drop_id in self.buffers: + del self.buffers[drop_id] + return plan.command + + def execute(self, command: vc.InferenceCommand) -> list: + """Execute a specific inference command by extracting data from the relevant buffer. + + Args: + command: The command containing buffer ID and frame counts. + Returns: + window: A list of frames and contexts ready for inference. + """ + with self.lock: + if command.buffer_id in self.buffers: + return self.buffers[command.buffer_id].execute(command.take_count, command.keep_count) + return None + + def get_state(self): + """Retrieve the current physiological state of the rPPG model. + + Returns: + state: The model's internal state vector or None. + """ + with self.lock: + return self.state + + def update_state(self, new_state): + """Update the internal state of the rPPG model after an inference step. + + Args: + new_state: The updated state vector returned by the API. + """ + with self.lock: + self.state = new_state + + def get_all_buffers(self): + """Get a list of all currently active buffers. + + Returns: + buffers: A list of (buffer_id, roi) tuples. + """ + with self.lock: + return [(buf.id, buf.roi) for buf in self.buffers.values()] + + def reset(self): + """Clear all active buffers and reset the model state.""" + with self.lock: + self.buffers.clear() + self.state = None + +class StreamSession: + """Context manager handling background inference and buffer synchronization for live streams.""" + def __init__(self, rppg_method, face_detector=None, fdet_fs=1.0, on_result: Optional[Callable] = None): + """Initialize the streaming session and start the background inference thread. + + Args: + rppg_method: The method instance (e.g., VitalLensRPPGMethod) to use. + face_detector: Optional instance of FaceDetector for automatic tracking. + fdet_fs: Frequency [Hz] at which to run automatic face detection. + on_result: Optional callback for asynchronous result handling. + """ + self.rppg = rppg_method + self.face_detector = face_detector + self.fdet_fs = fdet_fs + self.on_result = on_result + self.result_queue = queue.Queue() + self.buffer_manager = BufferManager(vc.compute_buffer_config(self.rppg.session_config)) + self.vc_session = vc.Session(self.rppg.session_config) + self.last_fdet_time = -1.0 + self.last_processed_time = -1.0 + self.current_face = None + self.current_roi_rust = None + self.running = True + self.thread = threading.Thread(target=self._inference_loop, daemon=True) + self.thread.start() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def close(self): + """Stop the background thread and shut down the streaming session.""" + self.running = False + if self.thread.is_alive(): + self.thread.join(timeout=2.0) + + def push(self, frame: np.ndarray, timestamp: float, face: np.ndarray = None): + """Push a single frame and its timestamp into the streaming pipeline. + + Args: + frame: The input video frame of shape (h, w, 3) in RGB format. + timestamp: The absolute timestamp of the frame in seconds. + face: Optional pre-detected face box [x0, y0, x1, y1]. + """ + if not self.running: return + # Throttled face detection + min_interval = 1.0 / self.rppg.fps_target + if (timestamp - self.last_processed_time) < (min_interval - 0.005): + return + self.last_processed_time = timestamp + if face is not None: + self.current_face = face + elif self.face_detector and (timestamp - self.last_fdet_time) >= (1.0 / self.fdet_fs): + faces_rel, _ = self.face_detector(inputs=frame[np.newaxis], n_frames=1, fps=30.0) + if len(faces_rel) > 0 and len(faces_rel[0]) > 0: + h, w = frame.shape[:2] + self.current_face = (faces_rel[0][0] * [w, h, w, h]).astype(np.int64) + self.last_fdet_time = timestamp + if self.current_face is None: + return + # ROI calculation and registration + if self.current_roi_rust is None or (timestamp - self.last_fdet_time) < 0.05: + self.current_roi_rust = vc.calculate_roi( + face=vc.Rect( + x=float(self.current_face[0]), + y=float(self.current_face[1]), + width=float(self.current_face[2] - self.current_face[0]), + height=float(self.current_face[3] - self.current_face[1]) + ), + method=self.rppg.roi_method, + container=(float(frame.shape[1]), float(frame.shape[0])) + ) + self.buffer_manager.register_target(self.current_roi_rust, timestamp) + + active_buffers = self.buffer_manager.get_all_buffers() + if not active_buffers: + return + + # Process the frame for active buffer using its specific ROI + for buf_id, buf_roi in active_buffers: + # Clamp ROI to image boundaries + x0 = max(0, int(buf_roi.x)) + y0 = max(0, int(buf_roi.y)) + x1 = min(frame.shape[1], int(buf_roi.x + buf_roi.width)) + y1 = min(frame.shape[0], int(buf_roi.y + buf_roi.height)) + + if x1 <= x0 or y1 <= y0: + continue + + # Save to parsed_frame so we don't overwrite the original frame + parsed_frame, _, _, _, _ = parse_image_inputs( + inputs=frame, fps=30.0, + roi=(x0, y0, x1, y1), + target_size=self.rppg.input_size, preserve_aspect_ratio=False, + library='prpy', scale_algorithm='bilinear', allow_image=True, videodims=False + ) + + if isinstance(self.rppg, SimpleRPPGMethod): + payload = np.mean(parsed_frame, axis=(0, 1)) + else: + payload = parsed_frame + + ctx = InferenceContext(timestamp=timestamp, roi=buf_roi, face_conf=1.0) + self.buffer_manager.append(buf_id, payload, ctx) + + def get_result(self, block=False, timeout=None): + """Pull the latest inference results from the session queue. + + Args: + block: Whether to block until a result is available. + timeout: Maximum time to block if block is True. + Returns: + results: A list of analysis results for detected faces. + """ + try: + return self.result_queue.get(block=block, timeout=timeout) + except queue.Empty: + return None + + def _inference_loop(self): + """Background loop that polls for ready buffers and executes model inference.""" + while self.running: + command = self.buffer_manager.poll() + if not command: + time.sleep(0.01) + continue + window = self.buffer_manager.execute(command) + if not window: + continue + window_frames = np.stack([item[0] for item in window]) + window_faces = np.stack([[item[1].roi.x, item[1].roi.y, item[1].roi.x+item[1].roi.width, item[1].roi.y+item[1].roi.height] for item in window]) + # Dynamically compute fps from timestamps + timestamps = [item[1].timestamp for item in window] + actual_fps = (len(timestamps) - 1) / (timestamps[-1] - timestamps[0]) if len(timestamps) > 1 else self.rppg.fps_target + # Execute strategy + try: + sig, conf, live, new_state = self.rppg.infer_stream( + window_frames, actual_fps, self.buffer_manager.get_state() + ) + self.buffer_manager.update_state(new_state) + n_res = len(live) + timestamps = timestamps[-n_res:] + window_faces = window_faces[-n_res:] + signals_input = {k: vc.SignalInput(data=v.tolist(), confidence=conf[k].tolist()) for k, v in sig.items()} + face_input = vc.FaceInput(coordinates=window_faces.tolist(), confidence=live.tolist()) + session_input = vc.SessionInput(face=face_input, signals=signals_input, timestamp=timestamps) + session_result = self.vc_session.process(session_input, "Incremental") + res_dict = self._format_result(session_result) + if self.on_result: + self.on_result(res_dict) + self.result_queue.put(res_dict) + except Exception as e: + logging.error(f"Stream inference error: {e}") + self.buffer_manager.reset() + self.vc_session.reset() + + def _format_result(self, session_result): + """Format the raw session result into the standard vitallens-python result dictionary. + + Args: + session_result: The vc.SessionResult object from vitallens-core. + Returns: + res_dict: A list containing a formatted results dictionary for the face. + """ + face_dict = { + 'coordinates': [], + 'confidence': [], + 'note': 'Face detection coordinates for this face with live confidence levels.' + } + + if session_result.face is not None: + face_dict['coordinates'] = session_result.face.coordinates + face_dict['confidence'] = session_result.face.confidence + if session_result.face.note: + face_dict['note'] = session_result.face.note + + res_dict = { + 'face': face_dict, + 'vitals': {}, + 'waveforms': {}, + 'message': session_result.message, + 'fps': session_result.fps, + 'n': len(session_result.timestamp), + 'time': session_result.timestamp + } + + for key, wave in session_result.waveforms.items(): + res_dict['waveforms'][key] = { + 'data': wave.data, + 'unit': wave.unit, + 'confidence': wave.confidence, + 'note': wave.note + } + + for key, vital in session_result.vitals.items(): + res_dict['vitals'][key] = { + 'value': vital.value, + 'unit': vital.unit, + 'confidence': vital.confidence, + 'note': vital.note + } + + return [res_dict] \ No newline at end of file