diff --git a/ruyi/pipeline/pipeline_ruyi_inpaint.py b/ruyi/pipeline/pipeline_ruyi_inpaint.py index f725c2c..2daf747 100644 --- a/ruyi/pipeline/pipeline_ruyi_inpaint.py +++ b/ruyi/pipeline/pipeline_ruyi_inpaint.py @@ -912,7 +912,8 @@ def __call__( base_size = 512 // 8 // self.transformer.config.patch_size grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) image_rotary_emb = get_2d_rotary_pos_embed( - self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width) + self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width), + output_type="pt" ) style = torch.tensor([0], device=device)