Skip to content

Commit bea8a11

Browse files
fix (#119)
1 parent 1cf2543 commit bea8a11

File tree

4 files changed

+12
-3
lines changed

4 files changed

+12
-3
lines changed

assets/slogan.png

206 KB
Loading

src/twinkle/patch/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
2-
from typing import TYPE_CHECKING, List, Union
2+
from typing import TYPE_CHECKING, Any, List, Union
33

44
if TYPE_CHECKING:
55
import torch
66

77

88
class Patch:
99

10-
def __call__(self, module: Union['torch.nn.Module', List['torch.nn.Module']], *args, **kwargs):
10+
def __call__(self, module: Union['torch.nn.Module', List['torch.nn.Module'], Any], *args, **kwargs):
1111
...

src/twinkle/sampler/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import twinkle
77
from twinkle import remote_function
88
from twinkle.data_format import InputFeature, SampleResponse, SamplingParams, Trajectory
9+
from twinkle.patch import Patch
910
from twinkle.template import Template
1011
from twinkle.utils import construct_class
1112

@@ -42,6 +43,10 @@ def sample(
4243
"""
4344
pass
4445

46+
@abstractmethod
47+
def apply_patch(self, patch_cls: Union[Patch, Type[Patch], str], **kwargs) -> None:
48+
...
49+
4550
@staticmethod
4651
def _not_encoded(inputs: Any) -> bool:
4752
"""Check if inputs are not yet encoded (i.e., is Trajectory, not InputFeature).

src/twinkle/sampler/vllm_sampler/vllm_sampler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,12 @@
2424
import numpy as np
2525
import os
2626
import threading
27-
from typing import Any, Dict, List, Optional, Union
27+
from typing import Any, Dict, List, Optional, Type, Union
2828

2929
from twinkle import DeviceMesh, get_logger, remote_class, remote_function, requires
3030
from twinkle.checkpoint_engine import CheckpointEngineMixin
3131
from twinkle.data_format import InputFeature, SampledSequence, SampleResponse, SamplingParams, Trajectory
32+
from twinkle.patch import Patch, apply_patch
3233
from twinkle.patch.vllm_lora_weights import VLLMLoraWeights
3334
from twinkle.sampler.base import Sampler
3435
from twinkle.utils import Platform
@@ -212,6 +213,9 @@ def encode_trajectory_for_vllm(self,
212213
result['videos'] = videos
213214
return result
214215

216+
def apply_patch(self, patch_cls: Union[Patch, Type[Patch], str], **kwargs) -> None:
217+
apply_patch(self, patch_cls, **kwargs)
218+
215219
async def _sample_single(
216220
self,
217221
feat: Dict[str, Any],

0 commit comments

Comments
 (0)