Skip to content

empty tensors when using DeMo with FSDP #2

@peter-sk

Description

@peter-sk

I am successfully training 1B OLMo models using DeMo with DDP :-)

However, when training the same model using DeMo with FSDP, I run into an issue where both _dct and _idct are called with empty tensors (x.shape[-1] == 0).

My config for DeMo as the optimizer is:

optimizer:
name: demo
learning_rate: 3.0e-4
weight_decay: 0.0
eps: 1e-8
decay_norm_and_bias: false
decay_embeddings: false
compression_decay: 0.99
compression_topk: 32
compression_chunk: 64
metrics_log_interval: 1

My config for FSDP is:

fsdp:
wrapping_strategy: null
precision: mixed
disable_grad_sync: true

My code base is in the following PR to OLMo:

allenai/OLMo#771

Traceback (most recent call last):
File "/leonardo_scratch/fast/EUHPC_D07_063/OLMo/scripts/train.py", line 429, in
main(cfg)
File "/leonardo_scratch/fast/EUHPC_D07_063/OLMo/scripts/train.py", line 248, in main
optim = build_optimizer(cfg, dist_model)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/leonardo_scratch/fast/EUHPC_D07_063/miniconda3/envs/olmo/lib/python3.12/site-packages/olmo/optim.py", line 1138, in build_optimizer
return DeMo(
^^^^^
File "/leonardo_scratch/fast/EUHPC_D07_063/miniconda3/envs/olmo/lib/python3.12/site-packages/olmo/optim.py", line 702, in init
self.transform = TransformDCT(self.param_groups, self.compression_chunk)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/leonardo_scratch/fast/EUHPC_D07_063/miniconda3/envs/olmo/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/leonardo_scratch/fast/EUHPC_D07_063/miniconda3/envs/olmo/lib/python3.12/site-packages/olmo/demo_util.py", line 32, in init
self.f_dict[sc] = _dct(I, norm=norm).to(p.dtype).to(p.device)
^^^^^^^^^^^^^^^^^^
File "/leonardo_scratch/fast/EUHPC_D07_063/miniconda3/envs/olmo/lib/python3.12/site-packages/olmo/demo_util.py", line 169, in _dct
x = x.contiguous().view(-1, N)
^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: cannot reshape tensor of 0 elements into shape [-1, 0] because the unspecified dimension size -1 can be any value and is ambiguous

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions