Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
106 commits
Select commit Hold shift + click to select a range
3cc126a
Wrap channel in an asyncio.Transport to eliminate loop back connection
bdraco Feb 8, 2025
c95319f
Merge branch 'main' into direct_connect_v2
bdraco Feb 9, 2025
1a12cf7
Merge remote-tracking branch 'origin/main' into direct_connect_v2
bdraco Feb 9, 2025
d0b2ae6
Fix test failures to merge collision
bdraco Feb 9, 2025
1a899ac
Merge branch 'merge_conflict_test' into direct_connect_v2
bdraco Feb 9, 2025
8d6d8bc
Merge branch 'main' into direct_connect_v2
bdraco Feb 9, 2025
5bcfe0d
handle start_tls returning None
bdraco Feb 9, 2025
1900c5b
do not write to closing transport
bdraco Feb 9, 2025
f163321
Merge remote-tracking branch 'origin/main' into direct_connect_v2
bdraco Feb 9, 2025
93d9bbb
add some flow control
bdraco Feb 9, 2025
62b522f
need to get a callback when queue is reduced
bdraco Feb 9, 2025
983fc04
need to get a callback when queue is reduced
bdraco Feb 9, 2025
41d2449
fixes
bdraco Feb 9, 2025
2b3a0e0
fixes
bdraco Feb 9, 2025
a351b1f
fixes
bdraco Feb 9, 2025
38bfb9d
preen
bdraco Feb 9, 2025
bb131a5
fixes
bdraco Feb 9, 2025
2b85cc1
Update snitun/multiplexer/channel.py
bdraco Feb 9, 2025
a9f2824
Update snitun/multiplexer/core.py
bdraco Feb 9, 2025
40a2f5c
buffer limits
bdraco Feb 9, 2025
84d054e
fix
bdraco Feb 9, 2025
6ea21b7
dry
bdraco Feb 9, 2025
25687eb
preen
bdraco Feb 9, 2025
406ddb7
debug
bdraco Feb 9, 2025
30e0ad8
Merge branch 'main' into direct_connect_v2
bdraco Feb 10, 2025
0ce366c
Merge branch 'main' into direct_connect_v2
bdraco Feb 11, 2025
2818e92
fix merge
bdraco Feb 11, 2025
157f812
fix merge
bdraco Feb 11, 2025
1cc3bea
fix merge
bdraco Feb 11, 2025
5f3d75e
Merge remote-tracking branch 'origin/main' into direct_connect_v2
bdraco Feb 11, 2025
7ff98be
Merge branch 'main' into direct_connect_v2
bdraco Feb 11, 2025
30aecec
Merge remote-tracking branch 'origin/main' into direct_connect_v2
bdraco Feb 12, 2025
0fd66c8
Merge branch 'main' into direct_connect_v2
bdraco Feb 12, 2025
e6c652c
preen
bdraco Feb 12, 2025
ed37098
Merge remote-tracking branch 'bdraco/direct_connect_v2' into direct_c…
bdraco Feb 12, 2025
f3d2c62
merge
bdraco Feb 12, 2025
c7e741d
test
bdraco Feb 12, 2025
33db064
Revert "test"
bdraco Feb 12, 2025
4d80687
cleanups
bdraco Feb 12, 2025
7e82626
merge
bdraco Feb 12, 2025
ffad8bb
merge
bdraco Feb 12, 2025
2124060
remove unused
bdraco Feb 12, 2025
a07944a
remove unused
bdraco Feb 12, 2025
79b2005
fix merge
bdraco Feb 12, 2025
579454d
preen
bdraco Feb 12, 2025
bb3336c
preen
bdraco Feb 12, 2025
3fc105b
Merge remote-tracking branch 'origin/main' into direct_connect_v2
bdraco Feb 13, 2025
70def54
cleanup merge
bdraco Feb 13, 2025
982edb9
cleanup merge
bdraco Feb 13, 2025
936e099
Revert "cleanup merge"
bdraco Feb 13, 2025
20e440d
remove
bdraco Feb 13, 2025
749d4ea
merge
bdraco Feb 13, 2025
be40b9d
fix merge
bdraco Feb 13, 2025
969649b
fix merge
bdraco Feb 13, 2025
373c887
Merge branch 'main' into direct_connect_v2
bdraco Feb 13, 2025
43a502d
fix mocking
bdraco Feb 13, 2025
7a12ff9
make it a fixture
bdraco Feb 13, 2025
8a35b4e
reverts
bdraco Feb 13, 2025
aa29c13
fixes
bdraco Feb 13, 2025
b2f9c87
fixes
bdraco Feb 13, 2025
d43cdcc
Merge branch 'main' into direct_connect_v2
bdraco Feb 13, 2025
25db33f
lint
bdraco Feb 13, 2025
1899b3a
lint
bdraco Feb 13, 2025
ed79751
lint
bdraco Feb 13, 2025
5e1fdbe
lint
bdraco Feb 13, 2025
93c07c1
Revert debugging
bdraco Feb 13, 2025
d6a1af8
Revert debug
bdraco Feb 13, 2025
af1ca2e
lint
bdraco Feb 13, 2025
9985039
compat
bdraco Feb 13, 2025
bdb6134
preen
bdraco Feb 13, 2025
322972b
preen
bdraco Feb 13, 2025
2440690
tests for pause/resume failures
bdraco Feb 13, 2025
e905913
tests for pause/resume failures
bdraco Feb 13, 2025
2b6c349
handle write after close
bdraco Feb 13, 2025
8fc1024
preen
bdraco Feb 13, 2025
fafc03a
compat
bdraco Feb 13, 2025
1f98a4c
Update tests/conftest.py
bdraco Feb 13, 2025
2c3c7a7
comments
bdraco Feb 13, 2025
af84014
Merge branch 'main' into direct_connect_v2
bdraco Feb 18, 2025
5150af4
Merge branch 'main' into direct_connect_v2
bdraco Feb 22, 2025
37b50b8
Merge branch 'main' into direct_connect_v2
bdraco Feb 28, 2025
d4059fb
Ensure protocol is ready for writing
bdraco Feb 28, 2025
70ad395
Merge remote-tracking branch 'origin/main' into direct_connect_v2
bdraco Mar 1, 2025
9350da1
format
bdraco Mar 1, 2025
9ad9eb0
Merge branch 'main' into direct_connect_v2
bdraco Mar 1, 2025
3b4ffb6
Merge remote-tracking branch 'origin/main' into direct_connect_v2
bdraco Mar 3, 2025
a9f341a
cleanups
bdraco Mar 3, 2025
076d65f
Merge branch 'main' into direct_connect_v2
bdraco Mar 3, 2025
ce4ed49
Merge branch 'main' into direct_connect_v2
bdraco Mar 3, 2025
e24ae53
Merge branch 'main' into direct_connect_v2
bdraco Mar 20, 2025
98cd876
fix merge
bdraco Mar 20, 2025
d1bc4b1
Merge remote-tracking branch 'origin/main' into direct_connect_v2
bdraco Apr 8, 2025
a71ca88
Merge remote-tracking branch 'origin/main' into direct_connect_v2
bdraco Jul 19, 2025
fdebb8f
Merge branch 'main' into direct_connect_v2
bdraco Jul 19, 2025
59ad787
Merge remote-tracking branch 'origin/main' into direct_connect_v2
bdraco Feb 13, 2026
a70e6d1
update test
bdraco Feb 13, 2026
a663aaf
fix mocking to be async
bdraco Feb 13, 2026
7d06ba7
limit chunk size to avoid buffer overflow
bdraco Feb 13, 2026
53a2fca
cleanup
bdraco Feb 13, 2026
17eebb4
cleanup
bdraco Feb 13, 2026
9f99f43
docs, cleanups
bdraco Feb 17, 2026
3a0f2da
docs, cleanups
bdraco Feb 17, 2026
1f3d8d2
comments, testing
bdraco Feb 17, 2026
d60edd1
address bot review
bdraco Feb 17, 2026
1b03481
Merge branch 'main' into direct_connect_v2
bdraco Feb 17, 2026
7d43ef9
revert version testing
bdraco Feb 17, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ test = [
"pytest-cov==7.0.0",
"pytest-timeout==2.4.0",
"pytest==9.0.2",
"trustme==1.2.1",
]

[project.urls]
Expand Down
228 changes: 118 additions & 110 deletions snitun/client/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
from __future__ import annotations

import asyncio
from collections.abc import Callable, Coroutine
from contextlib import suppress
from collections.abc import Callable
import ipaddress
from ipaddress import IPv4Address
import logging
from typing import Any
from ssl import SSLContext, SSLError

from ..exceptions import MultiplexerTransportClose, MultiplexerTransportError
from ..multiplexer.channel import ChannelFlowControlBase, MultiplexerChannel
from ..multiplexer.channel import MultiplexerChannel
from ..multiplexer.core import Multiplexer
from ..multiplexer.transport import ChannelTransport

_LOGGER = logging.getLogger(__name__)

Expand All @@ -22,19 +20,16 @@ class Connector:

def __init__(
self,
end_host: str,
end_port: int | None = None,
protocol_factory: Callable[[], asyncio.Protocol],
ssl_context: SSLContext,
whitelist: bool = False,
endpoint_connection_error_callback: Callable[[], Coroutine[Any, Any, None]]
| None = None,
) -> None:
"""Initialize Connector."""
self._loop = asyncio.get_event_loop()
self._end_host = end_host
self._end_port = end_port or 443
self._whitelist: set[IPv4Address] = set()
self._loop = asyncio.get_running_loop()
self._whitelist: set[ipaddress.IPv4Address] = set()
self._whitelist_enabled = whitelist
self._endpoint_connection_error_callback = endpoint_connection_error_callback
self._protocol_factory = protocol_factory
self._ssl_context = ssl_context

@property
def whitelist(self) -> set:
Expand All @@ -43,135 +38,148 @@ def whitelist(self) -> set:

def _whitelist_policy(self, ip_address: ipaddress.IPv4Address) -> bool:
"""Return True if the ip address can access to endpoint."""
if self._whitelist_enabled:
return ip_address in self._whitelist
return True
return not self._whitelist_enabled or ip_address in self._whitelist

async def handler(
self,
multiplexer: Multiplexer,
channel: MultiplexerChannel,
) -> None:
"""Handle new connection from SNIProxy."""
_LOGGER.debug(
"Receive from %s a request for %s",
channel.ip_address,
self._end_host,
)
_LOGGER.debug("New connection from %s", channel.ip_address)

# Check policy
if not self._whitelist_policy(channel.ip_address):
_LOGGER.warning("Block request from %s per policy", channel.ip_address)
multiplexer.delete_channel(channel)
return

await ConnectorHandler(self._loop, channel).start(
multiplexer,
self._end_host,
self._end_port,
self._endpoint_connection_error_callback,
transport = ChannelTransport(channel, multiplexer)

await ConnectorHandler(self._loop, multiplexer, channel, transport).start(
self._protocol_factory,
self._ssl_context,
)


class ConnectorHandler(ChannelFlowControlBase):
"""Handle connection to endpoint."""
class ConnectorHandler:
"""Handle connection to endpoint.

Bridges channel-level flow control to the SSL/app protocol layer.

Unlike ProxyPeerHandler (server side), this class does NOT inherit
from ChannelFlowControlBase because it has no read/write loop to
pause with a future. Instead, the ChannelTransport owns the reader
task, and this handler translates channel backpressure signals into
pause_protocol() / resume_protocol() calls on the transport, which
in turn call SSLProtocol.pause_writing() / resume_writing().
"""

def __init__(
self,
loop: asyncio.AbstractEventLoop,
multiplexer: Multiplexer,
channel: MultiplexerChannel,
transport: ChannelTransport,
) -> None:
"""Initialize ConnectorHandler."""
super().__init__(loop)
self._loop = loop
self._multiplexer = multiplexer
self._channel = channel
self._transport = transport

def _pause_resume_reader_callback(self, pause: bool) -> None:
"""Pause and resume reader."""
_LOGGER.debug(
"%s reader for %s (%s)",
"Pause" if pause else "Resume",
self._channel.ip_address,
self._channel.id,
)
if pause:
self._transport.pause_protocol()
else:
self._transport.resume_protocol()

async def _fail_to_start_tls(self, ex: Exception | None) -> None:
"""Handle failure to start TLS."""
channel = self._channel
_LOGGER.debug(
"Cannot start TLS for %s (%s): %s",
channel.ip_address,
channel.id,
ex,
)
self._multiplexer.delete_channel(channel)
await self._transport.stop_reader()

async def start(
self,
multiplexer: Multiplexer,
end_host: str,
end_port: int,
endpoint_connection_error_callback: Callable[[], Coroutine[Any, Any, None]]
| None = None,
protocol_factory: Callable[[], asyncio.Protocol],
ssl_context: SSLContext,
) -> None:
"""Start handler."""
channel = self._channel
channel.set_pause_resume_reader_callback(self._pause_resume_reader_callback)
# Open connection to endpoint
self._transport.start_reader()
# The request_handler is the aiohttp RequestHandler (or any other protocol)
# that is generated from the protocol_factory that
# was passed in the constructor.
protocol = protocol_factory()

# Upgrade the transport to TLS
try:
reader, writer = await asyncio.open_connection(host=end_host, port=end_port)
except OSError:
_LOGGER.error(
"Can't connect to endpoint %s:%s",
end_host,
end_port,
new_transport = await self._loop.start_tls(
self._transport,
protocol,
ssl_context,
server_side=True,
)
multiplexer.delete_channel(channel)
if endpoint_connection_error_callback:
await endpoint_connection_error_callback()
except (OSError, SSLError) as ex:
# This can can be just about any error, but mostly likely it's a TLS error
# or the connection gets dropped in the middle of the handshake
await self._fail_to_start_tls(ex)
return

from_endpoint: asyncio.Future[None] | asyncio.Task[bytes] | None = None
from_peer: asyncio.Task[bytes] | None = None
try:
# Process stream from multiplexer
while not writer.transport.is_closing():
if not from_endpoint:
# If the multiplexer channel queue is under water, pause the reader
# by waiting for the future to be set, once the queue is not under
# water the future will be set and cleared to resume the reader
from_endpoint = self._pause_future or self._loop.create_task(
reader.read(4096), # type: ignore[arg-type]
)
if not from_peer:
from_peer = self._loop.create_task(channel.read())

# Wait until data need to be processed
await asyncio.wait(
[from_endpoint, from_peer],
return_when=asyncio.FIRST_COMPLETED,
)

# From proxy
if from_endpoint.done():
if from_endpoint_exc := from_endpoint.exception():
raise from_endpoint_exc

if (from_endpoint_result := from_endpoint.result()) is not None:
await channel.write(from_endpoint_result)
from_endpoint = None

# From peer
if from_peer.done():
if from_peer_exc := from_peer.exception():
raise from_peer_exc

writer.write(from_peer.result())
from_peer = None

# Flush buffer
await writer.drain()

except (MultiplexerTransportError, OSError, RuntimeError):
_LOGGER.debug("Transport closed by endpoint for %s", channel.id)
multiplexer.delete_channel(channel)
if not new_transport:
await self._fail_to_start_tls(None)
return

except MultiplexerTransportClose:
_LOGGER.debug("Peer close connection for %s", channel.id)

finally:
# Cleanup peer reader
if from_peer:
if not from_peer.done():
from_peer.cancel()
else:
# Avoid exception was never retrieved
from_peer.exception()

# Cleanup endpoint reader
if from_endpoint and not from_endpoint.done():
from_endpoint.cancel()

# Close Transport
if not writer.transport.is_closing():
with suppress(OSError):
writer.close()
# Now that we have the connection upgraded to TLS, we can
# start the request handler and serve the connection.
#
# When the channel closes, ChannelTransport._force_close()
# schedules SSLProtocol.connection_lost() via call_soon, which
# cascades to the app protocol. We still call connection_lost
# on the app protocol directly here because during sendfile,
# SSLProtocol's cascade reaches _SendfileFallbackProtocol
# instead of the app protocol. aiohttp's connection_lost is
# reentrant so the double call in the normal case is safe.
#
# We intentionally do NOT call new_transport.close() — the
# SSLProtocol is already torn down by connection_lost from
# _force_close, and calling close() would start an SSL
# shutdown that can never complete (the channel is closed so
# the peer's close_notify never arrives).
_LOGGER.info("Connected peer: %s (%s)", channel.ip_address, channel.id)
try:
protocol.connection_made(new_transport)
await self._transport.wait_for_close()
except Exception as ex: # noqa: BLE001
# Make sure we catch any exception that might be raised
# so it gets feed back to connection_lost
_LOGGER.error(
"Transport error for %s (%s): %s",
channel.ip_address,
channel.id,
ex,
)
self._multiplexer.delete_channel(channel)
protocol.connection_lost(ex)
else:
_LOGGER.debug(
"Peer close connection for %s (%s)",
channel.ip_address,
channel.id,
)
protocol.connection_lost(None)
21 changes: 16 additions & 5 deletions snitun/multiplexer/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,19 +233,30 @@ def close(self) -> None:
with suppress(asyncio.QueueFull):
self._input.put_nowait(None)

async def write(self, data: bytes) -> None:
"""Send data to peer."""
def _make_message_or_raise(self, data: bytes) -> MultiplexerMessage:
"""Create message or raise exception."""
if not data:
raise MultiplexerTransportError
if self._closing:
raise MultiplexerTransportClose

# Create message
message = tuple.__new__(
return tuple.__new__(
MultiplexerMessage,
(self._id, CHANNEL_FLOW_DATA, data, b""),
)

def write_no_wait(self, data: bytes) -> None:
"""Send data to peer."""
# Create message
message = self._make_message_or_raise(data)
try:
self._output.put_nowait(self.id, message)
except asyncio.QueueFull:
_LOGGER.debug("Can't write to peer transport")
raise MultiplexerTransportError from None

async def write(self, data: bytes) -> None:
"""Send data to peer."""
message = self._make_message_or_raise(data)
try:
# Try to avoid the timer handle if we can
# add to the queue without waiting
Expand Down
4 changes: 2 additions & 2 deletions snitun/multiplexer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
MultiplexerTransportDecrypt,
MultiplexerTransportError,
)
from ..utils.asyncio import asyncio_timeout
from ..utils.asyncio import asyncio_timeout, create_eager_task
from ..utils.ipaddress import bytes_to_ip_address
from .channel import MultiplexerChannel
from .const import (
Expand Down Expand Up @@ -359,7 +359,7 @@ async def _process_message(self, message: MultiplexerMessage) -> None:

def _create_channel_task(self, coro: Coroutine[Any, Any, None]) -> None:
"""Create a new task for channel."""
task = self._loop.create_task(coro)
task = create_eager_task(coro, loop=self._loop)
self._channel_tasks.add(task)
task.add_done_callback(self._channel_tasks.remove)

Expand Down
Loading
Loading