Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[flake8]
#ignore=E121,E123,E126,E226,E24,E704,W503,W504,E501
max-line-length=160
max-line-length=240
8 changes: 4 additions & 4 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import time


class AnyValue():
class AnyValue:
def __eq__(self, value):
return True


class NowTimeDeltaValue():
class NowTimeDeltaValue:
def __init__(self, delta_sec=2.5):
self._delta_sec = delta_sec
self._last_time = None

def __eq__(self, value):
value = round(value, 1)
self._last_time = round(time.time(), 1)
return (value - self._delta_sec < self._last_time < value + self._delta_sec)
return value - self._delta_sec < self._last_time < value + self._delta_sec

def __repr__(self):
return f'{self.__class__.__name__}<{self._last_time}±{self._delta_sec}>'
return f"{self.__class__.__name__}<{self._last_time}±{self._delta_sec}>"
4 changes: 2 additions & 2 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ async def test_no_reties_on_fatal_error(sut):
assert sut.server.app["api"].channels["main"].qsize() == 0
assert sut.server.app["api"].channels["main"].stat() == {
"errors": 1,
"last_error": "Status: 401. Body: {}",
"last_error": "telegram fatal error: Status: 401. Body: {}",
"last_error_at": NowTimeDeltaValue(),
"queued": 1,
"sended": 0,
Expand Down Expand Up @@ -255,7 +255,7 @@ async def test_channel_statistics(sut):
await asyncio.sleep(1)
assert await resp.json() == {
"errors": 1,
"last_error": "Status: 401. Body: bad request",
"last_error": "telegram fatal error: Status: 401. Body: bad request",
"last_error_at": NowTimeDeltaValue(),
"queued": 3,
"sended": 2,
Expand Down
4 changes: 2 additions & 2 deletions tgproxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from tgproxy.channel import build_channel

__all__ = [
'HttpAPI',
'build_channel',
"HttpAPI",
"build_channel",
]
59 changes: 26 additions & 33 deletions tgproxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,33 @@

import tgproxy.errors as errors

DEFAULT_LOGGER_NAME = 'tgproxy.app'
DEFAULT_LOGGER_NAME = "tgproxy.app"


class BaseApp:
def __init__(self, logger_name=DEFAULT_LOGGER_NAME):
self.app = web.Application(
middlewares=[
self._error_middleware,
self._error_middleware,
],
)
self._log = logging.getLogger(logger_name)

self.app.add_routes([
web.get('/ping.html', self._on_ping),
])
self.app.add_routes(
[
web.get("/ping.html", self._on_ping),
]
)

def _success_response(self, status=200, **kwargs):
return web.json_response(
data=dict(
status='success',
**kwargs
),
data=dict(status="success", **kwargs),
status=status,
)

def _error_response(self, message, status=500, **kwargs):
return web.json_response(
dict(
status='error',
message=message or 'Unknown error',
**kwargs
),
dict(status="error", message=message or "Unknown error", **kwargs),
status=status,
)

Expand All @@ -54,7 +49,7 @@ async def _error_middleware(self, request, handler):

async def _on_ping(self, request):
return web.Response(
text='OK',
text="OK",
)


Expand All @@ -63,18 +58,20 @@ def __init__(self, channels):
super().__init__()

self.channels = dict(channels)
self.app.add_routes([
web.get('/', self._on_index),
web.get('/{channel_name}', self._on_channel_stat),
web.post('/{channel_name}', self._on_channel_send),
])
self.app.add_routes(
[
web.get("/", self._on_index),
web.get("/{channel_name}", self._on_channel_stat),
web.post("/{channel_name}", self._on_channel_send),
]
)
self.app.on_startup.append(self.start_background_channels_tasks)
self.app.on_shutdown.append(self.stop_background_channels_tasks)

self.background_tasks = list()

async def start_background_channels_tasks(self, app):
self._log.info('Start background tasks')
self._log.info("Start background tasks")
for ch in self.channels.values():
self.background_tasks.append(
asyncio.create_task(
Expand All @@ -84,31 +81,29 @@ async def start_background_channels_tasks(self, app):
)

async def stop_background_channels_tasks(self, app):
self._log.info('Stop background tasks')
self._log.info("Stop background tasks")
for task in self.background_tasks:
task.cancel()
await task

def _get_task_state(self, task):
if task.cancelled():
return 'cancelled'
return "cancelled"
if task.done():
return 'done'
return 'active'
return "done"
return "active"

def _has_failed_workers(self):
return any(map(lambda x: x.cancelled() or x.done(), self.background_tasks))

def _workers(self):
return {
task.get_name(): self._get_task_state(task) for task in self.background_tasks
}
return {task.get_name(): self._get_task_state(task) for task in self.background_tasks}

async def _on_ping(self, request):
if self._has_failed_workers():
return self._error_response(
status=502,
message='Background workers canceled',
message="Background workers canceled",
workers=self._workers(),
)

Expand All @@ -118,13 +113,11 @@ async def _on_ping(self, request):

async def _on_index(self, request):
return self._success_response(
channels={
name: str(ch) for name, ch in self.channels.items()
},
channels={name: str(ch) for name, ch in self.channels.items()},
)

def _get_channel(self, request):
channel_name = request.match_info['channel_name']
channel_name = request.match_info["channel_name"]
channel = self.channels.get(channel_name)
if not channel:
raise errors.ChannelNotFound(f'Channel "{channel_name}" not found')
Expand Down
91 changes: 35 additions & 56 deletions tgproxy/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import tgproxy.utils as utils
from tgproxy.queue import MemoryQueue

DEFAULT_LOGGER_NAME = 'tgproxy.channel'
DEFAULT_LOGGER_NAME = "tgproxy.channel"
CHANNELS_TYPES = dict()


Expand All @@ -35,19 +35,13 @@ def build_channel(url, **kwargs):
class Message:
# dict: name: {default: value}
request_fields = {
'text': {'default': '<Empty message>'},
'request_id': {},
"text": {"default": "<Empty message>"},
"request_id": {},
}

@classmethod
def from_request(cls, request):
message = cls(
**{
f: request.get(f, v.get('default'))
for f, v in cls.request_fields.items()
if request.get(f, v.get('default')) is not None
}
)
message = cls(**{f: request.get(f, v.get("default")) for f, v in cls.request_fields.items() if request.get(f, v.get("default")) is not None})
return message

def __init__(self, text, request_id=None, **options):
Expand All @@ -61,7 +55,7 @@ def __repr__(self):


class BaseChannel:
schema = '-'
schema = "-"
message_class = Message

@classmethod
Expand All @@ -74,7 +68,7 @@ def __init__(self, name, queue, provider, send_banner_on_startup=False, logger_n
self.send_banner_on_startup = send_banner_on_startup

self._queue = queue or MemoryQueue()
self._log = logging.getLogger(f'{logger_name}.{name}')
self._log = logging.getLogger(f"{logger_name}.{name}")
self._stat = dict(
queued=0,
sended=0,
Expand All @@ -85,7 +79,7 @@ def __init__(self, name, queue, provider, send_banner_on_startup=False, logger_n
)
self._retryMessageDelayInSeconds = 5

self._log.info(f'self.send_banner_on_startup == {self.send_banner_on_startup}')
self._log.info(f"self.send_banner_on_startup == {self.send_banner_on_startup}")

def qsize(self):
return self._queue.qsize()
Expand All @@ -102,7 +96,7 @@ async def put(self, message):
await self._enqueue(message)

async def process(self):
self._log.info(f'Start queue processor for {self}')
self._log.info(f"Start queue processor for {self}")
if self.send_banner_on_startup:
await self.put(self._get_banner())

Expand All @@ -111,7 +105,7 @@ async def process(self):
while True:
message = await self._dequeue()

self._log.info(f'Send message: {message}')
self._log.info(f"Send message: {message}")
while True:
error = await self._send_message(provider, message)
if error is None or isinstance(error, providers.errors.ProviderFatalError):
Expand All @@ -120,92 +114,77 @@ async def process(self):

await self._queue.task_done()
except asyncio.CancelledError:
self._log.info(f'Finish queue processor. Queue size: {self._queue.qsize()}')
self._log.info(f"Finish queue processor. Queue size: {self._queue.qsize()}")
except Exception as e:
self._log.error(str(e), exc_info=sys.exc_info())
self._log.info(f'Failed queue processor. Queue size: {self._queue.qsize()}')
self._log.info(f"Failed queue processor. Queue size: {self._queue.qsize()}")
raise

async def _enqueue(self, message):
self._log.info(f'Enque message: {message}')
self._log.info(f"Enque message: {message}")
await self._queue.enqueue(message)
self._stat['queued'] += 1
self._stat["queued"] += 1

async def _dequeue(self):
message = await self._queue.dequeue()
self._log.info(f'Deque message: {message}')
self._log.info(f"Deque message: {message}")
return message

def _get_banner(self):
return self.message_class.from_request(
dict(text=f'Start tgproxy on {socket.gethostname()}'),
dict(text=f"Start tgproxy on {socket.gethostname()}"),
)

async def _send_message(self, provider, message):
# return Exception if failed
try:
await provider.send_message(message)
self._log.info(f'Message sended: {message}')
self._stat['sended'] += 1
self._stat['last_sended_at'] = round(time.time(), 3)
self._log.info(f"Message sended: {message}")
self._stat["sended"] += 1
self._stat["last_sended_at"] = round(time.time(), 3)
except providers.errors.ProviderError as e:
self._stat['errors'] += 1
self._stat['last_error'] = str(e)
self._stat['last_error_at'] = round(time.time(), 3)
self._log.error(f'Error: {str(e)} Message: {message}', exc_info=sys.exc_info())
self._stat["errors"] += 1
self._stat["last_error"] = str(e)
self._stat["last_error_at"] = round(time.time(), 3)
self._log.error(f"Error: {str(e)} Message: {message}", exc_info=sys.exc_info())
return e

return None


class TelegramMessage(Message):
request_fields = {
'text': {'default': '<Empty message>'},
'request_id': {},
'parse_mode': {},
'disable_web_page_preview': {'default': 0},
'disable_notifications': {'default': 0},
'reply_to_message_id': {},
"text": {"default": "<Empty message>"},
"request_id": {},
"parse_mode": {},
"disable_web_page_preview": {"default": 0},
"disable_notifications": {"default": 0},
"reply_to_message_id": {},
}


class TelegramChannel(BaseChannel):
schema = 'telegram'
schema = "telegram"
message_class = TelegramMessage

@classmethod
def from_url(cls, url, queue=None, **kwargs):
parsed_url, options = utils.parse_url(url)
return cls(
name=parsed_url.path.strip('/'),
queue=queue,
provider=providers.TelegramChat(
chat_id=parsed_url.hostname,
bot_token=f'{parsed_url.username}:{parsed_url.password}',
**(options | kwargs)
),
**(options | kwargs)
)
return cls(name=parsed_url.path.strip("/"), queue=queue, provider=providers.TelegramChat(chat_id=parsed_url.hostname, bot_token=f"{parsed_url.username}:{parsed_url.password}", **(options | kwargs)), **(options | kwargs))

def __init__(self, name, queue, provider, send_banner_on_startup=True, **kwargs):
super().__init__(
name,
queue,
provider,
send_banner_on_startup=bool(int(send_banner_on_startup)),
**kwargs
)
super().__init__(name, queue, provider, send_banner_on_startup=bool(int(send_banner_on_startup)), **kwargs)
self._bot_token = provider.bot_token
self._bot_name = self._bot_token[:self._bot_token.find(":")]
self._bot_name = self._bot_token[: self._bot_token.find(":")]
self._chat_id = provider.chat_id
self._channel_options = dict(kwargs)

def __str__(self):
co = [f'{k}={v}' for k, v in self._channel_options.items()]
co = [f"{k}={v}" for k, v in self._channel_options.items()]
co = f'{"&".join(co)}'
if co:
co = f'?{co}'
return f'{self.schema}://{self._bot_name}:***@{self._chat_id}/{self.name}{co}'
co = f"?{co}"
return f"{self.schema}://{self._bot_name}:***@{self._chat_id}/{self.name}{co}"


register_channel_type(TelegramChannel)
Loading