diff --git a/web_api/app/alignment.py b/web_api/app/alignment.py index 9e90ecf81..53a109516 100644 --- a/web_api/app/alignment.py +++ b/web_api/app/alignment.py @@ -31,7 +31,10 @@ class CorrespondencesDict(BaseModel): ..., description="List of correspondence lines" ) - +# TODO: Refactor endpoint: +# - receive binary blobs (json-descriptor, image-data, src-mask-data, tgt-mask-data) +# - decode blobs to numpy arrays +# - return blob files (relaxed-field-data, warped-image-data) as gzipped stream class ApplyCorrespondencesRequest(BaseModel): correspondences_dict: CorrespondencesDict = Field( ..., description="Dictionary with correspondence lines" @@ -48,6 +51,16 @@ class ApplyCorrespondencesRequest(BaseModel): optimizer_type: str = Field( "adam", description="Optimizer to use (adam, lbfgs, sgd, adamw)" ) + src_mask: list[list[list[list[float]]]] | None = Field( + None, + description="Binary tissue mask (1=tissue, 0=non-tissue), shape (1, H, W, 1)." + "Zeros break rigidity constraints between regions.", + ) + tgt_mask: list[list[list[list[float]]]] | None = Field( + None, + description="Binary tissue mask (1=tissue, 0=non-tissue), shape (1, H, W, 1)." + "Zeros break rigidity constraints between regions.", + ) class ApplyCorrespondencesResponse(BaseModel): @@ -106,6 +119,22 @@ async def apply_correspondences(request: ApplyCorrespondencesRequest): device=device ) + src_mask_tensor = None + if request.src_mask is not None: + src_mask_tensor = torch.tensor( + request.src_mask, + dtype=torch.float32, + device=device + ) + + tgt_mask_tensor = None + if request.tgt_mask is not None: + tgt_mask_tensor = torch.tensor( + request.tgt_mask, + dtype=torch.float32, + device=device + ) + relaxed_field, warped_image = apply_correspondences_to_image( correspondences_dict=correspondences_dict, image=image_tensor, @@ -113,6 +142,8 @@ async def apply_correspondences(request: ApplyCorrespondencesRequest): rig=request.rig, lr=request.lr, optimizer_type=request.optimizer_type, + src_mask=src_mask_tensor, + tgt_mask=tgt_mask_tensor, ) # Convert to numpy arrays on CPU