Skip to content

RuntimeError in convert_model_output due to tensor shape mismatch (201 vs 96) #6

@hanm2019

Description

@hanm2019

Hi, I encountered a runtime error when running main.py in your project. The error seems to be caused by a shape mismatch during the convert_model_output operation in scheduling_dpmsolver_multistep.py. The issue arises when using scheduler=paraddim with the parallel pipeline.

(diffusion) root@p-74811dfc49bc-ackcs-00gjeicq:~/shared-nvme/paradigms# python main.py
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:07<00:00,  1.17s/it]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:05<00:00,  1.12it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:08<00:00,  1.37s/it]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:06<00:00,  1.09s/it]
parallel pipeline!
/root/shared-nvme/conda/envs/diffusion/lib/python3.9/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py:313: FutureWarning: `_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple.
  deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
/root/shared-nvme/paradigms/paradigms/stablediffusion_paradigms.py:142: FutureWarning: Accessing config attribute `in_channels` directly via 'UNet2DConditionModel' object attribute is deprecated. Please access 'in_channels' over 'UNet2DConditionModel's config object instead, e.g. 'unet.config.in_channels'.
  num_channels_latents = self.unet.in_channels
pass count 200
flop count 200
19284.71484375
done
/root/shared-nvme/conda/envs/diffusion/lib/python3.9/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py:599: FutureWarning: The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead
  deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
ngpu=1, parallel=1, scheduler=parawarmup, time=19284.71484375
parallel pipeline!
pass count 1000
flop count 1000
97392.1484375
done
ngpu=1, parallel=1, scheduler=paraddpm, time=97392.1484375
parallel pipeline!
pass count 200
flop count 200
19692.98828125
done
ngpu=1, parallel=1, scheduler=paraddim, time=19692.98828125
parallel pipeline!
/root/shared-nvme/conda/envs/diffusion/lib/python3.9/site-packages/diffusers/schedulers/scheduling_dpmsolver_multistep.py:651: FutureWarning: `timesteps` is deprecated and will be removed in version 1.0.0. Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`
  deprecate(
Traceback (most recent call last):
  File "/root/shared-nvme/paradigms/main.py", line 110, in <module>
    main()
  File "/root/shared-nvme/paradigms/main.py", line 83, in main
    output, stats = run_stable_diffusion(pipes[name], ngpu, parallel, num_inference_steps, prompts)
  File "/root/shared-nvme/paradigms/main.py", line 61, in run_stable_diffusion
    output, stats = pipe.paradigms_forward(prompts, num_inference_steps=num_inference_steps, full_return=False, **options)
  File "/root/shared-nvme/conda/envs/diffusion/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/root/shared-nvme/paradigms/paradigms/stablediffusion_paradigms.py", line 229, in paradigms_forward
    block_latents_denoise = scheduler.batch_step_no_noise(
  File "/root/shared-nvme/paradigms/paradigms/paradpmsolver_scheduler.py", line 68, in batch_step_no_noise
    model_output = self.convert_model_output(model_output, t, sample)
  File "/root/shared-nvme/conda/envs/diffusion/lib/python3.9/site-packages/diffusers/schedulers/scheduling_dpmsolver_multistep.py", line 671, in convert_model_output
    x0_pred = alpha_t * sample - sigma_t * model_output
RuntimeError: The size of tensor a (201) must match the size of tensor b (96) at non-singleton dimension 3

The main.py is running on NVIDIA 3090Ti, cuda 12.4. by the way, the conda env is as follow:

# packages in environment at /root/shared-nvme/conda/envs/diffusion:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main
_openmp_mutex             5.1                       1_gnu
accelerate                1.5.2              pyhd8ed1ab_0    conda-forge
asttokens                 3.0.0                    pypi_0    pypi
blas                      1.0                         mkl
bottleneck                1.4.2            py39ha9d4c09_0
brotli-python             1.0.9            py39h6a678d5_9
bzip2                     1.0.8                h5eee18b_6
c-ares                    1.19.1               h5eee18b_0
ca-certificates           2025.2.25            h06a4308_0
certifi                   2025.1.31        py39h06a4308_0
charset-normalizer        3.3.2              pyhd3eb1b0_0
click                     8.1.8            py39h06a4308_0
comm                      0.2.2                    pypi_0    pypi
contourpy                 1.2.1            py39hdb19cb5_1
cuda-cudart               12.4.127                      0    nvidia
cuda-cupti                12.4.127                      0    nvidia
cuda-libraries            12.4.1               h06a4308_1
cuda-nvrtc                12.4.127                      0    nvidia
cuda-nvtx                 12.4.127                      0    nvidia
cuda-opencl               12.4.127                      0    nvidia
cuda-runtime              12.4.1               hb982923_0
cudatoolkit               11.6.0               habf752d_9    nvidia
cycler                    0.11.0             pyhd3eb1b0_0
cyrus-sasl                2.1.28               h52b45da_1
debugpy                   1.8.14                   pypi_0    pypi
decorator                 5.2.1                    pypi_0    pypi
diffusers                 0.33.1                   pypi_0    pypi
exceptiongroup            1.2.2                    pypi_0    pypi
executing                 2.2.0                    pypi_0    pypi
expat                     2.7.1                h6a678d5_0
ffmpeg                    4.3                  hf484d3e_0    pytorch
filelock                  3.13.1                   pypi_0    pypi
fontconfig                2.14.1               h55d465d_3
fonttools                 4.55.3           py39h5eee18b_0
freetype                  2.13.3               h4a9f257_0
fsspec                    2024.6.1                 pypi_0    pypi
giflib                    5.2.2                h5eee18b_0
gmp                       6.3.0                h6a678d5_0
gmpy2                     2.2.1            py39h5eee18b_0
gnutls                    3.6.15               he1e5248_0
huggingface_hub           0.30.2             pyhd8ed1ab_0    conda-forge
icu                       73.1                 h6a678d5_0
idna                      3.7              py39h06a4308_0
imageio                   2.37.0           py39h06a4308_0
imageio-ffmpeg            0.6.0                    pypi_0    pypi
importlib-metadata        8.6.1                    pypi_0    pypi
importlib_resources       6.4.0            py39h06a4308_0
intel-openmp              2023.1.0         hdb19cb5_46306
ipykernel                 6.29.5                   pypi_0    pypi
ipython                   8.18.1                   pypi_0    pypi
jedi                      0.19.2                   pypi_0    pypi
jinja2                    3.1.4                    pypi_0    pypi
jpeg                      9e                   h5eee18b_3
jupyter-client            8.6.3                    pypi_0    pypi
jupyter-core              5.7.2                    pypi_0    pypi
kiwisolver                1.4.4            py39h6a678d5_0
krb5                      1.20.1               h143b758_1
lame                      3.100                h7b6447c_0
lcms2                     2.16                 h92b89f2_1
ld_impl_linux-64          2.40                 h12ee557_0
lerc                      4.0.0                h6a678d5_0
libabseil                 20250127.0      cxx17_h6a678d5_0
libcublas                 12.4.5.8                      0    nvidia
libcufft                  11.2.1.3                      0    nvidia
libcufile                 1.9.1.3                       0    nvidia
libcups                   2.4.2                h2d74bed_1
libcurand                 10.3.5.147                    0    nvidia
libcurl                   8.12.1               hc9e6f67_0
libcusolver               11.6.1.9                      0    nvidia
libcusparse               12.3.1.170                    0    nvidia
libdeflate                1.22                 h5eee18b_0
libedit                   3.1.20230828         h5eee18b_0
libev                     4.33                 h7f8727e_1
libffi                    3.4.4                h6a678d5_1
libgcc                    14.2.0               h767d61c_2    conda-forge
libgcc-ng                 14.2.0               h69a702a_2    conda-forge
libgfortran-ng            11.2.0               h00389a5_1
libgfortran5              11.2.0               h1234567_1
libglib                   2.78.4               hdc74915_0
libgomp                   14.2.0               h767d61c_2    conda-forge
libiconv                  1.16                 h5eee18b_3
libidn2                   2.3.4                h5eee18b_0
libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
libnghttp2                1.57.0               h2d74bed_0
libnpp                    12.2.5.30                     0    nvidia
libnvfatbin               12.4.127                      0    nvidia
libnvjitlink              12.4.127                      0    nvidia
libnvjpeg                 12.3.1.117                    0    nvidia
libpng                    1.6.39               h5eee18b_0
libpq                     17.4                 hdbd6064_0
libprotobuf               5.29.3               hc99497a_0
libssh2                   1.11.1               h251f7ec_0
libstdcxx-ng              11.2.0               h1234567_1
libtasn1                  4.19.0               h5eee18b_0
libtiff                   4.7.0                hde9077f_0
libunistring              0.9.10               h27cfd23_0
libuuid                   1.41.5               h5eee18b_0
libwebp                   1.3.2                h9f374a3_1
libwebp-base              1.3.2                h5eee18b_1
libxcb                    1.15                 h7f8727e_0
libxkbcommon              1.0.1                h097e994_2
libxml2                   2.13.7               hfdd30dd_0
llvm-openmp               14.0.6               h9e868ea_0
lz4-c                     1.9.4                h6a678d5_1
markupsafe                2.1.5                    pypi_0    pypi
matplotlib                3.9.2            py39h06a4308_1
matplotlib-base           3.9.2            py39hbfdbfaf_1
matplotlib-inline         0.1.7                    pypi_0    pypi
mkl                       2023.1.0         h213fc3f_46344
mkl-service               2.4.0            py39h5eee18b_2
mkl_fft                   1.3.11           py39h5eee18b_0
mkl_random                1.2.8            py39h1128e8f_0
mpc                       1.3.1                h5eee18b_0
mpfr                      4.2.1                h5eee18b_0
mpmath                    1.3.0            py39h06a4308_0
mysql                     8.4.0                h721767e_2
ncurses                   6.4                  h6a678d5_0
nest-asyncio              1.6.0                    pypi_0    pypi
nettle                    3.7.3                hbbd107a_1
networkx                  3.2.1            py39h06a4308_0
numexpr                   2.10.1           py39h3c60e43_0
numpy                     1.26.3                   pypi_0    pypi
nvidia-cublas-cu12        12.4.5.8                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.4.127                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.4.127                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.4.127                 pypi_0    pypi
nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
nvidia-cufft-cu12         11.2.1.3                 pypi_0    pypi
nvidia-curand-cu12        10.3.5.147               pypi_0    pypi
nvidia-cusolver-cu12      11.6.1.9                 pypi_0    pypi
nvidia-cusparse-cu12      12.3.1.170               pypi_0    pypi
nvidia-cusparselt-cu12    0.6.2                    pypi_0    pypi
nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
nvidia-nvtx-cu12          12.4.127                 pypi_0    pypi
openh264                  2.1.1                h4ff587b_0
openjpeg                  2.5.2                h0d4d230_1
openldap                  2.6.4                h42fbc30_0
openssl                   3.5.0                h7b32b05_0    conda-forge
packaging                 25.0                     pypi_0    pypi
pandas                    2.2.3            py39h6a678d5_0
parso                     0.8.4                    pypi_0    pypi
pcre2                     10.42                hebb0a14_1
pexpect                   4.9.0                    pypi_0    pypi
pillow                    11.0.0                   pypi_0    pypi
pip                       25.0             py39h06a4308_0
platformdirs              4.3.7                    pypi_0    pypi
prompt-toolkit            3.0.51                   pypi_0    pypi
psutil                    5.9.0            py39h5eee18b_1
ptyprocess                0.7.0                    pypi_0    pypi
pure-eval                 0.2.3                    pypi_0    pypi
pybind11-abi              4                    hd3eb1b0_1
pygments                  2.19.1                   pypi_0    pypi
pyparsing                 3.2.0            py39h06a4308_0
pyqt                      6.7.1            py39h6a678d5_1
pyqt6-sip                 13.9.1           py39h5eee18b_1
pysocks                   1.7.1            py39h06a4308_0
pyspng                    0.1.3                    pypi_0    pypi
python                    3.9.21               he870216_1
python-dateutil           2.9.0post0       py39h06a4308_2
python-tzdata             2023.3             pyhd3eb1b0_0
python_abi                3.9                      2_cp39    conda-forge
pytorch-cuda              12.4                 hc786d27_7    pytorch
pytorch-mutex             1.0                        cuda    pytorch
pytz                      2024.1           py39h06a4308_0
pyyaml                    6.0.2            py39h5eee18b_0
pyzmq                     26.4.0                   pypi_0    pypi
qtbase                    6.7.3                hdaa5aa8_0
qtdeclarative             6.7.3                h6a678d5_0
qtsvg                     6.7.3                he621ea3_0
qttools                   6.7.3                h80c7b02_0
qtwebchannel              6.7.3                h6a678d5_0
qtwebsockets              6.7.3                h6a678d5_0
readline                  8.2                  h5eee18b_0
regex                     2024.11.6                pypi_0    pypi
requests                  2.32.3           py39h06a4308_1
safetensors               0.5.3            py39he612d8f_0    conda-forge
scipy                     1.13.1           py39h5f9d8c6_1
setuptools                72.1.0           py39h06a4308_0
sip                       6.10.0           py39h6a678d5_0
six                       1.17.0           py39h06a4308_0
sqlite                    3.45.3               h5eee18b_0
stack-data                0.6.3                    pypi_0    pypi
sympy                     1.13.1                   pypi_0    pypi
tbb                       2021.8.0             hdb19cb5_0
tk                        8.6.14               h39e8969_0
tokenizers                0.21.1                   pypi_0    pypi
tomli                     2.0.1            py39h06a4308_0
torch                     2.6.0+cu124              pypi_0    pypi
torchaudio                2.5.1                py39_cu124    pytorch
tornado                   6.4.2            py39h5eee18b_0
tqdm                      4.67.1           py39h2f386ee_0
traitlets                 5.14.3                   pypi_0    pypi
transformers              4.51.3                   pypi_0    pypi
triton                    3.2.0                    pypi_0    pypi
typing-extensions         4.12.2           py39h06a4308_0
typing_extensions         4.12.2           py39h06a4308_0
tzdata                    2025a                h04d1e81_0
unicodedata2              15.1.0           py39h5eee18b_1
urllib3                   2.3.0            py39h06a4308_0
wcwidth                   0.2.13                   pypi_0    pypi
wheel                     0.45.1           py39h06a4308_0
xcb-util-cursor           0.1.4                h5eee18b_0
xformers                  0.0.29.post3             pypi_0    pypi
xz                        5.6.4                h5eee18b_1
yaml                      0.2.5                h7b6447c_0
zipp                      3.21.0           py39h06a4308_0
zlib                      1.2.13               h5eee18b_1
zstd                      1.5.6                hc292b87_0

Could you help identify whether this is a bug in the custom paraddim scheduler or a compatibility issue with the latest version of diffusers? Also, any workaround or patch would be appreciated.

Thanks for your great work on this project!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions