Skip to content

Commit 7eb5679

Browse files
committed
Add attention_mask argument to loss_fn() and lm_cross_entropy_loss() and adjust the cross entropy calculation to ignore masked (padding) tokens.
1 parent 15ae297 commit 7eb5679

File tree

2 files changed

+1
-1
lines changed

2 files changed

+1
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
description="An implementation of transformers tailored for mechanistic interpretability."
44
license="MIT"
55
name="transformer-lens"
6-
packages=[{include="transformer_lens"}]
6+
packages=[{include="transformer_lens"}, {include="transformer_lens/py.typed"}]
77
readme="README.md"
88
# Version is automatically set by the pipeline on release
99
version="0.0.0"

transformer_lens/py.typed

Whitespace-only changes.

0 commit comments

Comments
 (0)