Skip to content

Calculate batch norm statistic loss on parallel training #16

@dohe0342

Description

@dohe0342

Hello, I have one question about batch norm statistic loss.

Consider parallel training. I have 8 GPUs. and 1 gpu can bear 128 batch size.

But you know, batch norm statistic loss is calculated on each machine and each machine share their gradients not whole batch(1024). And I think this can cause image quality degradation.

So, here is my question. How can I calculate batch norm statistic loss on parallel training just like calculating whole batch size not mini-batch

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions