Skip to content

AssertionError: set_torch_image input must be BCHW with long side 1024. #1

@Yaojun-Lai

Description

@Yaojun-Lai

Hi! I have an image "image.png" in the current directory and I ran sam_masks = samclip.generate_image_sam_mask('image.png').

The assertion error in Colab is as followed:

AssertionError Traceback (most recent call last)
in <cell line: 2>()
1 classes = ["sky", "building"]
----> 2 sam_masks = samclip.generate_image_sam_mask('image.png')
3 # sam_masks = mask_generator.generate(image)
4 overlayed_img = samclip.get_overlay_img(sam_masks, classes)
5 overlay_pth = "output.jpg"

7 frames
in generate_image_sam_mask(self, img_pth, stability_score_threshold, predicted_iou_threshold)
151
152 mask_generator = SamAutomaticMaskGenerator(self.sam)
--> 153 masks = mask_generator.generate(self.cur_img_np)
154
155 og_masks = masks.copy()

/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
116
117 return decorate_context

/usr/local/lib/python3.10/dist-packages/segment_anything/automatic_mask_generator.py in generate(self, image)
161
162 # Generate masks
--> 163 mask_data = self._generate_masks(image)
164
165 # Filter small disconnected regions and holes in masks

/usr/local/lib/python3.10/dist-packages/segment_anything/automatic_mask_generator.py in _generate_masks(self, image)
204 data = MaskData()
205 for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
--> 206 crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
207 data.cat(crop_data)
208

/usr/local/lib/python3.10/dist-packages/segment_anything/automatic_mask_generator.py in _process_crop(self, image, crop_box, crop_layer_idx, orig_size)
234 cropped_im = image[y0:y1, x0:x1, :]
235 cropped_im_size = cropped_im.shape[:2]
--> 236 self.predictor.set_image(cropped_im)
237
238 # Get points for this crop

/usr/local/lib/python3.10/dist-packages/segment_anything/predictor.py in set_image(self, image, image_format)
58 input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
59
---> 60 self.set_torch_image(input_image_torch, image.shape[:2])
61
62 @torch.no_grad()

/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
116
117 return decorate_context

/usr/local/lib/python3.10/dist-packages/segment_anything/predictor.py in set_torch_image(self, transformed_image, original_image_size)
80 len(transformed_image.shape) == 4
81 and transformed_image.shape[1] == 3
---> 82 and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
83 ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
84 self.reset_image()

AssertionError: set_torch_image input must be BCHW with long side 1024.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions