-
Notifications
You must be signed in to change notification settings - Fork 4
Description
Hi! I have an image "image.png" in the current directory and I ran sam_masks = samclip.generate_image_sam_mask('image.png').
The assertion error in Colab is as followed:
AssertionError Traceback (most recent call last)
in <cell line: 2>()
1 classes = ["sky", "building"]
----> 2 sam_masks = samclip.generate_image_sam_mask('image.png')
3 # sam_masks = mask_generator.generate(image)
4 overlayed_img = samclip.get_overlay_img(sam_masks, classes)
5 overlay_pth = "output.jpg"
7 frames
in generate_image_sam_mask(self, img_pth, stability_score_threshold, predicted_iou_threshold)
151
152 mask_generator = SamAutomaticMaskGenerator(self.sam)
--> 153 masks = mask_generator.generate(self.cur_img_np)
154
155 og_masks = masks.copy()
/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
116
117 return decorate_context
/usr/local/lib/python3.10/dist-packages/segment_anything/automatic_mask_generator.py in generate(self, image)
161
162 # Generate masks
--> 163 mask_data = self._generate_masks(image)
164
165 # Filter small disconnected regions and holes in masks
/usr/local/lib/python3.10/dist-packages/segment_anything/automatic_mask_generator.py in _generate_masks(self, image)
204 data = MaskData()
205 for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
--> 206 crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
207 data.cat(crop_data)
208
/usr/local/lib/python3.10/dist-packages/segment_anything/automatic_mask_generator.py in _process_crop(self, image, crop_box, crop_layer_idx, orig_size)
234 cropped_im = image[y0:y1, x0:x1, :]
235 cropped_im_size = cropped_im.shape[:2]
--> 236 self.predictor.set_image(cropped_im)
237
238 # Get points for this crop
/usr/local/lib/python3.10/dist-packages/segment_anything/predictor.py in set_image(self, image, image_format)
58 input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
59
---> 60 self.set_torch_image(input_image_torch, image.shape[:2])
61
62 @torch.no_grad()
/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
116
117 return decorate_context
/usr/local/lib/python3.10/dist-packages/segment_anything/predictor.py in set_torch_image(self, transformed_image, original_image_size)
80 len(transformed_image.shape) == 4
81 and transformed_image.shape[1] == 3
---> 82 and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
83 ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
84 self.reset_image()
AssertionError: set_torch_image input must be BCHW with long side 1024.