Skip to content

Weight updates for branch heads #17

@Amakri1020

Description

@Amakri1020

Hi, I noticed that for every training sample the network outputs predictions for all 5 output branches, but the loss is then (correctly) calculated using the output from the branch that corresponds to that sample's high-level command and summing those losses for all samples in the batch to get the total_loss tensor. Is this total loss value then used to update all 5 branches? Or is an individual loss for each branch calculated somewhere only using the samples that they are supposed to predict on given the high-level command?

Hopefully the question is clear, I can try to rephrase if it isn't!

Thanks a lot for this repo it has been very useful!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions