Skip to content

你好,我发现代码实现中MCC部分并没有实现原文中描述的Straight-Through Estimator (STE),而是直接用softmax去计算的损失的 #7

@chenshming

Description

@chenshming

class Reconstruct:
def init(self):
self.ce = nn.CrossEntropyLoss(label_smoothing=0.2)

def compute(self, token_prediction_prob, tokens):
    hits = torch.sum(torch.argmax(token_prediction_prob, dim=-1) == tokens)
    NDCG10 = recalls_and_ndcgs_for_ks(token_prediction_prob.view(-1, token_prediction_prob.shape[-1]),
                                      tokens.reshape(-1, 1), 10)
    reconstruct_loss = self.ce(token_prediction_prob.view(-1, token_prediction_prob.shape[-1]), tokens.view(-1))
    return reconstruct_loss, hits, NDCG10

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions