From 0ea9a04cec6695ae22e5372ce7512b63b29b38b4 Mon Sep 17 00:00:00 2001 From: Vladyslav Shyrokyi Date: Mon, 26 Jan 2026 19:17:25 +0200 Subject: [PATCH] feat(manual-correspondence): add src_mask and tgt_mask parameters Introduces binary masks for tissue regions to relax rigidity constraints at boundaries during correspondence field processing. This enhancement allows for more precise control over the smoothing process in image alignment tasks. --- web_api/app/alignment.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/web_api/app/alignment.py b/web_api/app/alignment.py index 9e90ecf81..53a109516 100644 --- a/web_api/app/alignment.py +++ b/web_api/app/alignment.py @@ -31,7 +31,10 @@ class CorrespondencesDict(BaseModel): ..., description="List of correspondence lines" ) - +# TODO: Refactor endpoint: +# - receive binary blobs (json-descriptor, image-data, src-mask-data, tgt-mask-data) +# - decode blobs to numpy arrays +# - return blob files (relaxed-field-data, warped-image-data) as gzipped stream class ApplyCorrespondencesRequest(BaseModel): correspondences_dict: CorrespondencesDict = Field( ..., description="Dictionary with correspondence lines" @@ -48,6 +51,16 @@ class ApplyCorrespondencesRequest(BaseModel): optimizer_type: str = Field( "adam", description="Optimizer to use (adam, lbfgs, sgd, adamw)" ) + src_mask: list[list[list[list[float]]]] | None = Field( + None, + description="Binary tissue mask (1=tissue, 0=non-tissue), shape (1, H, W, 1)." + "Zeros break rigidity constraints between regions.", + ) + tgt_mask: list[list[list[list[float]]]] | None = Field( + None, + description="Binary tissue mask (1=tissue, 0=non-tissue), shape (1, H, W, 1)." + "Zeros break rigidity constraints between regions.", + ) class ApplyCorrespondencesResponse(BaseModel): @@ -106,6 +119,22 @@ async def apply_correspondences(request: ApplyCorrespondencesRequest): device=device ) + src_mask_tensor = None + if request.src_mask is not None: + src_mask_tensor = torch.tensor( + request.src_mask, + dtype=torch.float32, + device=device + ) + + tgt_mask_tensor = None + if request.tgt_mask is not None: + tgt_mask_tensor = torch.tensor( + request.tgt_mask, + dtype=torch.float32, + device=device + ) + relaxed_field, warped_image = apply_correspondences_to_image( correspondences_dict=correspondences_dict, image=image_tensor, @@ -113,6 +142,8 @@ async def apply_correspondences(request: ApplyCorrespondencesRequest): rig=request.rig, lr=request.lr, optimizer_type=request.optimizer_type, + src_mask=src_mask_tensor, + tgt_mask=tgt_mask_tensor, ) # Convert to numpy arrays on CPU