Skip to content

AssertionError: Float32 matmul precision check fails when calling model with numpy arrays #42

@SamueleBumbaca

Description

@SamueleBumbaca

AssertionError: Float32 matmul precision check fails when calling model with numpy arrays

Description

When attempting to use RoMaV2 with numpy arrays as input, the model raises an AssertionError during the forward pass. The error occurs in the precision check before any actual computation, making the model unusable with standard image inputs.

Environment

  • Python Version: 3.12.3
  • romav2 Version: 2.0.0
  • PyTorch Version: 2.10.0
  • Operating System: Linux (Ubuntu/Debian-based)
  • Installation Method: pip install --no-deps romav2 (see note below about installation issues)

Steps to Reproduce

import numpy as np
from PIL import Image
import cv2
from romav2 import RoMaV2

# Load images as numpy arrays
with Image.open('image1.png') as img:
    CSI0 = np.array(img)
with Image.open('image2.png') as img:
    CSI1 = np.array(img)

print(f"CSI0 shape: {CSI0.shape}, dtype: {CSI0.dtype}")  # (480, 640, 4), dtype: uint8
print(f"CSI1 shape: {CSI1.shape}, dtype: {CSI1.dtype}")  # (448, 800, 4), dtype: uint8

# Load pretrained model
model = RoMaV2()

# Attempt to match images
preds = model(CSI0, CSI1)  # Fails here

Behavior

The model raises an AssertionError with a cryptic error traceback:

AssertionError: (incomplete traceback showing torch compilation issues)

File ~/romav2/romav2.py:170, in RoMaV2.forward(...)
--> 170     if torch.get_float32_matmul_precision() != "highest":
    171         raise RuntimeError("Float32 matmul precision must be set to highest")
    172     assert not self.training, "Currently only inference mode released"

The error occurs deep in PyTorch's compilation stack, making it difficult to debug (full traceback at the end).

Root Cause Analysis

The error appears to be related to:

  1. Type mismatch: The model expects PyTorch tensors, not numpy arrays
  2. Value range: Images may need to be normalized to [-1, 1] or [0, 1] range
  3. Batch dimension: Images may need to be batched (B, C, H, W) instead of (H, W, C)
  4. Missing precision setting: torch.set_float32_matmul_precision("highest") may need to be called before model instantiation

Additional Installation Issue

Note: The package has a secondary bug in its dependency specification. Installation with pip install romav2 fails with:

ERROR: Could not find a version that satisfies the requirement dataclasses>=0.8 (from romav2) (from versions: 0.1, 0.2, 0.3, 0.4, 0.5, 0.6)
ERROR: No matching distribution found for dataclasses>=0.8

Reason: The dataclasses package only exists for Python <3.7 (as a backport) and maxes out at version 0.6. For Python >=3.7, dataclasses is built-in. The dependency should either be:

  • Removed entirely for python_requires>=3.10
  • Made conditional: dataclasses>=0.6; python_version<'3.7'

Workaround: Install with pip install --no-deps romav2 and manually install dependencies: pip install torch torchvision numpy pillow

Please fix dependency: Remove or make conditional the dataclasses>=0.8 requirement in setup.py/pyproject.toml

Full traceback:


AssertionError Traceback (most recent call last)
Cell In[49], line 6
4 model = RoMaV2()
5 # Match densely for any image-like pair of inputs
----> 6 preds = model(CSI0, CSI1)
7 # Sample 5000 matches for estimation
8 matches, overlaps, precision_AB, precision_BA = model.sample(preds, 5000)

File ~/projects/jetmobilewaggledancetracker/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1774, in Module._wrapped_call_impl(self, *args, **kwargs)
1772 def _wrapped_call_impl(self, *args, **kwargs):
1773 if self._compiled_call_impl is not None:
-> 1774 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1775 else:
1776 return self._call_impl(*args, **kwargs)

File ~/projects/jetmobilewaggledancetracker/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:953, in _TorchDynamoContext.call..compile_wrapper(*args, **kwargs)
950 _maybe_set_eval_frame(_callback_from_stance(callback))
952 try:
--> 953 return fn(*args, **kwargs)
954 except Unsupported as e:
955 if config.verbose:

File ~/projects/jetmobilewaggledancetracker/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
1782 # If we don't have any hooks, we want to skip the rest of the logic in
1783 # this function, and just call forward.
1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1785 or _global_backward_pre_hooks or _global_backward_hooks
1786 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787 return forward_call(*args, **kwargs)
1789 result = None
1790 called_always_called_hooks = set()

File ~/projects/jetmobilewaggledancetracker/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:124, in context_decorator..decorate_context(*args, **kwargs)
120 @functools.wraps(func)
121 def decorate_context(*args, **kwargs):
122 # pyrefly: ignore [bad-context-manager]
123 with ctx_factory():
--> 124 return func(*args, **kwargs)

File ~/projects/jetmobilewaggledancetracker/.venv/lib/python3.12/site-packages/romav2/romav2.py:170, in RoMaV2.forward(self, img_A_lr, img_B_lr, img_A_hr, img_B_hr)
162 @torch.inference_mode()
163 def forward(
164 self,
(...) 168 img_B_hr: torch.Tensor | None = None,
169 ) -> dict[str, tuple[torch.Tensor, torch.Tensor] | torch.Tensor]:
--> 170 if torch.get_float32_matmul_precision() != "highest":
171 raise RuntimeError("Float32 matmul precision must be set to highest")
172 assert not self.training, "Currently only inference mode released"

File ~/projects/jetmobilewaggledancetracker/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:2202, in CatchErrorsWrapper.call(self, frame, cache_entry, frame_state)
2196 return hijacked_callback(
2197 frame, cache_entry, self.hooks, frame_state
2198 )
2200 with compile_lock, _disable_current_modes():
2201 # skip=1: skip this frame
-> 2202 result = self._torchdynamo_orig_backend(
2203 frame, cache_entry, self.hooks, frame_state, skip=1
2204 )
2205 return result

File ~/projects/jetmobilewaggledancetracker/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:1945, in ConvertFrame.call(self, frame, cache_entry, hooks, frame_state, skip)
1943 counters["frames"]["total"] += 1
1944 try:
-> 1945 result = self._inner_convert(
1946 frame, cache_entry, hooks, frame_state, skip=skip + 1
1947 )
1948 counters["frames"]["ok"] += 1
1949 return result

File ~/projects/jetmobilewaggledancetracker/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:707, in ConvertFrameAssert.call(self, frame, cache_entry, hooks, frame_state, skip)
704 dynamo_tls.traced_frame_infos.append(info)
706 with compile_context(CompileContext(compile_id)):
--> 707 result = _compile(
708 frame.f_code,
709 frame.f_globals,
710 frame.f_locals,
711 frame.f_builtins,
712 frame.closure,
713 self._torchdynamo_orig_backend,
714 self._one_graph,
715 self._export,
716 self._export_constraints,
717 hooks,
718 cache_entry,
719 cache_size,
720 frame,
721 frame_state=frame_state,
722 compile_id=compile_id,
723 skip=skip + 1,
724 package=self._package,
725 convert_frame_box=self._box,
726 )
728 if config.caching_precompile and self._package is not None:
729 from .package import DynamoCache

File ~/projects/jetmobilewaggledancetracker/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:1752, in _compile(code, globals, locals, builtins, closure, compiler_fn, one_graph, export, export_constraints, hooks, cache_entry, cache_size, frame, frame_state, compile_id, skip, package, convert_frame_box)
1750 tracer_output = None
1751 try:
-> 1752 guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
1754 # NB: We only put_code_state in success case. Success case here
1755 # does include graph breaks; specifically, if a graph break still
1756 # resulted in a partially compiled graph, we WILL return here. An
(...) 1761 # to upload for graph break though, because this can prevent
1762 # extra graph break compilations.)
1763 put_code_state()

File ~/projects/jetmobilewaggledancetracker/.venv/lib/python3.12/site-packages/torch/_utils_internal.py:97, in compile_time_strobelight_meta..compile_time_strobelight_meta_inner..wrapper_function(*args, **kwargs)
94 # This is not needed but we have it here to avoid having profile_compile_time
95 # in stack traces when profiling is not enabled.
96 if not StrobelightCompileTimeProfiler.enabled:
---> 97 return function(*args, **kwargs)
99 return StrobelightCompileTimeProfiler.profile_compile_time(
100 function, phase_name, *args, **kwargs
101 )

File ~/projects/jetmobilewaggledancetracker/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:1433, in _compile..compile_inner(code, one_graph, hooks)
1427 stack.enter_context(
1428 torch._dynamo.callback_handler.install_callbacks(
1429 CallbackTrigger.DYNAMO, str(CompileContext.current_compile_id())
1430 )
1431 )
1432 stack.enter_context(CompileTimeInstructionCounter.record())
-> 1433 return _compile_inner(code, one_graph, hooks)
1435 return (
1436 ConvertFrameReturn(),
1437 None,
1438 )

File ~/projects/jetmobilewaggledancetracker/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:1567, in _compile.._compile_inner(code, one_graph, hooks)
1565 nonlocal cache_entry
1566 with dynamo_timed("build_guards", log_pt2_compile_event=True):
-> 1567 check_fn = dynamo_output.build_guards(
1568 code,
1569 hooks=hooks,
1570 save=package is not None,
1571 cache_entry=cache_entry,
1572 )
1574 if package is not None:
1575 assert check_fn.guards_state is not None

File ~/projects/jetmobilewaggledancetracker/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:879, in DynamoOutput.build_guards(self, code, hooks, save, cache_entry, strict_error)
877 output_graph = self.tracer_output.output_graph
878 assert output_graph is not None
--> 879 return CheckFunctionManager(
880 code,
881 output_graph,
882 cache_entry,
883 hooks.guard_fail_fn if hooks else None,
884 hooks.guard_filter_fn if hooks else None,
885 save_guards=save,
886 strict_error=strict_error,
887 )

File ~/projects/jetmobilewaggledancetracker/.venv/lib/python3.12/site-packages/torch/_dynamo/guards.py:3729, in CheckFunctionManager.init(self, f_code, output_graph, cache_entry, guard_fail_fn, guard_filter_fn, shape_code_parts, runtime_global_scope, save_guards, strict_error, source_get_cache)
3722 if not self.guard_manager.check(output_graph.local_scope):
3723 reasons = get_guard_fail_reason_helper(
3724 self.guard_manager,
3725 output_graph.local_scope,
3726 CompileContext.current_compile_id(),
3727 backend=None, # no need to set this because we are trying to find the offending guard entry
3728 )
-> 3729 raise AssertionError(
3730 "Guard failed on the same frame it was created. This is a bug - please create an issue."
3731 f"Guard fail reason: {reasons}"
3732 )
3734 if guard_manager_testing_hook_fn is not None:
3735 guard_manager_testing_hook_fn(
3736 self.guard_manager, output_graph.local_scope, builder
3737 )

AssertionError: Guard failed on the same frame it was created. This is a bug - please create an issue.Guard fail reason: 1/0: tensor '___from_numpy(img_A_lr)' dispatch key set mismatch. expected DispatchKeySet(CPU, BackendSelect, ADInplaceOrView), actual DispatchKeySet(CPU, BackendSelect)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions