-
Notifications
You must be signed in to change notification settings - Fork 10
Description
I am successfully training 1B OLMo models using DeMo with DDP :-)
However, when training the same model using DeMo with FSDP, I run into an issue where both _dct and _idct are called with empty tensors (x.shape[-1] == 0).
My config for DeMo as the optimizer is:
optimizer:
name: demo
learning_rate: 3.0e-4
weight_decay: 0.0
eps: 1e-8
decay_norm_and_bias: false
decay_embeddings: false
compression_decay: 0.99
compression_topk: 32
compression_chunk: 64
metrics_log_interval: 1
My config for FSDP is:
fsdp:
wrapping_strategy: null
precision: mixed
disable_grad_sync: true
My code base is in the following PR to OLMo:
Traceback (most recent call last):
File "/leonardo_scratch/fast/EUHPC_D07_063/OLMo/scripts/train.py", line 429, in
main(cfg)
File "/leonardo_scratch/fast/EUHPC_D07_063/OLMo/scripts/train.py", line 248, in main
optim = build_optimizer(cfg, dist_model)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/leonardo_scratch/fast/EUHPC_D07_063/miniconda3/envs/olmo/lib/python3.12/site-packages/olmo/optim.py", line 1138, in build_optimizer
return DeMo(
^^^^^
File "/leonardo_scratch/fast/EUHPC_D07_063/miniconda3/envs/olmo/lib/python3.12/site-packages/olmo/optim.py", line 702, in init
self.transform = TransformDCT(self.param_groups, self.compression_chunk)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/leonardo_scratch/fast/EUHPC_D07_063/miniconda3/envs/olmo/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/leonardo_scratch/fast/EUHPC_D07_063/miniconda3/envs/olmo/lib/python3.12/site-packages/olmo/demo_util.py", line 32, in init
self.f_dict[sc] = _dct(I, norm=norm).to(p.dtype).to(p.device)
^^^^^^^^^^^^^^^^^^
File "/leonardo_scratch/fast/EUHPC_D07_063/miniconda3/envs/olmo/lib/python3.12/site-packages/olmo/demo_util.py", line 169, in _dct
x = x.contiguous().view(-1, N)
^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: cannot reshape tensor of 0 elements into shape [-1, 0] because the unspecified dimension size -1 can be any value and is ambiguous