Skip to content

CUDA memory increases after each loss.backward() #201

@sreetamasarkar

Description

@sreetamasarkar

I am trying to use the FMoE layer in a ViT-Base model for a simple classification task. However, there is a gradual increase in CUDA memory, which eventually leads to out-of-memory error. Digging deeper, I observe that there is a small increase in memory every time after the loss.backward() call. Here is what the memory growth looks like:

Training (50/10000 Steps) Data time=0.03(0.04) Batch time=1.00(0.97) Memory=8588.2(8440.4)
Training (51/10000 Steps) Data time=0.04(0.04) Batch time=0.78(0.96) Memory=8590.5(8443.3)
Training (52/10000 Steps) Data time=0.03(0.04) Batch time=1.13(0.97) Memory=8602.8(8446.3)
Training (53/10000 Steps) Data time=0.04(0.04) Batch time=0.67(0.96) Memory=8602.8(8449.2)
Training (54/10000 Steps) Data time=0.08(0.04) Batch time=1.13(0.96) Memory=8602.8(8452.0)
Training (55/10000 Steps) Data time=0.02(0.04) Batch time=0.85(0.96) Memory=8602.8(8454.7)
Training (56/10000 Steps) Data time=0.06(0.04) Batch time=0.96(0.96) Memory=8602.8(8457.3)
Training (57/10000 Steps) Data time=0.03(0.04) Batch time=1.02(0.96) Memory=8602.8(8459.9)
Training (58/10000 Steps) Data time=0.06(0.04) Batch time=0.81(0.96) Memory=8623.7(8462.5)
Training (59/10000 Steps) Data time=0.04(0.04) Batch time=1.08(0.96) Memory=8623.7(8465.2)
Training (60/10000 Steps) Data time=0.04(0.04) Batch time=0.72(0.96) Memory=8623.7(8467.8)
Training (61/10000 Steps) Data time=0.03(0.04) Batch time=1.11(0.96) Memory=8623.7(8470.4)
Training (62/10000 Steps) Data time=0.02(0.04) Batch time=0.77(0.96) Memory=8623.7(8472.8)
Training (63/10000 Steps) Data time=0.04(0.04) Batch time=1.10(0.96) Memory=8655.3(8475.4)
Training (64/10000 Steps) Data time=0.02(0.04) Batch time=0.88(0.96) Memory=8655.3(8478.2)
Training (65/10000 Steps) Data time=0.04(0.04) Batch time=0.92(0.96) Memory=8667.8(8481.0)
Training (66/10000 Steps) Data time=0.03(0.04) Batch time=1.02(0.96) Memory=8667.8(8483.8)
Training (67/10000 Steps) Data time=0.04(0.04) Batch time=0.72(0.95) Memory=8667.8(8486.6)
Training (68/10000 Steps) Data time=0.03(0.04) Batch time=1.10(0.96) Memory=8667.8(8489.2)
Training (69/10000 Steps) Data time=0.02(0.04) Batch time=0.70(0.95) Memory=8667.8(8491.8)
Training (70/10000 Steps) Data time=0.06(0.04) Batch time=1.09(0.96) Memory=8667.8(8494.3)
Training (71/10000 Steps) Data time=0.02(0.04) Batch time=0.89(0.95) Memory=8667.8(8496.7)
Training (72/10000 Steps) Data time=0.05(0.04) Batch time=0.86(0.95) Memory=8725.6(8499.5)
Training (73/10000 Steps) Data time=0.03(0.04) Batch time=1.01(0.95) Memory=8725.6(8502.5)
Training (74/10000 Steps) Data time=0.04(0.04) Batch time=0.71(0.95) Memory=8725.6(8505.5)

Here is my training loop. If I replace the FMoE Mlp layer with a regular Mlp layer, the training works fine.

    for step, batch in enumerate(train_loader):
        data_time.update(time.time() - end)
        batch = tuple(t.to(args.device) for t in batch)

        x, y = batch

        pred = model(x, 0)
        loss_ = loss_fct(pred, y)
        
        loss_.backward()
        model.allreduce_params()
        optimizer.step()
        optimizer.zero_grad()
        MB = 1024 * 1024
        memory_meter.update(torch.cuda.max_memory_allocated() / MB)
        log.info("Training ({}/{} Steps)\tData time={:.2f}({:.2f})\tBatch time={:.2f}({:.2f})\tMemory={:.1f}({:.1f})".format(
            global_step, t_total, data_time.val, data_time.avg, batch_time.val, batch_time.avg, memory_meter.val, memory_meter.avg))

Can you please help me with what might be causing this? Thanks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions