Skip to content

Conflict with torch.distributed? #145

@Wuzimeng

Description

@Wuzimeng

Hello, I encounted an error when calling flop_count_table() in my distributed training code.
The error message is as below. But I checked the input of function allgather() and didn't find anything unusual.

File "/xxx/anaconda3/envs/torch13/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 2275, in all_gather
work = default_pg.allgather([tensor_list], [tensor])
RuntimeError: unsupported input list type: Tensor[]
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 628698) of binary: /xxx/anaconda3/envs/torch13/bin/python

Here's a brief code which can regenerate my error by calling python -m torch.distributed.run --nproc_per_node=1 --master_port 10603 try.py

import torch
import torch.nn as nn

from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count_table

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        y = self.fc(x)
        concat_all_gather(y)
        return y.sum()

@torch.no_grad()
def concat_all_gather(tensor):
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
    output = torch.cat(tensors_gather, dim=0)
    return output    

torch.distributed.init_process_group(backend='nccl')
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()

model = SimpleModel().cuda()
model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)

flop = FlopCountAnalysis(model.module, torch.randn(100, 10).cuda())
print(flop_count_table(flop, max_depth=7, show_param_shapes=True))

torch.distributed.destroy_process_group()

Additionally, my environment is:
Python 3.9.18, cuda-11.7, fvcore==0.1.5.post20221221, torch 1.13

Another confusing thing is, in the python3.8.18 & cuda-11.4 & torch 1.10 environment, the above doesn't result in an error.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions