Skip to content

你好,我发现代码实现中MCC部分并没有实现原文中描述的Straight-Through Estimator (STE),而是直接用softmax去计算的损失的 #7

@chenshming

Description

@chenshming

class Reconstruct:
def init(self):
self.ce = nn.CrossEntropyLoss(label_smoothing=0.2)

def compute(self, token_prediction_prob, tokens):
    hits = torch.sum(torch.argmax(token_prediction_prob, dim=-1) == tokens)
    NDCG10 = recalls_and_ndcgs_for_ks(token_prediction_prob.view(-1, token_prediction_prob.shape[-1]),
                                      tokens.reshape(-1, 1), 10)
    reconstruct_loss = self.ce(token_prediction_prob.view(-1, token_prediction_prob.shape[-1]), tokens.view(-1))
    return reconstruct_loss, hits, NDCG10

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions