Skip to content

Commit 8b8b0a1

Browse files
committed
update
1 parent 98b0fe7 commit 8b8b0a1

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

src/twinkle_client/utils/patch_tinker.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -144,21 +144,31 @@ def patch_tinker():
144144
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
145145
ParsedCheckpointTinkerPath.from_tinker_path = classmethod(_patched_from_tinker_path)
146146

147-
# Patch 4: inject Twinkle-specific headers
148-
from tinker.lib.public_interfaces import service_client
149-
from twinkle_client.http.utils import get_api_key, get_request_id
150-
151-
original_get_default_headers = service_client._get_default_headers
152-
153-
def _patched_get_default_headers():
154-
headers = original_get_default_headers()
155-
# Add Twinkle-specific headers
156-
headers['serve_multiplexed_model_id'] = get_request_id()
157-
headers['Authorization'] = 'Bearer ' + get_api_key()
158-
headers['Twinkle-Authorization'] = 'Bearer ' + get_api_key()
159-
return headers
160-
161-
service_client._get_default_headers = _patched_get_default_headers
147+
# Patch 4: inject Twinkle-specific headers by patching ServiceClient.__init__.
148+
from tinker.lib.public_interfaces.service_client import ServiceClient
149+
from twinkle_client.http.utils import get_request_id, get_api_key
150+
151+
_original_service_client_init = ServiceClient.__init__
152+
153+
def _patched_service_client_init(self, user_metadata=None, **kwargs):
154+
# Resolve api_key with the same priority order used by AsyncTinker:
155+
# 1. explicit kwarg 2. TINKER_API_KEY env var 3. TWINKLE_SERVER_TOKEN env var
156+
api_key = kwargs.get('api_key')
157+
if api_key is None:
158+
api_key = get_api_key()
159+
160+
twinkle_headers = {
161+
'serve_multiplexed_model_id': get_request_id(),
162+
'Authorization': 'Bearer ' + api_key,
163+
'Twinkle-Authorization': 'Bearer ' + api_key,
164+
}
165+
# Merge: caller-supplied default_headers take precedence over twinkle_headers
166+
user_default_headers = kwargs.pop('default_headers', {})
167+
kwargs['default_headers'] = twinkle_headers | user_default_headers
168+
169+
_original_service_client_init(self, user_metadata=user_metadata, **kwargs)
170+
171+
ServiceClient.__init__ = _patched_service_client_init
162172

163173
_patched = True
164174
except ImportError:

0 commit comments

Comments
 (0)