diff --git a/core/auth/api_hmac_auth.py b/core/auth/api_hmac_auth.py index 7800164..00b7ced 100644 --- a/core/auth/api_hmac_auth.py +++ b/core/auth/api_hmac_auth.py @@ -1,9 +1,11 @@ import hashlib import hmac import logging +from typing import Optional, Tuple from rest_framework import authentication from rest_framework import exceptions +from django.contrib.auth.models import User from core.models.facade import Profile from lib.cache import redis_client as redis_c @@ -13,8 +15,18 @@ class HMACAuthentication(authentication.BaseAuthentication): def authenticate(self, request): - api_key, access_signature, nonce = get_authorization_header(request=request) - + api_key, access_signature, nonce = get_authorization_header( + request=request) + return self.authenticate_values( + api_key, access_signature, nonce + ) + + @staticmethod + def authenticate_values( + api_key: Optional[str], + access_signature: Optional[str], + nonce: Optional[str] + ) -> Tuple[User, Optional[str]]: if not api_key or not access_signature or not nonce: raise exceptions.AuthenticationFailed('APIKEY, SIGNATURE or NONCE header does not set') @@ -70,4 +82,3 @@ def get_authorization_header(request): nonce = request.META.get('HTTP_NONCE') return [key, signature, nonce] - diff --git a/core/websockets/consumers.py b/core/websockets/consumers.py index ca860f5..707c5df 100644 --- a/core/websockets/consumers.py +++ b/core/websockets/consumers.py @@ -1,11 +1,13 @@ import logging from asyncio.futures import Future +from typing import Dict, Any from asgiref.sync import sync_to_async from cached_property import asyncio from channels.generic.websocket import AsyncJsonWebsocketConsumer -from django.contrib.auth.models import AnonymousUser +from django.contrib.auth.models import AnonymousUser, User +from core.auth.api_hmac_auth import HMACAuthentication from core.utils.auth import get_user_from_token from exchange.notifications import balance_notificator, executed_order_notificator, wallet_history_endpoint, \ opened_orders_endpoint, closed_orders_endpoint, opened_orders_by_pair_endpoint, closed_orders_by_pair_endpoint @@ -29,7 +31,6 @@ logger = logging.getLogger(__name__) - class LiveNotificationsConsumer(AsyncJsonWebsocketConsumer): AUTH_TIMEOUT = 10 @@ -52,14 +53,34 @@ async def do_auth(self, msg): token = msg.get('token', None) assert token user, token = await sync_to_async(get_user_from_token)(token) - self.scope['user'] = user - await self.join_group(user_notificator.gen_channel(user_id=self.scope['user'].id)) - data = user_notificator.prepare_data({'hello': self.scope['user'].username}) - await self.send_json(data) - self.authed.set_result(user) + await self._set_user(user) + except Exception as e: + self.authed.set_exception(e) + + async def do_auth_api_key(self, msg: Dict[str, Any]) -> None: + logger.debug('do_auth_api_key') + try: + authenticator = HMACAuthentication() + api_key = msg.get('api_key') + signature = msg.get('signature') + nonce = msg.get('nonce') + user, _ = await sync_to_async( + authenticator.authenticate_values + )( + api_key, signature, nonce + ) + await self._set_user(user) except Exception as e: self.authed.set_exception(e) + async def _set_user(self, user: User) -> None: + logger.debug('set_user') + self.scope['user'] = user + await self.join_group(user_notificator.gen_channel(user_id=self.scope['user'].id)) + data = user_notificator.prepare_data({'hello': self.scope['user'].username}) + await self.send_json(data) + self.authed.set_result(user) + async def join_group(self, grp_name): await self.channel_layer.group_add(grp_name, self.channel_name) self.groups.add(grp_name) @@ -78,7 +99,8 @@ async def receive_json(self, content, **kwargs): # if not self.authed.done(): # return await self.do_auth(content) - if 'token' in content and content.get('token', None) is None: + if ('token' in content and content.get('token', None) is None) or \ + ('api_key' in content and content.get('api_key', None) is None): await self.leave_group(user_notificator.gen_channel(user_id=self.scope['user'].id)) self.authed = Future() self.scope['user'] = AnonymousUser() @@ -86,6 +108,9 @@ async def receive_json(self, content, **kwargs): if not self.authed.done() and 'token' in content and content.get('token', None) is not None: return await self.do_auth(content) + if not self.authed.done() and 'api_key' in content and content.get('api_key', None) is not None: + return await self.do_auth_api_key(content) + command = content.get('command', None) params = content.get('params', {}) if not command: diff --git a/public_api/docs/websocket.md b/public_api/docs/websocket.md new file mode 100644 index 0000000..35b7867 --- /dev/null +++ b/public_api/docs/websocket.md @@ -0,0 +1,46 @@ +```python +import asyncio +import websockets +import hashlib +import hmac +import json + +async def connect(): + async with websockets.connect('wss://example.com/wsapi/v1/live_notifications') as websocket: + nonce = str(int(round(time.time() * 1000))) + api_key = '' + secret_key = '' + message = api_key + nonce + signature = hmac.new( + secret_key.encode(), + message.encode('utf-8'), + hashlib.sha256 + ).hexdigest().upper() + + # Construct authentication message + auth_data = { + 'api_key': api_key, + 'signature': signature, + 'nonce': nonce + } + # Send authentication message + await websocket.send(json.dumps(auth_data)) + + # Wait for server response + response = await websocket.recv() + print(response) + + # Send sample messages + for i in range(5): + message = { + 'type': 'sample_message', + 'data': {'index': i}, + } + await websocket.send(json.dumps(message)) + + # Wait for server response + response = await websocket.recv() + print(response) + +asyncio.get_event_loop().run_until_complete(connect()) +```