diff --git a/pytr/alarms.py b/pytr/alarms.py index 5de1cbc..a88853a 100644 --- a/pytr/alarms.py +++ b/pytr/alarms.py @@ -101,6 +101,7 @@ async def set_alarms(self): while action_count > 0: await self.tr.recv() action_count -= 1 + await self.tr.close() return def overview(self): @@ -151,7 +152,11 @@ def get(self): except InvalidOperation: raise ValueError(f"{token} is no valid ISIN or decimal value that could represent an alarm.") - asyncio.run(self.alarms_loop()) + async def get_alarms_and_close(): + await self.alarms_loop() + await self.tr.close() + + asyncio.run(get_alarms_and_close()) self.overview() diff --git a/pytr/api.py b/pytr/api.py index d106626..050a177 100644 --- a/pytr/api.py +++ b/pytr/api.py @@ -303,6 +303,13 @@ async def _get_ws(self): return self._ws + async def close(self): + """Close the websocket connection gracefully.""" + if self._ws is not None: + self.log.info("Closing websocket connection...") + await self._ws.close() + self._ws = None + async def _next_subscription_id(self): async with self._lock: subscription_id = self._subscription_id_counter diff --git a/pytr/details.py b/pytr/details.py index 15cc3e0..e556bd7 100644 --- a/pytr/details.py +++ b/pytr/details.py @@ -50,6 +50,7 @@ async def details_loop(self): print(f"unmatched subscription of type '{subscription['type']}':\n{preview(response, num_lines=30)}") if recv == 6: + await self.tr.close() return def print_instrument(self): diff --git a/pytr/portfolio.py b/pytr/portfolio.py index 1d1ac47..f36b61b 100644 --- a/pytr/portfolio.py +++ b/pytr/portfolio.py @@ -188,6 +188,8 @@ async def portfolio_loop(self): portfolionew.append(pos) self.portfolio = portfolionew + await self.tr.close() + def _get_sort_func(self): if self.sort_by_column: match self.sort_by_column.lower(): diff --git a/pytr/timeline.py b/pytr/timeline.py index 957e4ee..84f5ea1 100644 --- a/pytr/timeline.py +++ b/pytr/timeline.py @@ -109,6 +109,8 @@ async def tl_loop(self): else: self.log.warning(f"unmatched subscription of type '{subscription['type']}':\n{preview(response)}") + await self.tr.close() + async def get_next_timeline_transactions(self, response): """ Get timeline transactions and store them in list timeline_transactions