Skip to content

Commit 3fa6f17

Browse files
committed
wip
1 parent 356612f commit 3fa6f17

File tree

4 files changed

+6
-5
lines changed

4 files changed

+6
-5
lines changed

src/twinkle/data_format/input_feature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
if TYPE_CHECKING:
1414
import torch
1515

16-
InputType = Union[List[List[int]], List[int], np.ndarray, torch.Tensor]
16+
InputType = Union[List[List[int]], List[int], np.ndarray, 'torch.Tensor']
1717

1818

1919
class InputFeature(TypedDict, total=False):

src/twinkle/data_format/output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
if TYPE_CHECKING:
1313
import torch
1414

15-
OutputType = Union[np.ndarray, torch.Tensor, List[Any]]
15+
OutputType = Union[np.ndarray, 'torch.Tensor', List[Any]]
1616

1717

1818
class ModelOutput(TypedDict, total=False):

src/twinkle/model/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,15 @@ def clip_grad_and_step(self, max_grad_norm: float=1.0, norm_type=2, **kwargs):
6060
...
6161

6262
@abstractmethod
63-
def set_loss(self, loss_cls: Union[Loss, Type[Loss], str, Callable[[InputFeature, ModelOutput, ...], torch.Tensor]], **kwargs):
63+
def set_loss(self, loss_cls: Union[Loss, Type[Loss], str, Callable[[InputFeature, ModelOutput, ...], 'torch.Tensor']], **kwargs):
6464
...
6565

6666
@abstractmethod
67-
def set_optimizer(self, optimizer_cls: Union[Optimizer, Type[Optimizer], str], **kwargs):
67+
def set_optimizer(self, optimizer_cls: Union['Optimizer', Type['Optimizer'], str], **kwargs):
6868
...
6969

7070
@abstractmethod
71-
def set_lr_scheduler(self, scheduler_cls: Union[LRScheduler, Type[LRScheduler], str], **kwargs):
71+
def set_lr_scheduler(self, scheduler_cls: Union['LRScheduler', Type['LRScheduler'], str], **kwargs):
7272
...
7373

7474
@abstractmethod

src/twinkle/model/transformers/transformers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T
384384
inputs['labels'] = labels
385385
optimizer_config.inputs = inputs
386386
optimizer_config.outputs = outputs
387+
optimizer_config.loss_value = outputs.get('aux_loss', 0)
387388
return outputs
388389

389390
@remote_function(collect='mean')

0 commit comments

Comments
 (0)