From cfd6a8a706f3f891bdf0b62f54f68052a4236712 Mon Sep 17 00:00:00 2001 From: Gregg Donovan Date: Mon, 22 Dec 2025 10:40:05 -0500 Subject: [PATCH 1/3] Adopt standard Python src layout for lib/py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit migrates lib/py from a non-standard package layout to the modern Python "src layout" pattern while preserving full API compatibility. ## Background The previous layout used a `package_dir={'thrift': 'src'}` mapping in setup.py, which mapped the `thrift` package to the `src/` directory. While functional, this approach: - Required workarounds for type checkers (`extra-paths` in pyproject.toml) - Confused some IDEs and tooling that expect standard layouts - Made editable installs less intuitive ## Changes ### Directory Structure Moved all package contents from `src/` to `src/thrift/`: ``` Before: After: lib/py/src/ lib/py/src/thrift/ ├── __init__.py ├── __init__.py ├── Thrift.py ├── Thrift.py ├── protocol/ ├── protocol/ ├── transport/ ├── transport/ ├── server/ ├── server/ └── ext/ └── ext/ ``` ### Build Configuration - setup.py: Changed `package_dir={'thrift': 'src'}` to `package_dir={'': 'src'}` - setup.py: Updated C extension include_dirs from `src` to `src/thrift` - pyproject.toml: Removed `extra-paths` workaround for type checker - pyproject.toml: Added pytest configuration ### Type Checking Fixes The standard layout enabled proper import resolution, which revealed several type errors that were previously hidden: 1. **Parameter name mismatches** (Liskov Substitution Principle): - TJSONProtocol: `binary`→`str_val`, `dbl`→`dub`, `type`→`ttype` - TMultiplexedProtocol: `type`→`ttype` 2. **Type compatibility fixes**: - TCompactProtocol: `writeVarint` now accepts `TTransportBase | BytesIO` - THttpServer: Added `cast(BinaryIO, self.rfile)` for type safety - TTransport: Added assertion for SASL mechanism - TNonblockingServer: Added null check for client.handle ## API Compatibility Client code remains unchanged - all imports work as before: ```python from thrift.Thrift import TException from thrift.protocol import TBinaryProtocol from thrift.transport import TSocket ``` ## Testing - Type checker: 0 errors (3 warnings for optional attributes) - pytest: 4 passed, 11 skipped (SSL tests require test resources) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- lib/py/pyproject.toml | 71 ++++ lib/py/setup.py | 32 +- .../src/{ => thrift}/TMultiplexedProcessor.py | 30 +- lib/py/src/{ => thrift}/TRecursive.py | 18 +- lib/py/src/{ => thrift}/TSCons.py | 12 +- lib/py/src/{ => thrift}/TSerialization.py | 26 +- lib/py/src/{ => thrift}/TTornado.py | 109 ++++-- lib/py/src/{ => thrift}/Thrift.py | 119 +++--- lib/py/src/{ => thrift}/__init__.py | 0 lib/py/src/{ => thrift}/ext/binary.cpp | 0 lib/py/src/{ => thrift}/ext/binary.h | 0 lib/py/src/{ => thrift}/ext/compact.cpp | 0 lib/py/src/{ => thrift}/ext/compact.h | 0 lib/py/src/{ => thrift}/ext/endian.h | 0 lib/py/src/{ => thrift}/ext/module.cpp | 0 lib/py/src/{ => thrift}/ext/protocol.h | 0 lib/py/src/{ => thrift}/ext/protocol.tcc | 0 lib/py/src/{ => thrift}/ext/types.cpp | 0 lib/py/src/{ => thrift}/ext/types.h | 0 lib/py/src/{ => thrift}/protocol/TBase.py | 36 +- .../{ => thrift}/protocol/TBinaryProtocol.py | 161 +++++---- .../{ => thrift}/protocol/TCompactProtocol.py | 250 +++++++------ .../{ => thrift}/protocol/THeaderProtocol.py | 116 +++--- .../{ => thrift}/protocol/TJSONProtocol.py | 339 +++++++++--------- .../protocol/TMultiplexedProtocol.py | 19 +- lib/py/src/{ => thrift}/protocol/TProtocol.py | 268 ++++++++------ .../protocol/TProtocolDecorator.py | 12 +- lib/py/src/{ => thrift}/protocol/__init__.py | 0 lib/py/src/thrift/py.typed | 0 lib/py/src/{ => thrift}/server/THttpServer.py | 53 ++- .../{ => thrift}/server/TNonblockingServer.py | 165 +++++---- .../{ => thrift}/server/TProcessPoolServer.py | 37 +- lib/py/src/{ => thrift}/server/TServer.py | 73 ++-- lib/py/src/{ => thrift}/server/__init__.py | 0 .../transport/THeaderTransport.py | 103 +++--- .../src/{ => thrift}/transport/THttpClient.py | 100 ++++-- .../src/{ => thrift}/transport/TSSLSocket.py | 139 +++---- lib/py/src/{ => thrift}/transport/TSocket.py | 84 +++-- .../src/{ => thrift}/transport/TTransport.py | 260 ++++++++------ lib/py/src/{ => thrift}/transport/TTwisted.py | 211 +++++++---- .../{ => thrift}/transport/TZlibTransport.py | 70 ++-- lib/py/src/{ => thrift}/transport/__init__.py | 0 .../src/{ => thrift}/transport/sslcompat.py | 18 +- 43 files changed, 1798 insertions(+), 1133 deletions(-) create mode 100644 lib/py/pyproject.toml rename lib/py/src/{ => thrift}/TMultiplexedProcessor.py (76%) rename lib/py/src/{ => thrift}/TRecursive.py (91%) rename lib/py/src/{ => thrift}/TSCons.py (78%) rename lib/py/src/{ => thrift}/TSerialization.py (67%) rename lib/py/src/{ => thrift}/TTornado.py (68%) rename lib/py/src/{ => thrift}/Thrift.py (71%) rename lib/py/src/{ => thrift}/__init__.py (100%) rename lib/py/src/{ => thrift}/ext/binary.cpp (100%) rename lib/py/src/{ => thrift}/ext/binary.h (100%) rename lib/py/src/{ => thrift}/ext/compact.cpp (100%) rename lib/py/src/{ => thrift}/ext/compact.h (100%) rename lib/py/src/{ => thrift}/ext/endian.h (100%) rename lib/py/src/{ => thrift}/ext/module.cpp (100%) rename lib/py/src/{ => thrift}/ext/protocol.h (100%) rename lib/py/src/{ => thrift}/ext/protocol.tcc (100%) rename lib/py/src/{ => thrift}/ext/types.cpp (100%) rename lib/py/src/{ => thrift}/ext/types.h (100%) rename lib/py/src/{ => thrift}/protocol/TBase.py (71%) rename lib/py/src/{ => thrift}/protocol/TBinaryProtocol.py (66%) rename lib/py/src/{ => thrift}/protocol/TCompactProtocol.py (70%) rename lib/py/src/{ => thrift}/protocol/THeaderProtocol.py (68%) rename lib/py/src/{ => thrift}/protocol/TJSONProtocol.py (64%) rename lib/py/src/{ => thrift}/protocol/TMultiplexedProtocol.py (74%) rename lib/py/src/{ => thrift}/protocol/TProtocol.py (56%) rename lib/py/src/{ => thrift}/protocol/TProtocolDecorator.py (76%) rename lib/py/src/{ => thrift}/protocol/__init__.py (100%) create mode 100644 lib/py/src/thrift/py.typed rename lib/py/src/{ => thrift}/server/THttpServer.py (77%) rename lib/py/src/{ => thrift}/server/TNonblockingServer.py (78%) rename lib/py/src/{ => thrift}/server/TProcessPoolServer.py (80%) rename lib/py/src/{ => thrift}/server/TServer.py (86%) rename lib/py/src/{ => thrift}/server/__init__.py (100%) rename lib/py/src/{ => thrift}/transport/THeaderTransport.py (84%) rename lib/py/src/{ => thrift}/transport/THttpClient.py (68%) rename lib/py/src/{ => thrift}/transport/TSSLSocket.py (79%) rename lib/py/src/{ => thrift}/transport/TSocket.py (81%) rename lib/py/src/{ => thrift}/transport/TTransport.py (66%) rename lib/py/src/{ => thrift}/transport/TTwisted.py (59%) rename lib/py/src/{ => thrift}/transport/TZlibTransport.py (83%) rename lib/py/src/{ => thrift}/transport/__init__.py (100%) rename lib/py/src/{ => thrift}/transport/sslcompat.py (86%) diff --git a/lib/py/pyproject.toml b/lib/py/pyproject.toml new file mode 100644 index 00000000000..a487583ae1a --- /dev/null +++ b/lib/py/pyproject.toml @@ -0,0 +1,71 @@ +[project] +name = "thrift" +version = "0.23.0" +description = "Python bindings for the Apache Thrift RPC system" +readme = "README.md" +license = "Apache-2.0" +requires-python = ">=3.12" +authors = [ + { name = "Apache Thrift Developers", email = "dev@thrift.apache.org" } +] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Console", + "Intended Audience :: Developers", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Typing :: Typed", + "Topic :: Software Development :: Libraries", + "Topic :: System :: Networking", +] + +[project.urls] +Homepage = "https://thrift.apache.org" +Repository = "https://github.com/apache/thrift" + +[project.optional-dependencies] +tornado = ["tornado>=4.0"] +twisted = ["twisted"] +all = ["tornado>=4.0", "twisted"] + +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.setuptools.package-data] +thrift = ["py.typed"] + +# Development dependencies +[dependency-groups] +dev = [ + "ty>=0.0.5", + "pytest>=8.0", + "pure-sasl", # For SASL transport tests +] + +# ty type checker configuration +[tool.ty.environment] +python-version = "3.12" + +[tool.ty.rules] +possibly-unresolved-reference = "error" +invalid-assignment = "error" +invalid-argument-type = "error" +invalid-return-type = "error" +missing-argument = "error" +call-non-callable = "error" +no-matching-overload = "error" +unresolved-import = "error" +invalid-type-form = "error" + +# pytest configuration +[tool.pytest.ini_options] +testpaths = ["test"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] diff --git a/lib/py/setup.py b/lib/py/setup.py index a02cc4ff1f1..9c3a4a7c547 100644 --- a/lib/py/setup.py +++ b/lib/py/setup.py @@ -20,12 +20,9 @@ # import sys -try: - from setuptools import setup, Extension -except Exception: - from distutils.core import setup, Extension -from distutils.command.build_ext import build_ext +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext from distutils.errors import CCompilerError, DistutilsExecError, DistutilsPlatformError # Fix to build sdist under vagrant @@ -36,7 +33,7 @@ except AttributeError: pass -include_dirs = ['src'] +include_dirs = ['src/thrift'] if sys.platform == 'win32': include_dirs.append('compat/win32') ext_errors = (CCompilerError, DistutilsExecError, DistutilsPlatformError, IOError) @@ -83,10 +80,10 @@ def run_setup(with_binary): Extension('thrift.protocol.fastbinary', extra_compile_args=['-std=c++11'], sources=[ - 'src/ext/module.cpp', - 'src/ext/types.cpp', - 'src/ext/binary.cpp', - 'src/ext/compact.cpp', + 'src/thrift/ext/module.cpp', + 'src/thrift/ext/types.cpp', + 'src/thrift/ext/binary.cpp', + 'src/thrift/ext/compact.cpp', ], include_dirs=include_dirs, ) @@ -96,13 +93,11 @@ def run_setup(with_binary): else: extensions = dict() - ssl_deps = [] - if sys.hexversion < 0x03050000: - ssl_deps.append('backports.ssl_match_hostname>=3.5') tornado_deps = ['tornado>=4.0'] twisted_deps = ['twisted'] setup(name='thrift', + python_requires='>=3.12', version='0.23.0', description='Python bindings for the Apache Thrift RPC system', long_description=read_file("README.md"), @@ -112,10 +107,12 @@ def run_setup(with_binary): url='http://thrift.apache.org', license='Apache License 2.0', extras_require={ - 'ssl': ssl_deps, 'tornado': tornado_deps, 'twisted': twisted_deps, - 'all': ssl_deps + tornado_deps + twisted_deps, + 'all': tornado_deps + twisted_deps, + }, + package_data={ + 'thrift': ['py.typed'], }, packages=[ 'thrift', @@ -123,13 +120,16 @@ def run_setup(with_binary): 'thrift.transport', 'thrift.server', ], - package_dir={'thrift': 'src'}, + package_dir={'': 'src'}, # Standard src layout classifiers=[ 'Development Status :: 5 - Production/Stable', 'Environment :: Console', 'Intended Audience :: Developers', 'Programming Language :: Python', 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', + 'Typing :: Typed', 'Topic :: Software Development :: Libraries', 'Topic :: System :: Networking' ], diff --git a/lib/py/src/TMultiplexedProcessor.py b/lib/py/src/thrift/TMultiplexedProcessor.py similarity index 76% rename from lib/py/src/TMultiplexedProcessor.py rename to lib/py/src/thrift/TMultiplexedProcessor.py index ff88430bd0b..3c372d50a69 100644 --- a/lib/py/src/TMultiplexedProcessor.py +++ b/lib/py/src/thrift/TMultiplexedProcessor.py @@ -17,17 +17,29 @@ # under the License. # +from __future__ import annotations + +from typing import Any, Callable, TYPE_CHECKING + from thrift.Thrift import TProcessor, TMessageType from thrift.protocol import TProtocolDecorator, TMultiplexedProtocol from thrift.protocol.TProtocol import TProtocolException +if TYPE_CHECKING: + from thrift.protocol.TProtocol import TProtocolBase + class TMultiplexedProcessor(TProcessor): - def __init__(self): + """Processor that multiplexes multiple services on a single connection.""" + + defaultProcessor: TProcessor | None + services: dict[str, TProcessor] + + def __init__(self) -> None: self.defaultProcessor = None self.services = {} - def registerDefault(self, processor): + def registerDefault(self, processor: TProcessor) -> None: """ If a non-multiplexed processor connects to the server and wants to communicate, use the given processor to handle it. This mechanism @@ -36,14 +48,14 @@ def registerDefault(self, processor): """ self.defaultProcessor = processor - def registerProcessor(self, serviceName, processor): + def registerProcessor(self, serviceName: str, processor: TProcessor) -> None: self.services[serviceName] = processor - def on_message_begin(self, func): + def on_message_begin(self, func: Callable[[str, int, int], None]) -> None: for key in self.services.keys(): self.services[key].on_message_begin(func) - def process(self, iprot, oprot): + def process(self, iprot: TProtocolBase, oprot: TProtocolBase) -> bool | None: (name, type, seqid) = iprot.readMessageBegin() if type != TMessageType.CALL and type != TMessageType.ONEWAY: raise TProtocolException( @@ -75,8 +87,12 @@ def process(self, iprot, oprot): class StoredMessageProtocol(TProtocolDecorator.TProtocolDecorator): - def __init__(self, protocol, messageBegin): + """Protocol decorator that stores and returns a predetermined message begin.""" + + messageBegin: tuple[str, int, int] + + def __init__(self, protocol: TProtocolBase, messageBegin: tuple[str, int, int]) -> None: self.messageBegin = messageBegin - def readMessageBegin(self): + def readMessageBegin(self) -> tuple[str, int, int]: return self.messageBegin diff --git a/lib/py/src/TRecursive.py b/lib/py/src/thrift/TRecursive.py similarity index 91% rename from lib/py/src/TRecursive.py rename to lib/py/src/thrift/TRecursive.py index 3e11dee25c6..676799323a4 100644 --- a/lib/py/src/TRecursive.py +++ b/lib/py/src/thrift/TRecursive.py @@ -17,15 +17,19 @@ # under the License. # +from __future__ import annotations + +from typing import Any + from thrift.Thrift import TType -TYPE_IDX = 1 -SPEC_ARGS_IDX = 3 -SPEC_ARGS_CLASS_REF_IDX = 0 -SPEC_ARGS_THRIFT_SPEC_IDX = 1 +TYPE_IDX: int = 1 +SPEC_ARGS_IDX: int = 3 +SPEC_ARGS_CLASS_REF_IDX: int = 0 +SPEC_ARGS_THRIFT_SPEC_IDX: int = 1 -def fix_spec(all_structs): +def fix_spec(all_structs: list[Any]) -> None: """Wire up recursive references for all TStruct definitions inside of each thrift_spec.""" for struc in all_structs: spec = struc.thrift_spec @@ -41,7 +45,7 @@ def fix_spec(all_structs): _fix_map(thrift_spec[SPEC_ARGS_IDX]) -def _fix_list_or_set(element_type): +def _fix_list_or_set(element_type: list[Any]) -> None: # For a list or set, the thrift_spec entry looks like, # (1, TType.LIST, 'lister', (TType.STRUCT, [RecList, None], False), None, ), # 1 # so ``element_type`` will be, @@ -54,7 +58,7 @@ def _fix_list_or_set(element_type): _fix_map(element_type[1]) -def _fix_map(element_type): +def _fix_map(element_type: list[Any]) -> None: # For a map of key -> value type, ``element_type`` will be, # (TType.I16, None, TType.STRUCT, [RecMapBasic, None], False), None, ) # which is just a normal struct definition. diff --git a/lib/py/src/TSCons.py b/lib/py/src/thrift/TSCons.py similarity index 78% rename from lib/py/src/TSCons.py rename to lib/py/src/thrift/TSCons.py index 633f67ab008..70455bbbfc7 100644 --- a/lib/py/src/TSCons.py +++ b/lib/py/src/thrift/TSCons.py @@ -17,19 +17,23 @@ # under the License. # +from __future__ import annotations + from os import path -from SCons.Builder import Builder +from typing import Any, Iterator + +from SCons.Builder import Builder # type: ignore[import-untyped] -def scons_env(env, add=''): +def scons_env(env: Any, add: str = '') -> None: opath = path.dirname(path.abspath('$TARGET')) lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE' cppbuild = Builder(action=lstr) env.Append(BUILDERS={'ThriftCpp': cppbuild}) -def gen_cpp(env, dir, file): +def gen_cpp(env: Any, dir: str, file: str) -> Any: scons_env(env) suffixes = ['_types.h', '_types.cpp'] - targets = map(lambda s: 'gen-cpp/' + file + s, suffixes) + targets: Iterator[str] = map(lambda s: 'gen-cpp/' + file + s, suffixes) return env.ThriftCpp(targets, dir + file + '.thrift') diff --git a/lib/py/src/TSerialization.py b/lib/py/src/thrift/TSerialization.py similarity index 67% rename from lib/py/src/TSerialization.py rename to lib/py/src/thrift/TSerialization.py index fbbe7680766..257241bb02e 100644 --- a/lib/py/src/TSerialization.py +++ b/lib/py/src/thrift/TSerialization.py @@ -17,22 +17,36 @@ # under the License. # +from __future__ import annotations + +from typing import Any, TYPE_CHECKING, TypeVar + from .protocol import TBinaryProtocol from .transport import TTransport +if TYPE_CHECKING: + from thrift.protocol.TProtocol import TProtocolFactory + from thrift.protocol.TBase import TBase + +T = TypeVar('T') + -def serialize(thrift_object, - protocol_factory=TBinaryProtocol.TBinaryProtocolFactory()): +def serialize( + thrift_object: TBase, + protocol_factory: TProtocolFactory = TBinaryProtocol.TBinaryProtocolFactory(), +) -> bytes: transport = TTransport.TMemoryBuffer() protocol = protocol_factory.getProtocol(transport) thrift_object.write(protocol) return transport.getvalue() -def deserialize(base, - buf, - protocol_factory=TBinaryProtocol.TBinaryProtocolFactory()): +def deserialize( + base: T, + buf: bytes, + protocol_factory: TProtocolFactory = TBinaryProtocol.TBinaryProtocolFactory(), +) -> T: transport = TTransport.TMemoryBuffer(buf) protocol = protocol_factory.getProtocol(transport) - base.read(protocol) + base.read(protocol) # type: ignore[union-attr] return base diff --git a/lib/py/src/TTornado.py b/lib/py/src/thrift/TTornado.py similarity index 68% rename from lib/py/src/TTornado.py rename to lib/py/src/thrift/TTornado.py index c7218301afd..a44740522d3 100644 --- a/lib/py/src/TTornado.py +++ b/lib/py/src/thrift/TTornado.py @@ -17,47 +17,58 @@ # under the License. # +from __future__ import annotations + import logging import socket import struct import warnings +from collections import deque +from contextlib import contextmanager +from io import BytesIO +from typing import Any, Callable, Generator, TYPE_CHECKING + +from tornado import gen, iostream, ioloop, tcpserver, concurrent # type: ignore[import-untyped] from .transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer -from io import BytesIO -from collections import deque -from contextlib import contextmanager -from tornado import gen, iostream, ioloop, tcpserver, concurrent +if TYPE_CHECKING: + from thrift.protocol.TProtocol import TProtocolFactory + from thrift.Thrift import TProcessor __all__ = ['TTornadoServer', 'TTornadoStreamTransport'] -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class _Lock: - def __init__(self): + """Simple async lock for Tornado coroutines.""" + + _waiters: deque[concurrent.Future[None]] + + def __init__(self) -> None: self._waiters = deque() - def acquired(self): + def acquired(self) -> bool: return len(self._waiters) > 0 - @gen.coroutine - def acquire(self): + @gen.coroutine # type: ignore[misc] + def acquire(self) -> Generator[Any, Any, Any]: blocker = self._waiters[-1] if self.acquired() else None - future = concurrent.Future() + future: concurrent.Future[None] = concurrent.Future() self._waiters.append(future) if blocker: yield blocker raise gen.Return(self._lock_context()) - def release(self): + def release(self) -> None: assert self.acquired(), 'Lock not aquired' future = self._waiters.popleft() future.set_result(None) @contextmanager - def _lock_context(self): + def _lock_context(self) -> Generator[None, None, None]: try: yield finally: @@ -66,7 +77,21 @@ def _lock_context(self): class TTornadoStreamTransport(TTransportBase): """a framed, buffered transport over a Tornado stream""" - def __init__(self, host, port, stream=None, io_loop=None): + + host: str + port: int + io_loop: ioloop.IOLoop + __wbuf: BytesIO + _read_lock: _Lock + stream: iostream.IOStream | None + + def __init__( + self, + host: str, + port: int, + stream: iostream.IOStream | None = None, + io_loop: ioloop.IOLoop | None = None, + ) -> None: if io_loop is not None: warnings.warn( "The `io_loop` parameter is deprecated and unused. Passing " @@ -86,16 +111,16 @@ def __init__(self, host, port, stream=None, io_loop=None): # servers provide a ready-to-go stream self.stream = stream - def with_timeout(self, timeout, future): + def with_timeout(self, timeout: float, future: Any) -> Any: return gen.with_timeout(timeout, future) - def isOpen(self): + def isOpen(self) -> bool: if self.stream is None: return False return not self.stream.closed() - @gen.coroutine - def open(self, timeout=None): + @gen.coroutine # type: ignore[misc] + def open(self, timeout: float | None = None) -> Generator[Any, Any, TTornadoStreamTransport]: logger.debug('socket connecting') sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) self.stream = iostream.IOStream(sock) @@ -114,24 +139,24 @@ def open(self, timeout=None): raise gen.Return(self) - def set_close_callback(self, callback): + def set_close_callback(self, callback: Callable[[], None] | None) -> None: """ Should be called only after open() returns """ - self.stream.set_close_callback(callback) + self.stream.set_close_callback(callback) # type: ignore[union-attr] - def close(self): + def close(self) -> None: # don't raise if we intend to close - self.stream.set_close_callback(None) - self.stream.close() + self.stream.set_close_callback(None) # type: ignore[union-attr] + self.stream.close() # type: ignore[union-attr] - def read(self, _): + def read(self, sz: int) -> bytes: # The generated code for Tornado shouldn't do individual reads -- only # frames at a time - assert False, "you're doing it wrong" + raise NotImplementedError("Use readFrame() for Tornado transport") @contextmanager - def io_exception_context(self): + def io_exception_context(self) -> Generator[None, None, None]: try: yield except (socket.error, IOError) as e: @@ -143,8 +168,8 @@ def io_exception_context(self): type=TTransportException.UNKNOWN, message=str(e)) - @gen.coroutine - def readFrame(self): + @gen.coroutine # type: ignore[misc] + def readFrame(self) -> Generator[Any, Any, bytes]: # IOStream processes reads one at a time with (yield self._read_lock.acquire()): with self.io_exception_context(): @@ -155,21 +180,33 @@ def readFrame(self): frame = yield self.stream.read_bytes(frame_length) raise gen.Return(frame) - def write(self, buf): + def write(self, buf: bytes) -> None: self.__wbuf.write(buf) - def flush(self): + def flush(self) -> Any: frame = self.__wbuf.getvalue() # reset wbuf before write/flush to preserve state on underlying failure frame_length = struct.pack('!i', len(frame)) self.__wbuf = BytesIO() with self.io_exception_context(): - return self.stream.write(frame_length + frame) + return self.stream.write(frame_length + frame) # type: ignore[union-attr] + + +class TTornadoServer(tcpserver.TCPServer): # type: ignore[misc] + """Tornado-based Thrift server.""" + _processor: TProcessor + _iprot_factory: TProtocolFactory + _oprot_factory: TProtocolFactory -class TTornadoServer(tcpserver.TCPServer): - def __init__(self, processor, iprot_factory, oprot_factory=None, - *args, **kwargs): + def __init__( + self, + processor: TProcessor, + iprot_factory: TProtocolFactory, + oprot_factory: TProtocolFactory | None = None, + *args: Any, + **kwargs: Any, + ) -> None: super(TTornadoServer, self).__init__(*args, **kwargs) self._processor = processor @@ -177,14 +214,14 @@ def __init__(self, processor, iprot_factory, oprot_factory=None, self._oprot_factory = (oprot_factory if oprot_factory is not None else iprot_factory) - @gen.coroutine - def handle_stream(self, stream, address): + @gen.coroutine # type: ignore[misc] + def handle_stream(self, stream: iostream.IOStream, address: tuple[str, int]) -> Generator[Any, Any, None]: host, port = address[:2] trans = TTornadoStreamTransport(host=host, port=port, stream=stream) oprot = self._oprot_factory.getProtocol(trans) try: - while not trans.stream.closed(): + while not trans.stream.closed(): # type: ignore[union-attr] try: frame = yield trans.readFrame() except TTransportException as e: diff --git a/lib/py/src/Thrift.py b/lib/py/src/thrift/Thrift.py similarity index 71% rename from lib/py/src/Thrift.py rename to lib/py/src/thrift/Thrift.py index 81fe8cf33fe..23e3767611a 100644 --- a/lib/py/src/Thrift.py +++ b/lib/py/src/thrift/Thrift.py @@ -17,27 +17,36 @@ # under the License. # - -class TType(object): - STOP = 0 - VOID = 1 - BOOL = 2 - BYTE = 3 - I08 = 3 - DOUBLE = 4 - I16 = 6 - I32 = 8 - I64 = 10 - STRING = 11 - UTF7 = 11 - STRUCT = 12 - MAP = 13 - SET = 14 - LIST = 15 - UTF8 = 16 - UTF16 = 17 - - _VALUES_TO_NAMES = ( +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, NoReturn + +if TYPE_CHECKING: + from thrift.protocol.TProtocol import TProtocolBase + + +class TType: + """Thrift type constants.""" + + STOP: int = 0 + VOID: int = 1 + BOOL: int = 2 + BYTE: int = 3 + I08: int = 3 + DOUBLE: int = 4 + I16: int = 6 + I32: int = 8 + I64: int = 10 + STRING: int = 11 + UTF7: int = 11 + STRUCT: int = 12 + MAP: int = 13 + SET: int = 14 + LIST: int = 15 + UTF8: int = 16 + UTF16: int = 17 + + _VALUES_TO_NAMES: tuple[str | None, ...] = ( 'STOP', 'VOID', 'BOOL', @@ -59,17 +68,19 @@ class TType(object): ) -class TMessageType(object): - CALL = 1 - REPLY = 2 - EXCEPTION = 3 - ONEWAY = 4 +class TMessageType: + """Thrift message type constants.""" + + CALL: int = 1 + REPLY: int = 2 + EXCEPTION: int = 3 + ONEWAY: int = 4 -class TProcessor(object): +class TProcessor: """Base class for processor, which works on two streams.""" - def process(self, iprot, oprot): + def process(self, iprot: TProtocolBase, oprot: TProtocolBase) -> bool | None: """ Process a request. The normal behvaior is to have the processor invoke the correct handler and then it is the @@ -77,7 +88,7 @@ def process(self, iprot, oprot): """ pass - def on_message_begin(self, func): + def on_message_begin(self, func: Callable[[str, int, int], None]) -> None: """ Install a callback that receives (name, type, seqid) after the message header is read. @@ -88,7 +99,9 @@ def on_message_begin(self, func): class TException(Exception): """Base class for all thrift exceptions.""" - def __init__(self, message=None): + message: str | None + + def __init__(self, message: str | None = None) -> None: Exception.__init__(self, message) super(TException, self).__setattr__("message", message) @@ -96,23 +109,25 @@ def __init__(self, message=None): class TApplicationException(TException): """Application level thrift exceptions.""" - UNKNOWN = 0 - UNKNOWN_METHOD = 1 - INVALID_MESSAGE_TYPE = 2 - WRONG_METHOD_NAME = 3 - BAD_SEQUENCE_ID = 4 - MISSING_RESULT = 5 - INTERNAL_ERROR = 6 - PROTOCOL_ERROR = 7 - INVALID_TRANSFORM = 8 - INVALID_PROTOCOL = 9 - UNSUPPORTED_CLIENT_TYPE = 10 - - def __init__(self, type=UNKNOWN, message=None): + UNKNOWN: int = 0 + UNKNOWN_METHOD: int = 1 + INVALID_MESSAGE_TYPE: int = 2 + WRONG_METHOD_NAME: int = 3 + BAD_SEQUENCE_ID: int = 4 + MISSING_RESULT: int = 5 + INTERNAL_ERROR: int = 6 + PROTOCOL_ERROR: int = 7 + INVALID_TRANSFORM: int = 8 + INVALID_PROTOCOL: int = 9 + UNSUPPORTED_CLIENT_TYPE: int = 10 + + type: int + + def __init__(self, type: int = UNKNOWN, message: str | None = None) -> None: TException.__init__(self, message) self.type = type - def __str__(self): + def __str__(self) -> str: if self.message: return self.message elif self.type == self.UNKNOWN_METHOD: @@ -138,7 +153,7 @@ def __str__(self): else: return 'Default (unknown) TApplicationException' - def read(self, iprot): + def read(self, iprot: TProtocolBase) -> None: iprot.readStructBegin() while True: (fname, ftype, fid) = iprot.readFieldBegin() @@ -159,7 +174,7 @@ def read(self, iprot): iprot.readFieldEnd() iprot.readStructEnd() - def write(self, oprot): + def write(self, oprot: TProtocolBase) -> None: oprot.writeStructBegin('TApplicationException') if self.message is not None: oprot.writeFieldBegin('message', TType.STRING, 1) @@ -173,21 +188,23 @@ def write(self, oprot): oprot.writeStructEnd() -class TFrozenDict(dict): +class TFrozenDict(dict[Any, Any]): """A dictionary that is "frozen" like a frozenset""" - def __init__(self, *args, **kwargs): + __hashval: int + + def __init__(self, *args: Any, **kwargs: Any) -> None: super(TFrozenDict, self).__init__(*args, **kwargs) # Sort the items so they will be in a consistent order. # XOR in the hash of the class so we don't collide with # the hash of a list of tuples. self.__hashval = hash(TFrozenDict) ^ hash(tuple(sorted(self.items()))) - def __setitem__(self, *args): + def __setitem__(self, *args: Any) -> NoReturn: raise TypeError("Can't modify frozen TFreezableDict") - def __delitem__(self, *args): + def __delitem__(self, *args: Any) -> NoReturn: raise TypeError("Can't modify frozen TFreezableDict") - def __hash__(self): + def __hash__(self) -> int: # type: ignore[override] return self.__hashval diff --git a/lib/py/src/__init__.py b/lib/py/src/thrift/__init__.py similarity index 100% rename from lib/py/src/__init__.py rename to lib/py/src/thrift/__init__.py diff --git a/lib/py/src/ext/binary.cpp b/lib/py/src/thrift/ext/binary.cpp similarity index 100% rename from lib/py/src/ext/binary.cpp rename to lib/py/src/thrift/ext/binary.cpp diff --git a/lib/py/src/ext/binary.h b/lib/py/src/thrift/ext/binary.h similarity index 100% rename from lib/py/src/ext/binary.h rename to lib/py/src/thrift/ext/binary.h diff --git a/lib/py/src/ext/compact.cpp b/lib/py/src/thrift/ext/compact.cpp similarity index 100% rename from lib/py/src/ext/compact.cpp rename to lib/py/src/thrift/ext/compact.cpp diff --git a/lib/py/src/ext/compact.h b/lib/py/src/thrift/ext/compact.h similarity index 100% rename from lib/py/src/ext/compact.h rename to lib/py/src/thrift/ext/compact.h diff --git a/lib/py/src/ext/endian.h b/lib/py/src/thrift/ext/endian.h similarity index 100% rename from lib/py/src/ext/endian.h rename to lib/py/src/thrift/ext/endian.h diff --git a/lib/py/src/ext/module.cpp b/lib/py/src/thrift/ext/module.cpp similarity index 100% rename from lib/py/src/ext/module.cpp rename to lib/py/src/thrift/ext/module.cpp diff --git a/lib/py/src/ext/protocol.h b/lib/py/src/thrift/ext/protocol.h similarity index 100% rename from lib/py/src/ext/protocol.h rename to lib/py/src/thrift/ext/protocol.h diff --git a/lib/py/src/ext/protocol.tcc b/lib/py/src/thrift/ext/protocol.tcc similarity index 100% rename from lib/py/src/ext/protocol.tcc rename to lib/py/src/thrift/ext/protocol.tcc diff --git a/lib/py/src/ext/types.cpp b/lib/py/src/thrift/ext/types.cpp similarity index 100% rename from lib/py/src/ext/types.cpp rename to lib/py/src/thrift/ext/types.cpp diff --git a/lib/py/src/ext/types.h b/lib/py/src/thrift/ext/types.h similarity index 100% rename from lib/py/src/ext/types.h rename to lib/py/src/thrift/ext/types.h diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/thrift/protocol/TBase.py similarity index 71% rename from lib/py/src/protocol/TBase.py rename to lib/py/src/thrift/protocol/TBase.py index 6c6ef18e877..739f66864c2 100644 --- a/lib/py/src/protocol/TBase.py +++ b/lib/py/src/thrift/protocol/TBase.py @@ -17,17 +17,25 @@ # under the License. # +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, ClassVar, NoReturn + from thrift.transport import TTransport +if TYPE_CHECKING: + from thrift.protocol.TProtocol import TProtocolBase + -class TBase(object): - __slots__ = () +class TBase: + __slots__: ClassVar[tuple[str, ...]] = () + thrift_spec: ClassVar[tuple[Any, ...] | None] = None - def __repr__(self): + def __repr__(self) -> str: L = ['%s=%r' % (key, getattr(self, key)) for key in self.__slots__] return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return False for attr in self.__slots__: @@ -37,23 +45,23 @@ def __eq__(self, other): return False return True - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not (self == other) - def read(self, iprot): + def read(self, iprot: TProtocolBase) -> None: if (iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None): iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) else: - iprot.readStruct(self, self.thrift_spec) + iprot.readStruct(self, self.thrift_spec) # type: ignore[arg-type] - def write(self, oprot): + def write(self, oprot: TProtocolBase) -> None: if (oprot._fast_encode is not None and self.thrift_spec is not None): oprot.trans.write( oprot._fast_encode(self, [self.__class__, self.thrift_spec])) else: - oprot.writeStruct(self, self.thrift_spec) + oprot.writeStruct(self, self.thrift_spec) # type: ignore[arg-type] class TExceptionBase(TBase, Exception): @@ -61,17 +69,17 @@ class TExceptionBase(TBase, Exception): class TFrozenBase(TBase): - def __setitem__(self, *args): + def __setitem__(self, *args: Any) -> NoReturn: raise TypeError("Can't modify frozen struct") - def __delitem__(self, *args): + def __delitem__(self, *args: Any) -> NoReturn: raise TypeError("Can't modify frozen struct") - def __hash__(self, *args): + def __hash__(self) -> int: return hash(self.__class__) ^ hash(self.__slots__) @classmethod - def read(cls, iprot): + def read(cls, iprot: TProtocolBase) -> TFrozenBase: # type: ignore[override] if (iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and cls.thrift_spec is not None): @@ -79,7 +87,7 @@ def read(cls, iprot): return iprot._fast_decode(None, iprot, [self.__class__, self.thrift_spec]) else: - return iprot.readStruct(cls, cls.thrift_spec, True) + return iprot.readStruct(cls, cls.thrift_spec, True) # type: ignore[arg-type, return-value] class TFrozenExceptionBase(TFrozenBase, TExceptionBase): diff --git a/lib/py/src/protocol/TBinaryProtocol.py b/lib/py/src/thrift/protocol/TBinaryProtocol.py similarity index 66% rename from lib/py/src/protocol/TBinaryProtocol.py rename to lib/py/src/thrift/protocol/TBinaryProtocol.py index af64ec10356..9a8ecad522d 100644 --- a/lib/py/src/protocol/TBinaryProtocol.py +++ b/lib/py/src/thrift/protocol/TBinaryProtocol.py @@ -17,9 +17,14 @@ # under the License. # +from __future__ import annotations + from struct import pack, unpack +from typing import Any -from .TProtocol import TType, TProtocolBase, TProtocolException, TProtocolFactory +from thrift.protocol.TProtocol import TProtocolBase, TProtocolException, TProtocolFactory +from thrift.Thrift import TType +from thrift.transport.TTransport import TTransportBase class TBinaryProtocol(TProtocolBase): @@ -30,108 +35,119 @@ class TBinaryProtocol(TProtocolBase): # instead it'll stay in 32 bit-land. # VERSION_MASK = 0xffff0000 - VERSION_MASK = -65536 + VERSION_MASK: int = -65536 # VERSION_1 = 0x80010000 - VERSION_1 = -2147418112 - - TYPE_MASK = 0x000000ff - - def __init__(self, trans, strictRead=False, strictWrite=True, **kwargs): + VERSION_1: int = -2147418112 + + TYPE_MASK: int = 0x000000ff + + strictRead: bool + strictWrite: bool + string_length_limit: int | None + container_length_limit: int | None + + def __init__( + self, + trans: TTransportBase, + strictRead: bool = False, + strictWrite: bool = True, + **kwargs: Any, + ) -> None: TProtocolBase.__init__(self, trans) self.strictRead = strictRead self.strictWrite = strictWrite self.string_length_limit = kwargs.get('string_length_limit', None) self.container_length_limit = kwargs.get('container_length_limit', None) - def _check_string_length(self, length): + def _check_string_length(self, length: int) -> None: self._check_length(self.string_length_limit, length) - def _check_container_length(self, length): + def _check_container_length(self, length: int) -> None: self._check_length(self.container_length_limit, length) - def writeMessageBegin(self, name, type, seqid): + def writeMessageBegin(self, name: str, ttype: int, seqid: int) -> None: if self.strictWrite: - self.writeI32(TBinaryProtocol.VERSION_1 | type) + self.writeI32(TBinaryProtocol.VERSION_1 | ttype) self.writeString(name) self.writeI32(seqid) else: self.writeString(name) - self.writeByte(type) + self.writeByte(ttype) self.writeI32(seqid) - def writeMessageEnd(self): + def writeMessageEnd(self) -> None: pass - def writeStructBegin(self, name): + def writeStructBegin(self, name: str) -> None: pass - def writeStructEnd(self): + def writeStructEnd(self) -> None: pass - def writeFieldBegin(self, name, type, id): - self.writeByte(type) - self.writeI16(id) + def writeFieldBegin(self, name: str, ttype: int, fid: int) -> None: + self.writeByte(ttype) + self.writeI16(fid) - def writeFieldEnd(self): + def writeFieldEnd(self) -> None: pass - def writeFieldStop(self): + def writeFieldStop(self) -> None: self.writeByte(TType.STOP) - def writeMapBegin(self, ktype, vtype, size): + def writeMapBegin(self, ktype: int, vtype: int, size: int) -> None: self.writeByte(ktype) self.writeByte(vtype) self.writeI32(size) - def writeMapEnd(self): + def writeMapEnd(self) -> None: pass - def writeListBegin(self, etype, size): + def writeListBegin(self, etype: int, size: int) -> None: self.writeByte(etype) self.writeI32(size) - def writeListEnd(self): + def writeListEnd(self) -> None: pass - def writeSetBegin(self, etype, size): + def writeSetBegin(self, etype: int, size: int) -> None: self.writeByte(etype) self.writeI32(size) - def writeSetEnd(self): + def writeSetEnd(self) -> None: pass - def writeBool(self, bool): - if bool: + def writeBool(self, bool_val: bool) -> None: + if bool_val: self.writeByte(1) else: self.writeByte(0) - def writeByte(self, byte): + def writeByte(self, byte: int) -> None: buff = pack("!b", byte) self.trans.write(buff) - def writeI16(self, i16): + def writeI16(self, i16: int) -> None: buff = pack("!h", i16) self.trans.write(buff) - def writeI32(self, i32): + def writeI32(self, i32: int) -> None: buff = pack("!i", i32) self.trans.write(buff) - def writeI64(self, i64): + def writeI64(self, i64: int) -> None: buff = pack("!q", i64) self.trans.write(buff) - def writeDouble(self, dub): + def writeDouble(self, dub: float) -> None: buff = pack("!d", dub) self.trans.write(buff) - def writeBinary(self, str): - self.writeI32(len(str)) - self.trans.write(str) + def writeBinary(self, str_val: bytes) -> None: + self.writeI32(len(str_val)) + self.trans.write(str_val) - def readMessageBegin(self): + def readMessageBegin(self) -> tuple[str, int, int]: sz = self.readI32() if sz < 0: version = sz & TBinaryProtocol.VERSION_MASK @@ -151,85 +167,85 @@ def readMessageBegin(self): seqid = self.readI32() return (name, type, seqid) - def readMessageEnd(self): + def readMessageEnd(self) -> None: pass - def readStructBegin(self): + def readStructBegin(self) -> str | None: pass - def readStructEnd(self): + def readStructEnd(self) -> None: pass - def readFieldBegin(self): + def readFieldBegin(self) -> tuple[str | None, int, int]: type = self.readByte() if type == TType.STOP: return (None, type, 0) id = self.readI16() return (None, type, id) - def readFieldEnd(self): + def readFieldEnd(self) -> None: pass - def readMapBegin(self): + def readMapBegin(self) -> tuple[int, int, int]: ktype = self.readByte() vtype = self.readByte() size = self.readI32() self._check_container_length(size) return (ktype, vtype, size) - def readMapEnd(self): + def readMapEnd(self) -> None: pass - def readListBegin(self): + def readListBegin(self) -> tuple[int, int]: etype = self.readByte() size = self.readI32() self._check_container_length(size) return (etype, size) - def readListEnd(self): + def readListEnd(self) -> None: pass - def readSetBegin(self): + def readSetBegin(self) -> tuple[int, int]: etype = self.readByte() size = self.readI32() self._check_container_length(size) return (etype, size) - def readSetEnd(self): + def readSetEnd(self) -> None: pass - def readBool(self): + def readBool(self) -> bool: byte = self.readByte() if byte == 0: return False return True - def readByte(self): + def readByte(self) -> int: buff = self.trans.readAll(1) val, = unpack('!b', buff) return val - def readI16(self): + def readI16(self) -> int: buff = self.trans.readAll(2) val, = unpack('!h', buff) return val - def readI32(self): + def readI32(self) -> int: buff = self.trans.readAll(4) val, = unpack('!i', buff) return val - def readI64(self): + def readI64(self) -> int: buff = self.trans.readAll(8) val, = unpack('!q', buff) return val - def readDouble(self): + def readDouble(self) -> float: buff = self.trans.readAll(8) val, = unpack('!d', buff) return val - def readBinary(self): + def readBinary(self) -> bytes: size = self.readI32() self._check_string_length(size) s = self.trans.readAll(size) @@ -237,13 +253,23 @@ def readBinary(self): class TBinaryProtocolFactory(TProtocolFactory): - def __init__(self, strictRead=False, strictWrite=True, **kwargs): + strictRead: bool + strictWrite: bool + string_length_limit: int | None + container_length_limit: int | None + + def __init__( + self, + strictRead: bool = False, + strictWrite: bool = True, + **kwargs: Any, + ) -> None: self.strictRead = strictRead self.strictWrite = strictWrite self.string_length_limit = kwargs.get('string_length_limit', None) self.container_length_limit = kwargs.get('container_length_limit', None) - def getProtocol(self, trans): + def getProtocol(self, trans: TTransportBase) -> TBinaryProtocol: prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite, string_length_limit=self.string_length_limit, container_length_limit=self.container_length_limit) @@ -270,13 +296,12 @@ class TBinaryProtocolAccelerated(TBinaryProtocol): Please feel free to report bugs and/or success stories to the public mailing list. """ - pass - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: fallback = kwargs.pop('fallback', True) super(TBinaryProtocolAccelerated, self).__init__(*args, **kwargs) try: - from thrift.protocol import fastbinary + from thrift.protocol import fastbinary # type: ignore[attr-defined] except ImportError: if not fallback: raise @@ -286,15 +311,21 @@ def __init__(self, *args, **kwargs): class TBinaryProtocolAcceleratedFactory(TProtocolFactory): - def __init__(self, - string_length_limit=None, - container_length_limit=None, - fallback=True): + string_length_limit: int | None + container_length_limit: int | None + _fallback: bool + + def __init__( + self, + string_length_limit: int | None = None, + container_length_limit: int | None = None, + fallback: bool = True, + ) -> None: self.string_length_limit = string_length_limit self.container_length_limit = container_length_limit self._fallback = fallback - def getProtocol(self, trans): + def getProtocol(self, trans: TTransportBase) -> TBinaryProtocolAccelerated: return TBinaryProtocolAccelerated( trans, string_length_limit=self.string_length_limit, diff --git a/lib/py/src/protocol/TCompactProtocol.py b/lib/py/src/thrift/protocol/TCompactProtocol.py similarity index 70% rename from lib/py/src/protocol/TCompactProtocol.py rename to lib/py/src/thrift/protocol/TCompactProtocol.py index a3527cd47a3..0872e3de49e 100644 --- a/lib/py/src/protocol/TCompactProtocol.py +++ b/lib/py/src/thrift/protocol/TCompactProtocol.py @@ -17,25 +17,32 @@ # under the License. # -from .TProtocol import TType, TProtocolBase, TProtocolException, TProtocolFactory, checkIntegerLimits +from __future__ import annotations + from struct import pack, unpack +from io import BytesIO +from typing import Any, Callable + +from thrift.protocol.TProtocol import TProtocolBase, TProtocolException, TProtocolFactory, checkIntegerLimits +from thrift.Thrift import TType +from thrift.transport.TTransport import TTransportBase -__all__ = ['TCompactProtocol', 'TCompactProtocolFactory'] +__all__ = ['TCompactProtocol', 'TCompactProtocolFactory', 'writeVarint'] -CLEAR = 0 -FIELD_WRITE = 1 -VALUE_WRITE = 2 -CONTAINER_WRITE = 3 -BOOL_WRITE = 4 -FIELD_READ = 5 -CONTAINER_READ = 6 -VALUE_READ = 7 -BOOL_READ = 8 +CLEAR: int = 0 +FIELD_WRITE: int = 1 +VALUE_WRITE: int = 2 +CONTAINER_WRITE: int = 3 +BOOL_WRITE: int = 4 +FIELD_READ: int = 5 +CONTAINER_READ: int = 6 +VALUE_READ: int = 7 +BOOL_READ: int = 8 -def make_helper(v_from, container): - def helper(func): - def nested(self, *args, **kwargs): +def make_helper(v_from: int, container: int) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def helper(func: Callable[..., Any]) -> Callable[..., Any]: + def nested(self: TCompactProtocol, *args: Any, **kwargs: Any) -> Any: assert self.state in (v_from, container), (self.state, v_from, container) return func(self, *args, **kwargs) return nested @@ -46,16 +53,16 @@ def nested(self, *args, **kwargs): reader = make_helper(VALUE_READ, CONTAINER_READ) -def makeZigZag(n, bits): +def makeZigZag(n: int, bits: int) -> int: checkIntegerLimits(n, bits) return (n << 1) ^ (n >> (bits - 1)) -def fromZigZag(n): +def fromZigZag(n: int) -> int: return (n >> 1) ^ -(n & 1) -def writeVarint(trans, n): +def writeVarint(trans: TTransportBase | BytesIO, n: int) -> None: assert n >= 0, "Input to TCompactProtocol writeVarint cannot be negative!" out = bytearray() while True: @@ -68,7 +75,7 @@ def writeVarint(trans, n): trans.write(bytes(out)) -def readVarint(trans): +def readVarint(trans: TTransportBase) -> int: result = 0 shift = 0 while True: @@ -80,23 +87,23 @@ def readVarint(trans): shift += 7 -class CompactType(object): - STOP = 0x00 - TRUE = 0x01 - FALSE = 0x02 - BYTE = 0x03 - I16 = 0x04 - I32 = 0x05 - I64 = 0x06 - DOUBLE = 0x07 - BINARY = 0x08 - LIST = 0x09 - SET = 0x0A - MAP = 0x0B - STRUCT = 0x0C +class CompactType: + STOP: int = 0x00 + TRUE: int = 0x01 + FALSE: int = 0x02 + BYTE: int = 0x03 + I16: int = 0x04 + I32: int = 0x05 + I64: int = 0x06 + DOUBLE: int = 0x07 + BINARY: int = 0x08 + LIST: int = 0x09 + SET: int = 0x0A + MAP: int = 0x0B + STRUCT: int = 0x0C -CTYPES = { +CTYPES: dict[int, int] = { TType.STOP: CompactType.STOP, TType.BOOL: CompactType.TRUE, # used for collection TType.BYTE: CompactType.BYTE, @@ -111,27 +118,35 @@ class CompactType(object): TType.MAP: CompactType.MAP, } -TTYPES = {} -for k, v in CTYPES.items(): - TTYPES[v] = k +TTYPES: dict[int, int] = {v: k for k, v in CTYPES.items()} TTYPES[CompactType.FALSE] = TType.BOOL -del k -del v class TCompactProtocol(TProtocolBase): """Compact implementation of the Thrift protocol driver.""" - PROTOCOL_ID = 0x82 - VERSION = 1 - VERSION_MASK = 0x1f - TYPE_MASK = 0xe0 - TYPE_BITS = 0x07 - TYPE_SHIFT_AMOUNT = 5 - - def __init__(self, trans, - string_length_limit=None, - container_length_limit=None): + PROTOCOL_ID: int = 0x82 + VERSION: int = 1 + VERSION_MASK: int = 0x1f + TYPE_MASK: int = 0xe0 + TYPE_BITS: int = 0x07 + TYPE_SHIFT_AMOUNT: int = 5 + + state: int + string_length_limit: int | None + container_length_limit: int | None + __last_fid: int + __bool_fid: int | None + __bool_value: bool | int | None + __structs: list[tuple[int, int]] + __containers: list[int] + + def __init__( + self, + trans: TTransportBase, + string_length_limit: int | None = None, + container_length_limit: int | None = None, + ) -> None: TProtocolBase.__init__(self, trans) self.state = CLEAR self.__last_fid = 0 @@ -142,19 +157,19 @@ def __init__(self, trans, self.string_length_limit = string_length_limit self.container_length_limit = container_length_limit - def _check_string_length(self, length): + def _check_string_length(self, length: int) -> None: self._check_length(self.string_length_limit, length) - def _check_container_length(self, length): + def _check_container_length(self, length: int) -> None: self._check_length(self.container_length_limit, length) - def __writeVarint(self, n): + def __writeVarint(self, n: int) -> None: writeVarint(self.trans, n) - def writeMessageBegin(self, name, type, seqid): + def writeMessageBegin(self, name: str, ttype: int, seqid: int) -> None: assert self.state == CLEAR self.__writeUByte(self.PROTOCOL_ID) - self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT)) + self.__writeUByte(self.VERSION | (ttype << self.TYPE_SHIFT_AMOUNT)) # The sequence id is a signed 32-bit integer but the compact protocol # writes this out as a "var int" which is always positive, and attempting # to write a negative number results in an infinite loop, so we may @@ -166,24 +181,24 @@ def writeMessageBegin(self, name, type, seqid): self.__writeBinary(bytes(name, 'utf-8')) self.state = VALUE_WRITE - def writeMessageEnd(self): + def writeMessageEnd(self) -> None: assert self.state == VALUE_WRITE self.state = CLEAR - def writeStructBegin(self, name): + def writeStructBegin(self, name: str) -> None: assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state self.__structs.append((self.state, self.__last_fid)) self.state = FIELD_WRITE self.__last_fid = 0 - def writeStructEnd(self): + def writeStructEnd(self) -> None: assert self.state == FIELD_WRITE self.state, self.__last_fid = self.__structs.pop() - def writeFieldStop(self): + def writeFieldStop(self) -> None: self.__writeByte(0) - def __writeFieldHeader(self, type, fid): + def __writeFieldHeader(self, type: int, fid: int) -> None: delta = fid - self.__last_fid if 0 < delta <= 15: self.__writeUByte(delta << 4 | type) @@ -192,32 +207,32 @@ def __writeFieldHeader(self, type, fid): self.__writeI16(fid) self.__last_fid = fid - def writeFieldBegin(self, name, type, fid): + def writeFieldBegin(self, name: str, ttype: int, fid: int) -> None: assert self.state == FIELD_WRITE, self.state - if type == TType.BOOL: + if ttype == TType.BOOL: self.state = BOOL_WRITE self.__bool_fid = fid else: self.state = VALUE_WRITE - self.__writeFieldHeader(CTYPES[type], fid) + self.__writeFieldHeader(CTYPES[ttype], fid) - def writeFieldEnd(self): + def writeFieldEnd(self) -> None: assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state self.state = FIELD_WRITE - def __writeUByte(self, byte): + def __writeUByte(self, byte: int) -> None: self.trans.write(pack('!B', byte)) - def __writeByte(self, byte): + def __writeByte(self, byte: int) -> None: self.trans.write(pack('!b', byte)) - def __writeI16(self, i16): + def __writeI16(self, i16: int) -> None: self.__writeVarint(makeZigZag(i16, 16)) - def __writeSize(self, i32): + def __writeSize(self, i32: int) -> None: self.__writeVarint(i32) - def writeCollectionBegin(self, etype, size): + def writeCollectionBegin(self, etype: int, size: int) -> None: assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state if size <= 14: self.__writeUByte(size << 4 | CTYPES[etype]) @@ -226,10 +241,11 @@ def writeCollectionBegin(self, etype, size): self.__writeSize(size) self.__containers.append(self.state) self.state = CONTAINER_WRITE + writeSetBegin = writeCollectionBegin writeListBegin = writeCollectionBegin - def writeMapBegin(self, ktype, vtype, size): + def writeMapBegin(self, ktype: int, vtype: int, size: int) -> None: assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state if size == 0: self.__writeByte(0) @@ -239,22 +255,23 @@ def writeMapBegin(self, ktype, vtype, size): self.__containers.append(self.state) self.state = CONTAINER_WRITE - def writeCollectionEnd(self): + def writeCollectionEnd(self) -> None: assert self.state == CONTAINER_WRITE, self.state self.state = self.__containers.pop() + writeMapEnd = writeCollectionEnd writeSetEnd = writeCollectionEnd writeListEnd = writeCollectionEnd - def writeBool(self, bool): + def writeBool(self, bool_val: bool) -> None: if self.state == BOOL_WRITE: - if bool: + if bool_val: ctype = CompactType.TRUE else: ctype = CompactType.FALSE - self.__writeFieldHeader(ctype, self.__bool_fid) + self.__writeFieldHeader(ctype, self.__bool_fid) # type: ignore[arg-type] elif self.state == CONTAINER_WRITE: - if bool: + if bool_val: self.__writeByte(CompactType.TRUE) else: self.__writeByte(CompactType.FALSE) @@ -265,23 +282,24 @@ def writeBool(self, bool): writeI16 = writer(__writeI16) @writer - def writeI32(self, i32): + def writeI32(self, i32: int) -> None: self.__writeVarint(makeZigZag(i32, 32)) @writer - def writeI64(self, i64): + def writeI64(self, i64: int) -> None: self.__writeVarint(makeZigZag(i64, 64)) @writer - def writeDouble(self, dub): + def writeDouble(self, dub: float) -> None: self.trans.write(pack(' None: self.__writeSize(len(s)) self.trans.write(s) + writeBinary = writer(__writeBinary) - def readFieldBegin(self): + def readFieldBegin(self) -> tuple[str | None, int, int]: assert self.state == FIELD_READ, self.state type = self.__readUByte() if type & 0x0f == TType.STOP: @@ -303,31 +321,31 @@ def readFieldBegin(self): self.state = VALUE_READ return (None, self.__getTType(type), fid) - def readFieldEnd(self): + def readFieldEnd(self) -> None: assert self.state in (VALUE_READ, BOOL_READ), self.state self.state = FIELD_READ - def __readUByte(self): + def __readUByte(self) -> int: result, = unpack('!B', self.trans.readAll(1)) return result - def __readByte(self): + def __readByte(self) -> int: result, = unpack('!b', self.trans.readAll(1)) return result - def __readVarint(self): + def __readVarint(self) -> int: return readVarint(self.trans) - def __readZigZag(self): + def __readZigZag(self) -> int: return fromZigZag(self.__readVarint()) - def __readSize(self): + def __readSize(self) -> int: result = self.__readVarint() if result < 0: - raise TProtocolException("Length < 0") + raise TProtocolException(TProtocolException.NEGATIVE_SIZE, "Length < 0") return result - def readMessageBegin(self): + def readMessageBegin(self) -> tuple[str, int, int]: assert self.state == CLEAR proto_id = self.__readUByte() if proto_id != self.PROTOCOL_ID: @@ -347,21 +365,22 @@ def readMessageBegin(self): name = self.__readBinary().decode('utf-8') return (name, type, seqid) - def readMessageEnd(self): + def readMessageEnd(self) -> None: assert self.state == CLEAR assert len(self.__structs) == 0 - def readStructBegin(self): + def readStructBegin(self) -> str | None: assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state self.__structs.append((self.state, self.__last_fid)) self.state = FIELD_READ self.__last_fid = 0 + return None - def readStructEnd(self): + def readStructEnd(self) -> None: assert self.state == FIELD_READ self.state, self.__last_fid = self.__structs.pop() - def readCollectionBegin(self): + def readCollectionBegin(self) -> tuple[int, int]: assert self.state in (VALUE_READ, CONTAINER_READ), self.state size_type = self.__readUByte() size = size_type >> 4 @@ -372,10 +391,11 @@ def readCollectionBegin(self): self.__containers.append(self.state) self.state = CONTAINER_READ return type, size + readSetBegin = readCollectionBegin readListBegin = readCollectionBegin - def readMapBegin(self): + def readMapBegin(self) -> tuple[int, int, int]: assert self.state in (VALUE_READ, CONTAINER_READ), self.state size = self.__readSize() self._check_container_length(size) @@ -388,16 +408,17 @@ def readMapBegin(self): self.state = CONTAINER_READ return (ktype, vtype, size) - def readCollectionEnd(self): + def readCollectionEnd(self) -> None: assert self.state == CONTAINER_READ, self.state self.state = self.__containers.pop() + readSetEnd = readCollectionEnd readListEnd = readCollectionEnd readMapEnd = readCollectionEnd - def readBool(self): + def readBool(self) -> bool: if self.state == BOOL_READ: - return self.__bool_value == CompactType.TRUE + return self.__bool_value == True # noqa: E712 elif self.state == CONTAINER_READ: return self.__readByte() == CompactType.TRUE else: @@ -411,29 +432,35 @@ def readBool(self): readI64 = reader(__readZigZag) @reader - def readDouble(self): + def readDouble(self) -> float: buff = self.trans.readAll(8) val, = unpack(' bytes: size = self.__readSize() self._check_string_length(size) return self.trans.readAll(size) + readBinary = reader(__readBinary) - def __getTType(self, byte): + def __getTType(self, byte: int) -> int: return TTYPES[byte & 0x0f] class TCompactProtocolFactory(TProtocolFactory): - def __init__(self, - string_length_limit=None, - container_length_limit=None): + string_length_limit: int | None + container_length_limit: int | None + + def __init__( + self, + string_length_limit: int | None = None, + container_length_limit: int | None = None, + ) -> None: self.string_length_limit = string_length_limit self.container_length_limit = container_length_limit - def getProtocol(self, trans): + def getProtocol(self, trans: TTransportBase) -> TCompactProtocol: return TCompactProtocol(trans, self.string_length_limit, self.container_length_limit) @@ -453,13 +480,12 @@ class TCompactProtocolAccelerated(TCompactProtocol): In order to take advantage of the C module, just use TCompactProtocolAccelerated instead of TCompactProtocol. """ - pass - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: fallback = kwargs.pop('fallback', True) super(TCompactProtocolAccelerated, self).__init__(*args, **kwargs) try: - from thrift.protocol import fastbinary + from thrift.protocol import fastbinary # type: ignore[attr-defined] except ImportError: if not fallback: raise @@ -469,15 +495,21 @@ def __init__(self, *args, **kwargs): class TCompactProtocolAcceleratedFactory(TProtocolFactory): - def __init__(self, - string_length_limit=None, - container_length_limit=None, - fallback=True): + string_length_limit: int | None + container_length_limit: int | None + _fallback: bool + + def __init__( + self, + string_length_limit: int | None = None, + container_length_limit: int | None = None, + fallback: bool = True, + ) -> None: self.string_length_limit = string_length_limit self.container_length_limit = container_length_limit self._fallback = fallback - def getProtocol(self, trans): + def getProtocol(self, trans: TTransportBase) -> TCompactProtocolAccelerated: return TCompactProtocolAccelerated( trans, string_length_limit=self.string_length_limit, diff --git a/lib/py/src/protocol/THeaderProtocol.py b/lib/py/src/thrift/protocol/THeaderProtocol.py similarity index 68% rename from lib/py/src/protocol/THeaderProtocol.py rename to lib/py/src/thrift/protocol/THeaderProtocol.py index 4b58e639da2..7a21901f1b8 100644 --- a/lib/py/src/protocol/THeaderProtocol.py +++ b/lib/py/src/thrift/protocol/THeaderProtocol.py @@ -17,14 +17,17 @@ # under the License. # +from __future__ import annotations + from thrift.protocol.TBinaryProtocol import TBinaryProtocolAccelerated from thrift.protocol.TCompactProtocol import TCompactProtocolAccelerated from thrift.protocol.TProtocol import TProtocolBase, TProtocolException, TProtocolFactory from thrift.Thrift import TApplicationException, TMessageType from thrift.transport.THeaderTransport import THeaderTransport, THeaderSubprotocolID, THeaderClientType +from thrift.transport.TTransport import TTransportBase -PROTOCOLS_BY_ID = { +PROTOCOLS_BY_ID: dict[int, type[TBinaryProtocolAccelerated] | type[TCompactProtocolAccelerated]] = { THeaderSubprotocolID.BINARY: TBinaryProtocolAccelerated, THeaderSubprotocolID.COMPACT: TCompactProtocolAccelerated, } @@ -55,7 +58,15 @@ class THeaderProtocol(TProtocolBase): """ - def __init__(self, transport, allowed_client_types, default_protocol=THeaderSubprotocolID.BINARY): + trans: THeaderTransport + _protocol: TBinaryProtocolAccelerated | TCompactProtocolAccelerated + + def __init__( + self, + transport: TTransportBase | THeaderTransport, + allowed_client_types: tuple[int, ...], + default_protocol: int = THeaderSubprotocolID.BINARY, + ) -> None: # much of the actual work for THeaderProtocol happens down in # THeaderTransport since we need to do low-level shenanigans to detect # if the client is sending us headers or one of the headerless formats @@ -66,80 +77,80 @@ def __init__(self, transport, allowed_client_types, default_protocol=THeaderSubp super(THeaderProtocol, self).__init__(transport) self._set_protocol() - def get_headers(self): + def get_headers(self) -> dict[bytes, bytes]: return self.trans.get_headers() - def set_header(self, key, value): + def set_header(self, key: bytes, value: bytes) -> None: self.trans.set_header(key, value) - def clear_headers(self): + def clear_headers(self) -> None: self.trans.clear_headers() - def add_transform(self, transform_id): + def add_transform(self, transform_id: int) -> None: self.trans.add_transform(transform_id) - def writeMessageBegin(self, name, ttype, seqid): + def writeMessageBegin(self, name: str, ttype: int, seqid: int) -> None: self.trans.sequence_id = seqid return self._protocol.writeMessageBegin(name, ttype, seqid) - def writeMessageEnd(self): + def writeMessageEnd(self) -> None: return self._protocol.writeMessageEnd() - def writeStructBegin(self, name): + def writeStructBegin(self, name: str) -> None: return self._protocol.writeStructBegin(name) - def writeStructEnd(self): + def writeStructEnd(self) -> None: return self._protocol.writeStructEnd() - def writeFieldBegin(self, name, ttype, fid): + def writeFieldBegin(self, name: str, ttype: int, fid: int) -> None: return self._protocol.writeFieldBegin(name, ttype, fid) - def writeFieldEnd(self): + def writeFieldEnd(self) -> None: return self._protocol.writeFieldEnd() - def writeFieldStop(self): + def writeFieldStop(self) -> None: return self._protocol.writeFieldStop() - def writeMapBegin(self, ktype, vtype, size): + def writeMapBegin(self, ktype: int, vtype: int, size: int) -> None: return self._protocol.writeMapBegin(ktype, vtype, size) - def writeMapEnd(self): + def writeMapEnd(self) -> None: return self._protocol.writeMapEnd() - def writeListBegin(self, etype, size): + def writeListBegin(self, etype: int, size: int) -> None: return self._protocol.writeListBegin(etype, size) - def writeListEnd(self): + def writeListEnd(self) -> None: return self._protocol.writeListEnd() - def writeSetBegin(self, etype, size): + def writeSetBegin(self, etype: int, size: int) -> None: return self._protocol.writeSetBegin(etype, size) - def writeSetEnd(self): + def writeSetEnd(self) -> None: return self._protocol.writeSetEnd() - def writeBool(self, bool_val): + def writeBool(self, bool_val: bool) -> None: return self._protocol.writeBool(bool_val) - def writeByte(self, byte): + def writeByte(self, byte: int) -> None: return self._protocol.writeByte(byte) - def writeI16(self, i16): + def writeI16(self, i16: int) -> None: return self._protocol.writeI16(i16) - def writeI32(self, i32): + def writeI32(self, i32: int) -> None: return self._protocol.writeI32(i32) - def writeI64(self, i64): + def writeI64(self, i64: int) -> None: return self._protocol.writeI64(i64) - def writeDouble(self, dub): + def writeDouble(self, dub: float) -> None: return self._protocol.writeDouble(dub) - def writeBinary(self, str_val): + def writeBinary(self, str_val: bytes) -> None: return self._protocol.writeBinary(str_val) - def _set_protocol(self): + def _set_protocol(self) -> None: try: protocol_cls = PROTOCOLS_BY_ID[self.trans.protocol_id] except KeyError: @@ -152,81 +163,84 @@ def _set_protocol(self): self._fast_encode = self._protocol._fast_encode self._fast_decode = self._protocol._fast_decode - def readMessageBegin(self): + def readMessageBegin(self) -> tuple[str, int, int]: try: self.trans.readFrame(0) self._set_protocol() except TApplicationException as exc: - self._protocol.writeMessageBegin(b"", TMessageType.EXCEPTION, 0) + self._protocol.writeMessageBegin(b"", TMessageType.EXCEPTION, 0) # type: ignore[arg-type] exc.write(self._protocol) self._protocol.writeMessageEnd() self.trans.flush() return self._protocol.readMessageBegin() - def readMessageEnd(self): + def readMessageEnd(self) -> None: return self._protocol.readMessageEnd() - def readStructBegin(self): + def readStructBegin(self) -> str | None: return self._protocol.readStructBegin() - def readStructEnd(self): + def readStructEnd(self) -> None: return self._protocol.readStructEnd() - def readFieldBegin(self): + def readFieldBegin(self) -> tuple[str | None, int, int]: return self._protocol.readFieldBegin() - def readFieldEnd(self): + def readFieldEnd(self) -> None: return self._protocol.readFieldEnd() - def readMapBegin(self): + def readMapBegin(self) -> tuple[int, int, int]: return self._protocol.readMapBegin() - def readMapEnd(self): + def readMapEnd(self) -> None: return self._protocol.readMapEnd() - def readListBegin(self): + def readListBegin(self) -> tuple[int, int]: return self._protocol.readListBegin() - def readListEnd(self): + def readListEnd(self) -> None: return self._protocol.readListEnd() - def readSetBegin(self): + def readSetBegin(self) -> tuple[int, int]: return self._protocol.readSetBegin() - def readSetEnd(self): + def readSetEnd(self) -> None: return self._protocol.readSetEnd() - def readBool(self): + def readBool(self) -> bool: return self._protocol.readBool() - def readByte(self): + def readByte(self) -> int: return self._protocol.readByte() - def readI16(self): + def readI16(self) -> int: return self._protocol.readI16() - def readI32(self): + def readI32(self) -> int: return self._protocol.readI32() - def readI64(self): + def readI64(self) -> int: return self._protocol.readI64() - def readDouble(self): + def readDouble(self) -> float: return self._protocol.readDouble() - def readBinary(self): + def readBinary(self) -> bytes: return self._protocol.readBinary() class THeaderProtocolFactory(TProtocolFactory): + allowed_client_types: tuple[int, ...] + default_protocol: int + def __init__( self, - allowed_client_types=(THeaderClientType.HEADERS,), - default_protocol=THeaderSubprotocolID.BINARY, - ): + allowed_client_types: tuple[int, ...] = (THeaderClientType.HEADERS,), + default_protocol: int = THeaderSubprotocolID.BINARY, + ) -> None: self.allowed_client_types = allowed_client_types self.default_protocol = default_protocol - def getProtocol(self, trans): + def getProtocol(self, trans: TTransportBase) -> THeaderProtocol: return THeaderProtocol(trans, self.allowed_client_types, self.default_protocol) diff --git a/lib/py/src/protocol/TJSONProtocol.py b/lib/py/src/thrift/protocol/TJSONProtocol.py similarity index 64% rename from lib/py/src/protocol/TJSONProtocol.py rename to lib/py/src/thrift/protocol/TJSONProtocol.py index a42aaa6315d..cbddf6b8f18 100644 --- a/lib/py/src/protocol/TJSONProtocol.py +++ b/lib/py/src/thrift/protocol/TJSONProtocol.py @@ -17,11 +17,14 @@ # under the License. # -from .TProtocol import (TType, TProtocolBase, TProtocolException, - TProtocolFactory, checkIntegerLimits) +from __future__ import annotations + import base64 import math -import sys + +from thrift.protocol.TProtocol import TProtocolBase, TProtocolException, TProtocolFactory, checkIntegerLimits +from thrift.Thrift import TType +from thrift.transport.TTransport import TTransportBase __all__ = ['TJSONProtocol', @@ -29,21 +32,21 @@ 'TSimpleJSONProtocol', 'TSimpleJSONProtocolFactory'] -VERSION = 1 - -COMMA = b',' -COLON = b':' -LBRACE = b'{' -RBRACE = b'}' -LBRACKET = b'[' -RBRACKET = b']' -QUOTE = b'"' -BACKSLASH = b'\\' -ZERO = b'0' - -ESCSEQ0 = ord('\\') -ESCSEQ1 = ord('u') -ESCAPE_CHAR_VALS = { +VERSION: int = 1 + +COMMA: bytes = b',' +COLON: bytes = b':' +LBRACE: bytes = b'{' +RBRACE: bytes = b'}' +LBRACKET: bytes = b'[' +RBRACKET: bytes = b']' +QUOTE: bytes = b'"' +BACKSLASH: bytes = b'\\' +ZERO: bytes = b'0' + +ESCSEQ0: int = ord('\\') +ESCSEQ1: int = ord('u') +ESCAPE_CHAR_VALS: dict[str, str] = { '"': '\\"', '\\': '\\\\', '\b': '\\b', @@ -53,7 +56,7 @@ '\t': '\\t', # '/': '\\/', } -ESCAPE_CHARS = { +ESCAPE_CHARS: dict[bytes, str] = { b'"': '"', b'\\': '\\', b'b': '\b', @@ -63,9 +66,9 @@ b't': '\t', b'/': '/', } -NUMERIC_CHAR = b'+-.0123456789Ee' +NUMERIC_CHAR: bytes = b'+-.0123456789Ee' -CTYPES = { +CTYPES: dict[int, str] = { TType.BOOL: 'tf', TType.BYTE: 'i8', TType.I16: 'i16', @@ -79,90 +82,94 @@ TType.MAP: 'map', } -JTYPES = {} +JTYPES: dict[str, int] = {} for key in CTYPES.keys(): JTYPES[CTYPES[key]] = key -class JSONBaseContext(object): +class JSONBaseContext: + protocol: TJSONProtocolBase + first: bool - def __init__(self, protocol): + def __init__(self, protocol: TJSONProtocolBase) -> None: self.protocol = protocol self.first = True - def doIO(self, function): + def doIO(self, function: object) -> None: pass - def write(self): + def write(self) -> None: pass - def read(self): + def read(self) -> None: pass - def escapeNum(self): + def escapeNum(self) -> bool: return False - def __str__(self): + def __str__(self) -> str: return self.__class__.__name__ class JSONListContext(JSONBaseContext): - def doIO(self, function): + def doIO(self, function: object) -> None: if self.first is True: self.first = False else: - function(COMMA) + function(COMMA) # type: ignore[operator] - def write(self): + def write(self) -> None: self.doIO(self.protocol.trans.write) - def read(self): + def read(self) -> None: self.doIO(self.protocol.readJSONSyntaxChar) class JSONPairContext(JSONBaseContext): + colon: bool - def __init__(self, protocol): + def __init__(self, protocol: TJSONProtocolBase) -> None: super(JSONPairContext, self).__init__(protocol) self.colon = True - def doIO(self, function): + def doIO(self, function: object) -> None: if self.first: self.first = False self.colon = True else: - function(COLON if self.colon else COMMA) + function(COLON if self.colon else COMMA) # type: ignore[operator] self.colon = not self.colon - def write(self): + def write(self) -> None: self.doIO(self.protocol.trans.write) - def read(self): + def read(self) -> None: self.doIO(self.protocol.readJSONSyntaxChar) - def escapeNum(self): + def escapeNum(self) -> bool: return self.colon - def __str__(self): + def __str__(self) -> str: return '%s, colon=%s' % (self.__class__.__name__, self.colon) -class LookaheadReader(): - hasData = False - data = '' +class LookaheadReader: + hasData: bool = False + data: bytes = b'' + protocol: TJSONProtocolBase - def __init__(self, protocol): + def __init__(self, protocol: TJSONProtocolBase) -> None: self.protocol = protocol - def read(self): + def read(self) -> bytes: if self.hasData is True: self.hasData = False else: self.data = self.protocol.trans.read(1) return self.data - def peek(self): + def peek(self) -> bytes: if self.hasData is False: self.data = self.protocol.trans.read(1) self.hasData = True @@ -170,41 +177,44 @@ def peek(self): class TJSONProtocolBase(TProtocolBase): + context: JSONBaseContext + contextStack: list[JSONBaseContext] + reader: LookaheadReader - def __init__(self, trans): + def __init__(self, trans: TTransportBase) -> None: TProtocolBase.__init__(self, trans) self.resetWriteContext() self.resetReadContext() # We don't have length limit implementation for JSON protocols @property - def string_length_limit(senf): + def string_length_limit(self) -> None: return None @property - def container_length_limit(senf): + def container_length_limit(self) -> None: return None - def resetWriteContext(self): + def resetWriteContext(self) -> None: self.context = JSONBaseContext(self) self.contextStack = [self.context] - def resetReadContext(self): + def resetReadContext(self) -> None: self.resetWriteContext() self.reader = LookaheadReader(self) - def pushContext(self, ctx): + def pushContext(self, ctx: JSONBaseContext) -> None: self.contextStack.append(ctx) self.context = ctx - def popContext(self): + def popContext(self) -> None: self.contextStack.pop() if self.contextStack: self.context = self.contextStack[-1] else: self.context = JSONBaseContext(self) - def writeJSONString(self, string): + def writeJSONString(self, string: str) -> None: self.context.write() json_str = ['"'] for s in string: @@ -213,7 +223,7 @@ def writeJSONString(self, string): json_str.append('"') self.trans.write(bytes(''.join(json_str), 'utf-8')) - def writeJSONNumber(self, number, formatter='{0}'): + def writeJSONNumber(self, number: int | float, formatter: str = '{0}') -> None: self.context.write() jsNumber = str(formatter.format(number)).encode('ascii') if self.context.escapeNum(): @@ -223,43 +233,43 @@ def writeJSONNumber(self, number, formatter='{0}'): else: self.trans.write(jsNumber) - def writeJSONBase64(self, binary): + def writeJSONBase64(self, binary: bytes) -> None: self.context.write() self.trans.write(QUOTE) self.trans.write(base64.b64encode(binary)) self.trans.write(QUOTE) - def writeJSONObjectStart(self): + def writeJSONObjectStart(self) -> None: self.context.write() self.trans.write(LBRACE) self.pushContext(JSONPairContext(self)) - def writeJSONObjectEnd(self): + def writeJSONObjectEnd(self) -> None: self.popContext() self.trans.write(RBRACE) - def writeJSONArrayStart(self): + def writeJSONArrayStart(self) -> None: self.context.write() self.trans.write(LBRACKET) self.pushContext(JSONListContext(self)) - def writeJSONArrayEnd(self): + def writeJSONArrayEnd(self) -> None: self.popContext() self.trans.write(RBRACKET) - def readJSONSyntaxChar(self, character): + def readJSONSyntaxChar(self, character: bytes) -> None: current = self.reader.read() if character != current: raise TProtocolException(TProtocolException.INVALID_DATA, "Unexpected character: %s" % current) - def _isHighSurrogate(self, codeunit): + def _isHighSurrogate(self, codeunit: int) -> bool: return codeunit >= 0xd800 and codeunit <= 0xdbff - def _isLowSurrogate(self, codeunit): + def _isLowSurrogate(self, codeunit: int) -> bool: return codeunit >= 0xdc00 and codeunit <= 0xdfff - def _toChar(self, high, low=None): + def _toChar(self, high: int, low: int | None = None) -> str: if not low: return chr(high) else: @@ -267,9 +277,9 @@ def _toChar(self, high, low=None): codepoint += low & 0x3ff return chr(codepoint) - def readJSONString(self, skipContext): - highSurrogate = None - string = [] + def readJSONString(self, skipContext: bool) -> str: + highSurrogate: int | None = None + string: list[str] = [] if skipContext is False: self.context.read() self.readJSONSyntaxChar(QUOTE) @@ -280,8 +290,8 @@ def readJSONString(self, skipContext): if ord(character) == ESCSEQ0: character = self.reader.read() if ord(character) == ESCSEQ1: - character = self.trans.read(4).decode('ascii') - codeunit = int(character, 16) + char_str = self.trans.read(4).decode('ascii') + codeunit = int(char_str, 16) if self._isHighSurrogate(codeunit): if highSurrogate: raise TProtocolException( @@ -294,40 +304,40 @@ def readJSONString(self, skipContext): raise TProtocolException( TProtocolException.INVALID_DATA, "Expected high surrogate char") - character = self._toChar(highSurrogate, codeunit) + char_result = self._toChar(highSurrogate, codeunit) highSurrogate = None else: - character = self._toChar(codeunit) + char_result = self._toChar(codeunit) else: if character not in ESCAPE_CHARS: raise TProtocolException( TProtocolException.INVALID_DATA, "Expected control char") - character = ESCAPE_CHARS[character] - elif character in ESCAPE_CHAR_VALS: + char_result = ESCAPE_CHARS[character] + elif character.decode('latin-1') in ESCAPE_CHAR_VALS: raise TProtocolException(TProtocolException.INVALID_DATA, "Unescaped control char") else: utf8_bytes = bytearray([ord(character)]) while ord(self.reader.peek()) >= 0x80: utf8_bytes.append(ord(self.reader.read())) - character = utf8_bytes.decode('utf-8') - string.append(character) + char_result = utf8_bytes.decode('utf-8') + string.append(char_result) if highSurrogate: raise TProtocolException(TProtocolException.INVALID_DATA, "Expected low surrogate char") return ''.join(string) - def isJSONNumeric(self, character): - return (True if NUMERIC_CHAR.find(character) != - 1 else False) + def isJSONNumeric(self, character: bytes) -> bool: + return True if NUMERIC_CHAR.find(character) != -1 else False - def readJSONQuotes(self): - if (self.context.escapeNum()): + def readJSONQuotes(self) -> None: + if self.context.escapeNum(): self.readJSONSyntaxChar(QUOTE) - def readJSONNumericChars(self): - numeric = [] + def readJSONNumericChars(self) -> str: + numeric: list[bytes] = [] while True: character = self.reader.peek() if self.isJSONNumeric(character) is False: @@ -335,7 +345,7 @@ def readJSONNumericChars(self): numeric.append(self.reader.read()) return b''.join(numeric).decode('ascii') - def readJSONInteger(self): + def readJSONInteger(self) -> int: self.context.read() self.readJSONQuotes() numeric = self.readJSONNumericChars() @@ -346,13 +356,13 @@ def readJSONInteger(self): raise TProtocolException(TProtocolException.INVALID_DATA, "Bad data encounted in numeric data") - def readJSONDouble(self): + def readJSONDouble(self) -> float: self.context.read() if self.reader.peek() == QUOTE: string = self.readJSONString(True) try: double = float(string) - if (self.context.escapeNum is False and + if (self.context.escapeNum is False and # type: ignore[comparison-overlap] not math.isinf(double) and not math.isnan(double)): raise TProtocolException( @@ -371,7 +381,7 @@ def readJSONDouble(self): raise TProtocolException(TProtocolException.INVALID_DATA, "Bad data encounted in numeric data") - def readJSONBase64(self): + def readJSONBase64(self) -> bytes: string = self.readJSONString(False) size = len(string) m = size % 4 @@ -381,28 +391,28 @@ def readJSONBase64(self): string += '=' return base64.b64decode(string) - def readJSONObjectStart(self): + def readJSONObjectStart(self) -> None: self.context.read() self.readJSONSyntaxChar(LBRACE) self.pushContext(JSONPairContext(self)) - def readJSONObjectEnd(self): + def readJSONObjectEnd(self) -> None: self.readJSONSyntaxChar(RBRACE) self.popContext() - def readJSONArrayStart(self): + def readJSONArrayStart(self) -> None: self.context.read() self.readJSONSyntaxChar(LBRACKET) self.pushContext(JSONListContext(self)) - def readJSONArrayEnd(self): + def readJSONArrayEnd(self) -> None: self.readJSONSyntaxChar(RBRACKET) self.popContext() class TJSONProtocol(TJSONProtocolBase): - def readMessageBegin(self): + def readMessageBegin(self) -> tuple[str, int, int]: self.resetReadContext() self.readJSONArrayStart() if self.readJSONInteger() != VERSION: @@ -413,16 +423,17 @@ def readMessageBegin(self): seqid = self.readJSONInteger() return (name, typen, seqid) - def readMessageEnd(self): + def readMessageEnd(self) -> None: self.readJSONArrayEnd() - def readStructBegin(self): + def readStructBegin(self) -> str | None: self.readJSONObjectStart() + return None - def readStructEnd(self): + def readStructEnd(self) -> None: self.readJSONObjectEnd() - def readFieldBegin(self): + def readFieldBegin(self) -> tuple[str | None, int, int]: character = self.reader.peek() ttype = 0 id = 0 @@ -434,10 +445,10 @@ def readFieldBegin(self): ttype = JTYPES[self.readJSONString(False)] return (None, ttype, id) - def readFieldEnd(self): + def readFieldEnd(self) -> None: self.readJSONObjectEnd() - def readMapBegin(self): + def readMapBegin(self) -> tuple[int, int, int]: self.readJSONArrayStart() keyType = JTYPES[self.readJSONString(False)] valueType = JTYPES[self.readJSONString(False)] @@ -445,138 +456,141 @@ def readMapBegin(self): self.readJSONObjectStart() return (keyType, valueType, size) - def readMapEnd(self): + def readMapEnd(self) -> None: self.readJSONObjectEnd() self.readJSONArrayEnd() - def readCollectionBegin(self): + def readCollectionBegin(self) -> tuple[int, int]: self.readJSONArrayStart() elemType = JTYPES[self.readJSONString(False)] size = self.readJSONInteger() return (elemType, size) + readListBegin = readCollectionBegin readSetBegin = readCollectionBegin - def readCollectionEnd(self): + def readCollectionEnd(self) -> None: self.readJSONArrayEnd() + readSetEnd = readCollectionEnd readListEnd = readCollectionEnd - def readBool(self): - return (False if self.readJSONInteger() == 0 else True) + def readBool(self) -> bool: + return False if self.readJSONInteger() == 0 else True - def readNumber(self): + def readNumber(self) -> int: return self.readJSONInteger() + readByte = readNumber readI16 = readNumber readI32 = readNumber readI64 = readNumber - def readDouble(self): + def readDouble(self) -> float: return self.readJSONDouble() - def readString(self): + def readString(self) -> str: return self.readJSONString(False) - def readBinary(self): + def readBinary(self) -> bytes: return self.readJSONBase64() - def writeMessageBegin(self, name, request_type, seqid): + def writeMessageBegin(self, name: str, ttype: int, seqid: int) -> None: self.resetWriteContext() self.writeJSONArrayStart() self.writeJSONNumber(VERSION) self.writeJSONString(name) - self.writeJSONNumber(request_type) + self.writeJSONNumber(ttype) self.writeJSONNumber(seqid) - def writeMessageEnd(self): + def writeMessageEnd(self) -> None: self.writeJSONArrayEnd() - def writeStructBegin(self, name): + def writeStructBegin(self, name: str) -> None: self.writeJSONObjectStart() - def writeStructEnd(self): + def writeStructEnd(self) -> None: self.writeJSONObjectEnd() - def writeFieldBegin(self, name, ttype, id): - self.writeJSONNumber(id) + def writeFieldBegin(self, name: str, ttype: int, fid: int) -> None: + self.writeJSONNumber(fid) self.writeJSONObjectStart() self.writeJSONString(CTYPES[ttype]) - def writeFieldEnd(self): + def writeFieldEnd(self) -> None: self.writeJSONObjectEnd() - def writeFieldStop(self): + def writeFieldStop(self) -> None: pass - def writeMapBegin(self, ktype, vtype, size): + def writeMapBegin(self, ktype: int, vtype: int, size: int) -> None: self.writeJSONArrayStart() self.writeJSONString(CTYPES[ktype]) self.writeJSONString(CTYPES[vtype]) self.writeJSONNumber(size) self.writeJSONObjectStart() - def writeMapEnd(self): + def writeMapEnd(self) -> None: self.writeJSONObjectEnd() self.writeJSONArrayEnd() - def writeListBegin(self, etype, size): + def writeListBegin(self, etype: int, size: int) -> None: self.writeJSONArrayStart() self.writeJSONString(CTYPES[etype]) self.writeJSONNumber(size) - def writeListEnd(self): + def writeListEnd(self) -> None: self.writeJSONArrayEnd() - def writeSetBegin(self, etype, size): + def writeSetBegin(self, etype: int, size: int) -> None: self.writeJSONArrayStart() self.writeJSONString(CTYPES[etype]) self.writeJSONNumber(size) - def writeSetEnd(self): + def writeSetEnd(self) -> None: self.writeJSONArrayEnd() - def writeBool(self, boolean): - self.writeJSONNumber(1 if boolean is True else 0) + def writeBool(self, bool_val: bool) -> None: + self.writeJSONNumber(1 if bool_val is True else 0) - def writeByte(self, byte): + def writeByte(self, byte: int) -> None: checkIntegerLimits(byte, 8) self.writeJSONNumber(byte) - def writeI16(self, i16): + def writeI16(self, i16: int) -> None: checkIntegerLimits(i16, 16) self.writeJSONNumber(i16) - def writeI32(self, i32): + def writeI32(self, i32: int) -> None: checkIntegerLimits(i32, 32) self.writeJSONNumber(i32) - def writeI64(self, i64): + def writeI64(self, i64: int) -> None: checkIntegerLimits(i64, 64) self.writeJSONNumber(i64) - def writeDouble(self, dbl): + def writeDouble(self, dub: float) -> None: # 17 significant digits should be just enough for any double precision # value. - self.writeJSONNumber(dbl, '{0:.17g}') + self.writeJSONNumber(dub, '{0:.17g}') - def writeString(self, string): - self.writeJSONString(string) + def writeString(self, str_val: str) -> None: + self.writeJSONString(str_val) - def writeBinary(self, binary): - self.writeJSONBase64(binary) + def writeBinary(self, str_val: bytes) -> None: + self.writeJSONBase64(str_val) class TJSONProtocolFactory(TProtocolFactory): - def getProtocol(self, trans): + def getProtocol(self, trans: TTransportBase) -> TJSONProtocol: return TJSONProtocol(trans) @property - def string_length_limit(senf): + def string_length_limit(self) -> None: return None @property - def container_length_limit(senf): + def container_length_limit(self) -> None: return None @@ -586,82 +600,83 @@ class TSimpleJSONProtocol(TJSONProtocolBase): Useful for interacting with scripting languages. """ - def readMessageBegin(self): + def readMessageBegin(self) -> tuple[str, int, int]: raise NotImplementedError() - def readMessageEnd(self): + def readMessageEnd(self) -> None: raise NotImplementedError() - def readStructBegin(self): + def readStructBegin(self) -> str | None: raise NotImplementedError() - def readStructEnd(self): + def readStructEnd(self) -> None: raise NotImplementedError() - def writeMessageBegin(self, name, request_type, seqid): + def writeMessageBegin(self, name: str, ttype: int, seqid: int) -> None: self.resetWriteContext() - def writeMessageEnd(self): + def writeMessageEnd(self) -> None: pass - def writeStructBegin(self, name): + def writeStructBegin(self, name: str) -> None: self.writeJSONObjectStart() - def writeStructEnd(self): + def writeStructEnd(self) -> None: self.writeJSONObjectEnd() - def writeFieldBegin(self, name, ttype, fid): + def writeFieldBegin(self, name: str, ttype: int, fid: int) -> None: self.writeJSONString(name) - def writeFieldEnd(self): + def writeFieldEnd(self) -> None: pass - def writeMapBegin(self, ktype, vtype, size): + def writeMapBegin(self, ktype: int, vtype: int, size: int) -> None: self.writeJSONObjectStart() - def writeMapEnd(self): + def writeMapEnd(self) -> None: self.writeJSONObjectEnd() - def _writeCollectionBegin(self, etype, size): + def _writeCollectionBegin(self, etype: int, size: int) -> None: self.writeJSONArrayStart() - def _writeCollectionEnd(self): + def _writeCollectionEnd(self) -> None: self.writeJSONArrayEnd() + writeListBegin = _writeCollectionBegin writeListEnd = _writeCollectionEnd writeSetBegin = _writeCollectionBegin writeSetEnd = _writeCollectionEnd - def writeByte(self, byte): + def writeByte(self, byte: int) -> None: checkIntegerLimits(byte, 8) self.writeJSONNumber(byte) - def writeI16(self, i16): + def writeI16(self, i16: int) -> None: checkIntegerLimits(i16, 16) self.writeJSONNumber(i16) - def writeI32(self, i32): + def writeI32(self, i32: int) -> None: checkIntegerLimits(i32, 32) self.writeJSONNumber(i32) - def writeI64(self, i64): + def writeI64(self, i64: int) -> None: checkIntegerLimits(i64, 64) self.writeJSONNumber(i64) - def writeBool(self, boolean): - self.writeJSONNumber(1 if boolean is True else 0) + def writeBool(self, bool_val: bool) -> None: + self.writeJSONNumber(1 if bool_val is True else 0) - def writeDouble(self, dbl): - self.writeJSONNumber(dbl) + def writeDouble(self, dub: float) -> None: + self.writeJSONNumber(dub) - def writeString(self, string): - self.writeJSONString(string) + def writeString(self, str_val: str) -> None: + self.writeJSONString(str_val) - def writeBinary(self, binary): - self.writeJSONBase64(binary) + def writeBinary(self, str_val: bytes) -> None: + self.writeJSONBase64(str_val) class TSimpleJSONProtocolFactory(TProtocolFactory): - def getProtocol(self, trans): + def getProtocol(self, trans: TTransportBase) -> TSimpleJSONProtocol: return TSimpleJSONProtocol(trans) diff --git a/lib/py/src/protocol/TMultiplexedProtocol.py b/lib/py/src/thrift/protocol/TMultiplexedProtocol.py similarity index 74% rename from lib/py/src/protocol/TMultiplexedProtocol.py rename to lib/py/src/thrift/protocol/TMultiplexedProtocol.py index 0f8390fdbfb..60e425f09f5 100644 --- a/lib/py/src/protocol/TMultiplexedProtocol.py +++ b/lib/py/src/thrift/protocol/TMultiplexedProtocol.py @@ -17,23 +17,28 @@ # under the License. # +from __future__ import annotations + from thrift.Thrift import TMessageType from thrift.protocol import TProtocolDecorator +from thrift.protocol.TProtocol import TProtocolBase -SEPARATOR = ":" +SEPARATOR: str = ":" class TMultiplexedProtocol(TProtocolDecorator.TProtocolDecorator): - def __init__(self, protocol, serviceName): + serviceName: str + + def __init__(self, protocol: TProtocolBase, serviceName: str) -> None: self.serviceName = serviceName - def writeMessageBegin(self, name, type, seqid): - if (type == TMessageType.CALL or - type == TMessageType.ONEWAY): + def writeMessageBegin(self, name: str, ttype: int, seqid: int) -> None: + if (ttype == TMessageType.CALL or + ttype == TMessageType.ONEWAY): super(TMultiplexedProtocol, self).writeMessageBegin( self.serviceName + SEPARATOR + name, - type, + ttype, seqid ) else: - super(TMultiplexedProtocol, self).writeMessageBegin(name, type, seqid) + super(TMultiplexedProtocol, self).writeMessageBegin(name, ttype, seqid) diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/thrift/protocol/TProtocol.py similarity index 56% rename from lib/py/src/protocol/TProtocol.py rename to lib/py/src/thrift/protocol/TProtocol.py index 5b4f4d85d81..efcaced53ea 100644 --- a/lib/py/src/protocol/TProtocol.py +++ b/lib/py/src/thrift/protocol/TProtocol.py @@ -17,171 +17,195 @@ # under the License. # -from thrift.Thrift import TException, TType, TFrozenDict -from thrift.transport.TTransport import TTransportException +from __future__ import annotations import sys +from abc import ABC, abstractmethod from itertools import islice +from typing import Any, Generator, Iterable + +from thrift.Thrift import TException, TFrozenDict, TType +from thrift.transport.TTransport import TTransportBase, TTransportException class TProtocolException(TException): """Custom Protocol Exception class""" - UNKNOWN = 0 - INVALID_DATA = 1 - NEGATIVE_SIZE = 2 - SIZE_LIMIT = 3 - BAD_VERSION = 4 - NOT_IMPLEMENTED = 5 - DEPTH_LIMIT = 6 - INVALID_PROTOCOL = 7 + UNKNOWN: int = 0 + INVALID_DATA: int = 1 + NEGATIVE_SIZE: int = 2 + SIZE_LIMIT: int = 3 + BAD_VERSION: int = 4 + NOT_IMPLEMENTED: int = 5 + DEPTH_LIMIT: int = 6 + INVALID_PROTOCOL: int = 7 + + type: int - def __init__(self, type=UNKNOWN, message=None): + def __init__(self, type: int = UNKNOWN, message: str | None = None) -> None: TException.__init__(self, message) self.type = type -class TProtocolBase(object): +class TProtocolBase(ABC): """Base class for Thrift protocol driver.""" - def __init__(self, trans): + trans: TTransportBase + _fast_decode: Any + _fast_encode: Any + + def __init__(self, trans: TTransportBase) -> None: self.trans = trans self._fast_decode = None self._fast_encode = None @staticmethod - def _check_length(limit, length): + def _check_length(limit: int | None, length: int) -> None: if length < 0: - raise TTransportException(TTransportException.NEGATIVE_SIZE, - 'Negative length: %d' % length) + raise TTransportException( + TTransportException.NEGATIVE_SIZE, 'Negative length: %d' % length + ) if limit is not None and length > limit: - raise TTransportException(TTransportException.SIZE_LIMIT, - 'Length exceeded max allowed: %d' % limit) + raise TTransportException( + TTransportException.SIZE_LIMIT, 'Length exceeded max allowed: %d' % limit + ) - def writeMessageBegin(self, name, ttype, seqid): + def writeMessageBegin(self, name: str, ttype: int, seqid: int) -> None: pass - def writeMessageEnd(self): + def writeMessageEnd(self) -> None: pass - def writeStructBegin(self, name): + def writeStructBegin(self, name: str) -> None: pass - def writeStructEnd(self): + def writeStructEnd(self) -> None: pass - def writeFieldBegin(self, name, ttype, fid): + def writeFieldBegin(self, name: str, ttype: int, fid: int) -> None: pass - def writeFieldEnd(self): + def writeFieldEnd(self) -> None: pass - def writeFieldStop(self): + def writeFieldStop(self) -> None: pass - def writeMapBegin(self, ktype, vtype, size): + def writeMapBegin(self, ktype: int, vtype: int, size: int) -> None: pass - def writeMapEnd(self): + def writeMapEnd(self) -> None: pass - def writeListBegin(self, etype, size): + def writeListBegin(self, etype: int, size: int) -> None: pass - def writeListEnd(self): + def writeListEnd(self) -> None: pass - def writeSetBegin(self, etype, size): + def writeSetBegin(self, etype: int, size: int) -> None: pass - def writeSetEnd(self): + def writeSetEnd(self) -> None: pass - def writeBool(self, bool_val): + def writeBool(self, bool_val: bool) -> None: pass - def writeByte(self, byte): + def writeByte(self, byte: int) -> None: pass - def writeI16(self, i16): + def writeI16(self, i16: int) -> None: pass - def writeI32(self, i32): + def writeI32(self, i32: int) -> None: pass - def writeI64(self, i64): + def writeI64(self, i64: int) -> None: pass - def writeDouble(self, dub): + def writeDouble(self, dub: float) -> None: pass - def writeString(self, str_val): + def writeString(self, str_val: str) -> None: self.writeBinary(bytes(str_val, 'utf-8')) - def writeBinary(self, str_val): + def writeBinary(self, str_val: bytes) -> None: pass - def readMessageBegin(self): + @abstractmethod + def readMessageBegin(self) -> tuple[str, int, int]: pass - def readMessageEnd(self): + def readMessageEnd(self) -> None: pass - def readStructBegin(self): + def readStructBegin(self) -> str | None: pass - def readStructEnd(self): + def readStructEnd(self) -> None: pass - def readFieldBegin(self): + @abstractmethod + def readFieldBegin(self) -> tuple[str | None, int, int]: pass - def readFieldEnd(self): + def readFieldEnd(self) -> None: pass - def readMapBegin(self): + @abstractmethod + def readMapBegin(self) -> tuple[int, int, int]: pass - def readMapEnd(self): + def readMapEnd(self) -> None: pass - def readListBegin(self): + @abstractmethod + def readListBegin(self) -> tuple[int, int]: pass - def readListEnd(self): + def readListEnd(self) -> None: pass - def readSetBegin(self): + @abstractmethod + def readSetBegin(self) -> tuple[int, int]: pass - def readSetEnd(self): + def readSetEnd(self) -> None: pass - def readBool(self): + @abstractmethod + def readBool(self) -> bool: pass - def readByte(self): + @abstractmethod + def readByte(self) -> int: pass - def readI16(self): + @abstractmethod + def readI16(self) -> int: pass - def readI32(self): + @abstractmethod + def readI32(self) -> int: pass - def readI64(self): + @abstractmethod + def readI64(self) -> int: pass - def readDouble(self): + @abstractmethod + def readDouble(self) -> float: pass - def readString(self): + def readString(self) -> str: return self.readBinary().decode('utf-8') - def readBinary(self): + @abstractmethod + def readBinary(self) -> bytes: pass - def skip(self, ttype): + def skip(self, ttype: int) -> None: if ttype == TType.BOOL: self.readBool() elif ttype == TType.BYTE: @@ -222,12 +246,12 @@ def skip(self, ttype): self.skip(etype) self.readListEnd() else: - raise TProtocolException( - TProtocolException.INVALID_DATA, - "invalid TType") + raise TProtocolException(TProtocolException.INVALID_DATA, "invalid TType") # tuple of: ( 'reader method' name, is_container bool, 'writer_method' name ) - _TTYPE_HANDLERS = ( + _TTYPE_HANDLERS: tuple[ + tuple[str | None, str | None, bool], ... + ] = ( (None, None, False), # 0 TType.STOP (None, None, False), # 1 TType.VOID # TODO: handle void? ('readBool', 'writeBool', False), # 2 TType.BOOL @@ -245,61 +269,75 @@ def skip(self, ttype): ('readContainerSet', 'writeContainerSet', True), # 14 TType.SET ('readContainerList', 'writeContainerList', True), # 15 TType.LIST (None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types? - (None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types? + (None, None, False), # 17 TType.UTF16 # TODO: handle utf16 types? ) - def _ttype_handlers(self, ttype, spec): + def _ttype_handlers( + self, ttype: int, spec: Any + ) -> tuple[str | None, str | None, bool]: if spec == 'BINARY': if ttype != TType.STRING: - raise TProtocolException(type=TProtocolException.INVALID_DATA, - message='Invalid binary field type %d' % ttype) + raise TProtocolException( + type=TProtocolException.INVALID_DATA, + message='Invalid binary field type %d' % ttype, + ) return ('readBinary', 'writeBinary', False) - return self._TTYPE_HANDLERS[ttype] if ttype < len(self._TTYPE_HANDLERS) else (None, None, False) - - def _read_by_ttype(self, ttype, spec, espec): + return ( + self._TTYPE_HANDLERS[ttype] + if ttype < len(self._TTYPE_HANDLERS) + else (None, None, False) + ) + + def _read_by_ttype( + self, ttype: int, spec: Any, espec: Any + ) -> Generator[Any, None, None]: reader_name, _, is_container = self._ttype_handlers(ttype, espec) if reader_name is None: - raise TProtocolException(type=TProtocolException.INVALID_DATA, - message='Invalid type %d' % (ttype)) + raise TProtocolException( + type=TProtocolException.INVALID_DATA, message='Invalid type %d' % (ttype) + ) reader_func = getattr(self, reader_name) read = (lambda: reader_func(espec)) if is_container else reader_func while True: yield read() - def readFieldByTType(self, ttype, spec): + def readFieldByTType(self, ttype: int, spec: Any) -> Any: return next(self._read_by_ttype(ttype, spec, spec)) - def readContainerList(self, spec): + def readContainerList(self, spec: tuple[int, Any, bool]) -> list[Any] | tuple[Any, ...]: ttype, tspec, is_immutable = spec (list_type, list_len) = self.readListBegin() # TODO: compare types we just decoded with thrift_spec elems = islice(self._read_by_ttype(ttype, spec, tspec), list_len) - results = (tuple if is_immutable else list)(elems) + results: list[Any] | tuple[Any, ...] = (tuple if is_immutable else list)(elems) self.readListEnd() return results - def readContainerSet(self, spec): + def readContainerSet(self, spec: tuple[int, Any, bool]) -> set[Any] | frozenset[Any]: ttype, tspec, is_immutable = spec (set_type, set_len) = self.readSetBegin() # TODO: compare types we just decoded with thrift_spec elems = islice(self._read_by_ttype(ttype, spec, tspec), set_len) - results = (frozenset if is_immutable else set)(elems) + results: set[Any] | frozenset[Any] = (frozenset if is_immutable else set)(elems) self.readSetEnd() return results - def readContainerStruct(self, spec): + def readContainerStruct(self, spec: tuple[type, Any]) -> Any: (obj_class, obj_spec) = spec # If obj_class.read is a classmethod (e.g. in frozen structs), # call it as such. - if getattr(obj_class.read, '__self__', None) is obj_class: - obj = obj_class.read(self) + read_method = getattr(obj_class, 'read', None) + if read_method is not None and getattr(read_method, '__self__', None) is obj_class: + obj = read_method(self) else: obj = obj_class() obj.read(self) return obj - def readContainerMap(self, spec): + def readContainerMap( + self, spec: tuple[int, Any, int, Any, bool] + ) -> dict[Any, Any] | TFrozenDict: ktype, kspec, vtype, vspec, is_immutable = spec (map_ktype, map_vtype, map_len) = self.readMapBegin() # TODO: compare types we just decoded with thrift_spec and @@ -307,13 +345,16 @@ def readContainerMap(self, spec): keys = self._read_by_ttype(ktype, spec, kspec) vals = self._read_by_ttype(vtype, spec, vspec) keyvals = islice(zip(keys, vals), map_len) - results = (TFrozenDict if is_immutable else dict)(keyvals) + results: dict[Any, Any] | TFrozenDict = (TFrozenDict if is_immutable else dict)( + keyvals + ) self.readMapEnd() return results - def readStruct(self, obj, thrift_spec, is_immutable=False): - if is_immutable: - fields = {} + def readStruct( + self, obj: Any, thrift_spec: tuple[Any, ...], is_immutable: bool = False + ) -> Any: + fields: dict[str, Any] = {} if is_immutable else {} self.readStructBegin() while True: (fname, ftype, fid) = self.readFieldBegin() @@ -338,33 +379,38 @@ def readStruct(self, obj, thrift_spec, is_immutable=False): self.readStructEnd() if is_immutable: return obj(**fields) + return None - def writeContainerStruct(self, val, spec): + def writeContainerStruct(self, val: Any, spec: Any) -> None: val.write(self) - def writeContainerList(self, val, spec): + def writeContainerList(self, val: list[Any] | tuple[Any, ...], spec: tuple[int, Any, bool]) -> None: ttype, tspec, _ = spec self.writeListBegin(ttype, len(val)) for _ in self._write_by_ttype(ttype, val, spec, tspec): pass self.writeListEnd() - def writeContainerSet(self, val, spec): + def writeContainerSet(self, val: set[Any] | frozenset[Any], spec: tuple[int, Any, bool]) -> None: ttype, tspec, _ = spec self.writeSetBegin(ttype, len(val)) for _ in self._write_by_ttype(ttype, val, spec, tspec): pass self.writeSetEnd() - def writeContainerMap(self, val, spec): + def writeContainerMap( + self, val: dict[Any, Any], spec: tuple[int, Any, int, Any, bool] + ) -> None: ktype, kspec, vtype, vspec, _ = spec self.writeMapBegin(ktype, vtype, len(val)) - for _ in zip(self._write_by_ttype(ktype, val.keys(), spec, kspec), - self._write_by_ttype(vtype, val.values(), spec, vspec)): + for _ in zip( + self._write_by_ttype(ktype, val.keys(), spec, kspec), + self._write_by_ttype(vtype, val.values(), spec, vspec), + ): pass self.writeMapEnd() - def writeStruct(self, obj, thrift_spec): + def writeStruct(self, obj: Any, thrift_spec: tuple[Any, ...]) -> None: self.writeStructBegin(obj.__class__.__name__) for field in thrift_spec: if field is None: @@ -383,32 +429,42 @@ def writeStruct(self, obj, thrift_spec): self.writeFieldStop() self.writeStructEnd() - def _write_by_ttype(self, ttype, vals, spec, espec): + def _write_by_ttype( + self, ttype: int, vals: Iterable[Any], spec: Any, espec: Any + ) -> Generator[Any, None, None]: _, writer_name, is_container = self._ttype_handlers(ttype, espec) + assert writer_name is not None, f"No writer for ttype {ttype}" writer_func = getattr(self, writer_name) write = (lambda v: writer_func(v, espec)) if is_container else writer_func for v in vals: yield write(v) - def writeFieldByTType(self, ttype, val, spec): + def writeFieldByTType(self, ttype: int, val: Any, spec: Any) -> None: next(self._write_by_ttype(ttype, [val], spec, spec)) -def checkIntegerLimits(i, bits): +def checkIntegerLimits(i: int, bits: int) -> None: if bits == 8 and (i < -128 or i > 127): - raise TProtocolException(TProtocolException.INVALID_DATA, - "i8 requires -128 <= number <= 127") + raise TProtocolException( + TProtocolException.INVALID_DATA, "i8 requires -128 <= number <= 127" + ) elif bits == 16 and (i < -32768 or i > 32767): - raise TProtocolException(TProtocolException.INVALID_DATA, - "i16 requires -32768 <= number <= 32767") + raise TProtocolException( + TProtocolException.INVALID_DATA, "i16 requires -32768 <= number <= 32767" + ) elif bits == 32 and (i < -2147483648 or i > 2147483647): - raise TProtocolException(TProtocolException.INVALID_DATA, - "i32 requires -2147483648 <= number <= 2147483647") + raise TProtocolException( + TProtocolException.INVALID_DATA, + "i32 requires -2147483648 <= number <= 2147483647", + ) elif bits == 64 and (i < -9223372036854775808 or i > 9223372036854775807): - raise TProtocolException(TProtocolException.INVALID_DATA, - "i64 requires -9223372036854775808 <= number <= 9223372036854775807") + raise TProtocolException( + TProtocolException.INVALID_DATA, + "i64 requires -9223372036854775808 <= number <= 9223372036854775807", + ) -class TProtocolFactory(object): - def getProtocol(self, trans): +class TProtocolFactory(ABC): + @abstractmethod + def getProtocol(self, trans: TTransportBase) -> TProtocolBase: pass diff --git a/lib/py/src/protocol/TProtocolDecorator.py b/lib/py/src/thrift/protocol/TProtocolDecorator.py similarity index 76% rename from lib/py/src/protocol/TProtocolDecorator.py rename to lib/py/src/thrift/protocol/TProtocolDecorator.py index f5546c736e1..2a9fa051db1 100644 --- a/lib/py/src/protocol/TProtocolDecorator.py +++ b/lib/py/src/thrift/protocol/TProtocolDecorator.py @@ -17,9 +17,17 @@ # under the License. # +from __future__ import annotations -class TProtocolDecorator(object): - def __new__(cls, protocol, *args, **kwargs): +from typing import Any + +from thrift.protocol.TProtocol import TProtocolBase + + +class TProtocolDecorator(TProtocolBase): + """Protocol decorator base class that wraps another protocol.""" + + def __new__(cls, protocol: TProtocolBase, *args: Any, **kwargs: Any) -> TProtocolDecorator: decorated_cls = type(''.join(['Decorated', protocol.__class__.__name__]), (cls, protocol.__class__), protocol.__dict__) diff --git a/lib/py/src/protocol/__init__.py b/lib/py/src/thrift/protocol/__init__.py similarity index 100% rename from lib/py/src/protocol/__init__.py rename to lib/py/src/thrift/protocol/__init__.py diff --git a/lib/py/src/thrift/py.typed b/lib/py/src/thrift/py.typed new file mode 100644 index 00000000000..e69de29bb2d diff --git a/lib/py/src/server/THttpServer.py b/lib/py/src/thrift/server/THttpServer.py similarity index 77% rename from lib/py/src/server/THttpServer.py rename to lib/py/src/thrift/server/THttpServer.py index 21f2c869149..1836026ac90 100644 --- a/lib/py/src/server/THttpServer.py +++ b/lib/py/src/thrift/server/THttpServer.py @@ -17,14 +17,20 @@ # under the License. # -import ssl +from __future__ import annotations import http.server as BaseHTTPServer +import ssl +from typing import Any, BinaryIO, Callable, TYPE_CHECKING, cast from thrift.Thrift import TMessageType from thrift.server import TServer from thrift.transport import TTransport +if TYPE_CHECKING: + from thrift.protocol.TProtocol import TProtocolFactory + from thrift.Thrift import TProcessor + class ResponseException(Exception): """Allows handlers to override the HTTP response @@ -37,7 +43,10 @@ class ResponseException(Exception): for ONEWAY requests, as the HTTP response must be sent before the RPC is processed. """ - def __init__(self, handler): + + handler: Callable[[Any], None] + + def __init__(self, handler: Callable[[Any], None]) -> None: self.handler = handler @@ -50,13 +59,19 @@ class THttpServer(TServer.TServer): transport/protocol/processor/server layering, by performing the transport functions here. This means things like oneway handling are oddly exposed. """ - def __init__(self, - processor, - server_address, - inputProtocolFactory, - outputProtocolFactory=None, - server_class=BaseHTTPServer.HTTPServer, - **kwargs): + + httpd: BaseHTTPServer.HTTPServer + _replied: bool | None + + def __init__( + self, + processor: TProcessor, + server_address: tuple[str, int], + inputProtocolFactory: TProtocolFactory, + outputProtocolFactory: TProtocolFactory | None = None, + server_class: type[BaseHTTPServer.HTTPServer] = BaseHTTPServer.HTTPServer, + **kwargs: Any, + ) -> None: """Set up protocol factories and HTTP (or HTTPS) server. See BaseHTTPServer for server_address. @@ -80,7 +95,7 @@ class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler): def do_POST(self): # Don't care about the request path. thttpserver._replied = False - iftrans = TTransport.TFileObjectTransport(self.rfile) + iftrans = TTransport.TFileObjectTransport(cast(BinaryIO, self.rfile)) itrans = TTransport.TBufferedTransport( iftrans, int(self.headers['Content-Length'])) otrans = TTransport.TMemoryBuffer() @@ -96,7 +111,7 @@ def do_POST(self): # If the request was ONEWAY we would have replied already data = otrans.getvalue() self.send_response(200) - self.send_header("Content-Length", len(data)) + self.send_header("Content-Length", str(len(data))) self.send_header("Content-Type", "application/x-thrift") self.end_headers() self.wfile.write(data) @@ -116,16 +131,20 @@ def on_begin(self, name, type, seqid): self.httpd = server_class(server_address, RequestHander) - if (kwargs.get('cafile') or kwargs.get('cert_file') or kwargs.get('key_file')): - context = ssl.create_default_context(cafile=kwargs.get('cafile')) + cert_file = kwargs.get('cert_file') + key_file = kwargs.get('key_file') + cafile = kwargs.get('cafile') + if cafile or cert_file or key_file: + context = ssl.create_default_context(cafile=cafile) context.check_hostname = False - context.load_cert_chain(kwargs.get('cert_file'), kwargs.get('key_file')) - context.verify_mode = ssl.CERT_REQUIRED if kwargs.get('cafile') else ssl.CERT_NONE + if cert_file: + context.load_cert_chain(cert_file, key_file) + context.verify_mode = ssl.CERT_REQUIRED if cafile else ssl.CERT_NONE self.httpd.socket = context.wrap_socket(self.httpd.socket, server_side=True) - def serve(self): + def serve(self) -> None: self.httpd.serve_forever() - def shutdown(self): + def shutdown(self) -> None: self.httpd.socket.close() # self.httpd.shutdown() # hangs forever, python doesn't handle POLLNVAL properly! diff --git a/lib/py/src/server/TNonblockingServer.py b/lib/py/src/thrift/server/TNonblockingServer.py similarity index 78% rename from lib/py/src/server/TNonblockingServer.py rename to lib/py/src/thrift/server/TNonblockingServer.py index a7a40cafb53..97651713719 100644 --- a/lib/py/src/server/TNonblockingServer.py +++ b/lib/py/src/thrift/server/TNonblockingServer.py @@ -25,37 +25,46 @@ maximum connections """ +from __future__ import annotations + import logging +import queue import select import socket import struct import threading - from collections import deque -import queue +from typing import Any, Callable, TYPE_CHECKING from thrift.transport import TTransport from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory +if TYPE_CHECKING: + from thrift.protocol.TProtocol import TProtocolFactory + from thrift.Thrift import TProcessor + from thrift.transport.TSocket import TServerSocket + __all__ = ['TNonblockingServer'] -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class Worker(threading.Thread): """Worker is a small helper to process incoming connection.""" - def __init__(self, queue): + queue: queue.Queue[list[Any]] + + def __init__(self, queue: queue.Queue[list[Any]]) -> None: threading.Thread.__init__(self) self.queue = queue - def run(self): + def run(self) -> None: """Process queries from task queue, stop if processor is None.""" while True: + processor, iprot, oprot, otrans, callback = self.queue.get() + if processor is None: + break try: - processor, iprot, oprot, otrans, callback = self.queue.get() - if processor is None: - break processor.process(iprot, oprot) callback(True, otrans.getvalue()) except Exception: @@ -63,16 +72,16 @@ def run(self): callback(False, b'') -WAIT_LEN = 0 -WAIT_MESSAGE = 1 -WAIT_PROCESS = 2 -SEND_ANSWER = 3 -CLOSED = 4 +WAIT_LEN: int = 0 +WAIT_MESSAGE: int = 1 +WAIT_PROCESS: int = 2 +SEND_ANSWER: int = 3 +CLOSED: int = 4 -def locked(func): +def locked(func: Callable[..., Any]) -> Callable[..., Any]: """Decorator which locks self.lock.""" - def nested(self, *args, **kwargs): + def nested(self: Any, *args: Any, **kwargs: Any) -> Any: self.lock.acquire() try: return func(self, *args, **kwargs) @@ -81,9 +90,9 @@ def nested(self, *args, **kwargs): return nested -def socket_exception(func): +def socket_exception(func: Callable[..., Any]) -> Callable[..., Any]: """Decorator close object on socket.error.""" - def read(self, *args, **kwargs): + def read(self: Any, *args: Any, **kwargs: Any) -> Any: try: return func(self, *args, **kwargs) except socket.error: @@ -92,19 +101,26 @@ def read(self, *args, **kwargs): return read -class Message(object): - def __init__(self, offset, len_, header): +class Message: + """Represents a message being read from or written to a connection.""" + + offset: int + len: int + buffer: bytes | None + is_header: bool + + def __init__(self, offset: int, len_: int, header: bool) -> None: self.offset = offset self.len = len_ self.buffer = None self.is_header = header @property - def end(self): + def end(self) -> int: return self.offset + self.len -class Connection(object): +class Connection: """Basic class is represented connection. It can be in state: @@ -116,7 +132,19 @@ class Connection(object): of answer). CLOSED --- socket was closed and connection should be deleted. """ - def __init__(self, new_socket, wake_up): + + socket: socket.socket + status: int + len: int + received: deque[Message] + _reading: Message + _rbuf: bytes + _wbuf: bytes + lock: threading.Lock + wake_up: Callable[[], None] + remaining: bool + + def __init__(self, new_socket: socket.socket, wake_up: Callable[[], None]) -> None: self.socket = new_socket self.socket.setblocking(False) self.status = WAIT_LEN @@ -130,7 +158,7 @@ def __init__(self, new_socket, wake_up): self.remaining = False @socket_exception - def read(self): + def read(self) -> None: """Reads data from stream and switch state.""" assert self.status in (WAIT_LEN, WAIT_MESSAGE) assert not self.received @@ -169,7 +197,7 @@ def read(self): self.remaining = not done @socket_exception - def write(self): + def write(self) -> None: """Writes data from socket and switch state.""" assert self.status == SEND_ANSWER sent = self.socket.send(self._wbuf) @@ -181,7 +209,7 @@ def write(self): self._wbuf = self._wbuf[sent:] @locked - def ready(self, all_ok, message): + def ready(self, all_ok: bool, message: bytes) -> None: """Callback function for switching state and waking up main thread. This function is the only function witch can be called asynchronous. @@ -209,40 +237,55 @@ def ready(self, all_ok, message): self.wake_up() @locked - def is_writeable(self): + def is_writeable(self) -> bool: """Return True if connection should be added to write list of select""" return self.status == SEND_ANSWER # it's not necessary, but... @locked - def is_readable(self): + def is_readable(self) -> bool: """Return True if connection should be added to read list of select""" return self.status in (WAIT_LEN, WAIT_MESSAGE) @locked - def is_closed(self): + def is_closed(self) -> bool: """Returns True if connection is closed.""" return self.status == CLOSED - def fileno(self): + def fileno(self) -> int: """Returns the file descriptor of the associated socket.""" return self.socket.fileno() - def close(self): + def close(self) -> None: """Closes connection""" self.status = CLOSED self.socket.close() -class TNonblockingServer(object): +class TNonblockingServer: """Non-blocking server.""" - def __init__(self, - processor, - lsocket, - inputProtocolFactory=None, - outputProtocolFactory=None, - threads=10): + processor: TProcessor + socket: TServerSocket + in_protocol: TProtocolFactory + out_protocol: TProtocolFactory + threads: int + clients: dict[int, Connection] + tasks: queue.Queue[list[Any]] + _read: socket.socket + _write: socket.socket + prepared: bool + _stop: bool + poll: select.poll | None + + def __init__( + self, + processor: TProcessor, + lsocket: TServerSocket, + inputProtocolFactory: TProtocolFactory | None = None, + outputProtocolFactory: TProtocolFactory | None = None, + threads: int = 10, + ) -> None: self.processor = processor self.socket = lsocket self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory() @@ -255,13 +298,13 @@ def __init__(self, self._stop = False self.poll = select.poll() if hasattr(select, 'poll') else None - def setNumThreads(self, num): + def setNumThreads(self, num: int) -> None: """Set the number of worker threads that should be created.""" # implement ThreadPool interface assert not self.prepared, "Can't change number of threads after start" self.threads = num - def prepare(self): + def prepare(self) -> None: """Prepares server for serve requests.""" if self.prepared: return @@ -272,7 +315,7 @@ def prepare(self): thread.start() self.prepared = True - def wake_up(self): + def wake_up(self) -> None: """Wake up main thread. The server usually waits in select call in we should terminate one. @@ -285,7 +328,7 @@ def wake_up(self): """ self._write.send(b'1') - def stop(self): + def stop(self) -> None: """Stop the server. This method causes the serve() method to return. stop() may be invoked @@ -300,11 +343,11 @@ def stop(self): self._stop = True self.wake_up() - def _select(self): + def _select(self) -> tuple[list[int], list[int], list[int], bool]: """Does select on open connections.""" - readable = [self.socket.handle.fileno(), self._read.fileno()] - writable = [] - remaining = [] + readable: list[int] = [self.socket.handle.fileno(), self._read.fileno()] # type: ignore[union-attr] + writable: list[int] = [] + remaining: list[int] = [] for i, connection in list(self.clients.items()): if connection.is_readable(): readable.append(connection.fileno()) @@ -317,35 +360,35 @@ def _select(self): if remaining: return remaining, [], [], False else: - return select.select(readable, writable, readable) + (True,) + return select.select(readable, writable, readable) + (True,) # type: ignore[return-value] - def _poll_select(self): + def _poll_select(self) -> tuple[list[int], list[int], list[int], bool]: """Does poll on open connections, if available.""" - remaining = [] + remaining: list[int] = [] - self.poll.register(self.socket.handle.fileno(), select.POLLIN | select.POLLRDNORM) - self.poll.register(self._read.fileno(), select.POLLIN | select.POLLRDNORM) + self.poll.register(self.socket.handle.fileno(), select.POLLIN | select.POLLRDNORM) # type: ignore[union-attr] + self.poll.register(self._read.fileno(), select.POLLIN | select.POLLRDNORM) # type: ignore[union-attr] for i, connection in list(self.clients.items()): if connection.is_readable(): - self.poll.register(connection.fileno(), select.POLLIN | select.POLLRDNORM | select.POLLERR | select.POLLHUP | select.POLLNVAL) + self.poll.register(connection.fileno(), select.POLLIN | select.POLLRDNORM | select.POLLERR | select.POLLHUP | select.POLLNVAL) # type: ignore[union-attr] if connection.remaining or connection.received: remaining.append(connection.fileno()) if connection.is_writeable(): - self.poll.register(connection.fileno(), select.POLLOUT | select.POLLWRNORM) + self.poll.register(connection.fileno(), select.POLLOUT | select.POLLWRNORM) # type: ignore[union-attr] if connection.is_closed(): try: - self.poll.unregister(i) + self.poll.unregister(i) # type: ignore[union-attr] except KeyError: logger.debug("KeyError in unregistering connections...") del self.clients[i] if remaining: return remaining, [], [], False - rlist = [] - wlist = [] - xlist = [] - pollres = self.poll.poll() + rlist: list[int] = [] + wlist: list[int] = [] + xlist: list[int] = [] + pollres = self.poll.poll() # type: ignore[union-attr] for fd, event in pollres: if event & (select.POLLERR | select.POLLHUP | select.POLLNVAL): xlist.append(fd) @@ -359,7 +402,7 @@ def _poll_select(self): return rlist, wlist, xlist, True - def handle(self): + def handle(self) -> None: """Handle requests. WARNING! You must call prepare() BEFORE calling handle() @@ -373,7 +416,7 @@ def handle(self): elif readable == self.socket.handle.fileno(): try: client = self.socket.accept() - if client: + if client and client.handle: self.clients[client.handle.fileno()] = Connection(client.handle, self.wake_up) except socket.error: @@ -398,14 +441,14 @@ def handle(self): for oob in xset: self.clients[oob].close() - def close(self): + def close(self) -> None: """Closes the server.""" for _ in range(self.threads): self.tasks.put([None, None, None, None, None]) self.socket.close() self.prepared = False - def serve(self): + def serve(self) -> None: """Serve requests. Serve requests forever, or until stop() is called. diff --git a/lib/py/src/server/TProcessPoolServer.py b/lib/py/src/thrift/server/TProcessPoolServer.py similarity index 80% rename from lib/py/src/server/TProcessPoolServer.py rename to lib/py/src/thrift/server/TProcessPoolServer.py index c9cfa1104bc..17718dd8d8e 100644 --- a/lib/py/src/server/TProcessPoolServer.py +++ b/lib/py/src/thrift/server/TProcessPoolServer.py @@ -18,14 +18,21 @@ # -import logging +from __future__ import annotations +import logging from multiprocessing import Process, Value, Condition +from multiprocessing.synchronize import Condition as ConditionType +from multiprocessing.sharedctypes import Synchronized +from typing import Any, Callable, TYPE_CHECKING from .TServer import TServer from thrift.transport.TTransport import TTransportException -logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from thrift.transport.TTransport import TTransportBase + +logger: logging.Logger = logging.getLogger(__name__) class TProcessPoolServer(TServer): @@ -34,7 +41,14 @@ class TProcessPoolServer(TServer): Note that if you need shared state between the handlers - it's up to you! Written by Dvir Volk, doat.com """ - def __init__(self, *args): + + numWorkers: int + workers: list[Process] | None + isRunning: Synchronized[bool] + stopCondition: ConditionType + postForkCallback: Callable[[], None] | None + + def __init__(self, *args: Any) -> None: TServer.__init__(self, *args) self.numWorkers = 10 self.workers = [] @@ -42,21 +56,21 @@ def __init__(self, *args): self.stopCondition = Condition() self.postForkCallback = None - def __getstate__(self): + def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() state['workers'] = None return state - def setPostForkCallback(self, callback): + def setPostForkCallback(self, callback: Callable[[], None]) -> None: if not callable(callback): raise TypeError("This is not a callback!") self.postForkCallback = callback - def setNumWorkers(self, num): + def setNumWorkers(self, num: int) -> None: """Set the number of worker threads that should be created""" self.numWorkers = num - def workerProcess(self): + def workerProcess(self) -> int | None: """Loop getting clients from the shared queue and process them""" if self.postForkCallback: self.postForkCallback() @@ -71,8 +85,9 @@ def workerProcess(self): return 0 except Exception as x: logger.exception(x) + return None - def serveClient(self, client): + def serveClient(self, client: TTransportBase) -> None: """Process input/output from a client for as long as possible""" itrans = self.inputTransportFactory.getTransport(client) otrans = self.outputTransportFactory.getTransport(client) @@ -90,7 +105,7 @@ def serveClient(self, client): itrans.close() otrans.close() - def serve(self): + def serve(self) -> None: """Start workers and put into queue""" # this is a shared state that can tell the workers to exit when False self.isRunning.value = True @@ -104,7 +119,7 @@ def serve(self): w = Process(target=self.workerProcess) w.daemon = True w.start() - self.workers.append(w) + self.workers.append(w) # type: ignore[union-attr] except Exception as x: logger.exception(x) @@ -121,7 +136,7 @@ def serve(self): self.isRunning.value = False - def stop(self): + def stop(self) -> None: self.isRunning.value = False self.stopCondition.acquire() self.stopCondition.notify() diff --git a/lib/py/src/server/TServer.py b/lib/py/src/thrift/server/TServer.py similarity index 86% rename from lib/py/src/server/TServer.py rename to lib/py/src/thrift/server/TServer.py index 81144f14a9b..58926ea6209 100644 --- a/lib/py/src/server/TServer.py +++ b/lib/py/src/thrift/server/TServer.py @@ -17,19 +17,27 @@ # under the License. # +from __future__ import annotations + import queue import logging import os import threading +from typing import Any, TYPE_CHECKING from thrift.protocol import TBinaryProtocol from thrift.protocol.THeaderProtocol import THeaderProtocolFactory from thrift.transport import TTransport -logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from thrift.protocol.TProtocol import TProtocolFactory + from thrift.transport.TTransport import TTransportFactoryBase, TServerTransportBase, TTransportBase + from thrift.Thrift import TProcessor + +logger: logging.Logger = logging.getLogger(__name__) -class TServer(object): +class TServer: """Base interface for a server, which must have a serve() method. Three constructors for all servers: @@ -39,7 +47,15 @@ class TServer(object): inputTransportFactory, outputTransportFactory, inputProtocolFactory, outputProtocolFactory) """ - def __init__(self, *args): + + processor: TProcessor + serverTransport: TServerTransportBase + inputTransportFactory: TTransportFactoryBase + outputTransportFactory: TTransportFactoryBase + inputProtocolFactory: TProtocolFactory + outputProtocolFactory: TProtocolFactory + + def __init__(self, *args: Any) -> None: if (len(args) == 2): self.__initArgs__(args[0], args[1], TTransport.TTransportFactoryBase(), @@ -51,9 +67,15 @@ def __init__(self, *args): elif (len(args) == 6): self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5]) - def __initArgs__(self, processor, serverTransport, - inputTransportFactory, outputTransportFactory, - inputProtocolFactory, outputProtocolFactory): + def __initArgs__( + self, + processor: TProcessor, + serverTransport: TServerTransportBase, + inputTransportFactory: TTransportFactoryBase, + outputTransportFactory: TTransportFactoryBase, + inputProtocolFactory: TProtocolFactory, + outputProtocolFactory: TProtocolFactory, + ) -> None: self.processor = processor self.serverTransport = serverTransport self.inputTransportFactory = inputTransportFactory @@ -67,17 +89,17 @@ def __initArgs__(self, processor, serverTransport, raise ValueError("THeaderProtocol servers require that both the input and " "output protocols are THeaderProtocol.") - def serve(self): + def serve(self) -> None: pass class TSimpleServer(TServer): """Simple single-threaded server that just pumps around one transport.""" - def __init__(self, *args): + def __init__(self, *args: Any) -> None: TServer.__init__(self, *args) - def serve(self): + def serve(self) -> None: self.serverTransport.listen() while True: client = self.serverTransport.accept() @@ -113,11 +135,13 @@ def serve(self): class TThreadedServer(TServer): """Threaded server that spawns a new thread per each connection.""" - def __init__(self, *args, **kwargs): + daemon: bool + + def __init__(self, *args: Any, **kwargs: Any) -> None: TServer.__init__(self, *args) self.daemon = kwargs.get("daemon", False) - def serve(self): + def serve(self) -> None: self.serverTransport.listen() while True: try: @@ -132,7 +156,7 @@ def serve(self): except Exception as x: logger.exception(x) - def handle(self, client): + def handle(self, client: TTransportBase) -> None: itrans = self.inputTransportFactory.getTransport(client) iprot = self.inputProtocolFactory.getProtocol(itrans) @@ -162,17 +186,21 @@ def handle(self, client): class TThreadPoolServer(TServer): """Server with a fixed size pool of threads which service requests.""" - def __init__(self, *args, **kwargs): + clients: queue.Queue[TTransportBase] + threads: int + daemon: bool + + def __init__(self, *args: Any, **kwargs: Any) -> None: TServer.__init__(self, *args) self.clients = queue.Queue() self.threads = 10 self.daemon = kwargs.get("daemon", False) - def setNumThreads(self, num): + def setNumThreads(self, num: int) -> None: """Set the number of worker threads that should be created""" self.threads = num - def serveThread(self): + def serveThread(self) -> None: """Loop around getting clients from the shared queue and process them.""" while True: try: @@ -181,7 +209,7 @@ def serveThread(self): except Exception as x: logger.exception(x) - def serveClient(self, client): + def serveClient(self, client: TTransportBase) -> None: """Process input/output from a client for as long as possible""" itrans = self.inputTransportFactory.getTransport(client) iprot = self.inputProtocolFactory.getProtocol(itrans) @@ -208,7 +236,7 @@ def serveClient(self, client): if otrans: otrans.close() - def serve(self): + def serve(self) -> None: """Start a fixed number of worker threads and put client into a queue""" for i in range(self.threads): try: @@ -243,12 +271,15 @@ class TForkingServer(TServer): This code is heavily inspired by SocketServer.ForkingMixIn in the Python stdlib. """ - def __init__(self, *args): + + children: list[int] + + def __init__(self, *args: Any) -> None: TServer.__init__(self, *args) self.children = [] - def serve(self): - def try_close(file): + def serve(self) -> None: + def try_close(file: TTransportBase) -> None: try: file.close() except IOError as e: @@ -310,7 +341,7 @@ def try_close(file): except Exception as x: logger.exception(x) - def collect_children(self): + def collect_children(self) -> None: while self.children: try: pid, status = os.waitpid(0, os.WNOHANG) diff --git a/lib/py/src/server/__init__.py b/lib/py/src/thrift/server/__init__.py similarity index 100% rename from lib/py/src/server/__init__.py rename to lib/py/src/thrift/server/__init__.py diff --git a/lib/py/src/transport/THeaderTransport.py b/lib/py/src/thrift/transport/THeaderTransport.py similarity index 84% rename from lib/py/src/transport/THeaderTransport.py rename to lib/py/src/thrift/transport/THeaderTransport.py index 4fb20343020..52ca9086dcf 100644 --- a/lib/py/src/transport/THeaderTransport.py +++ b/lib/py/src/thrift/transport/THeaderTransport.py @@ -17,9 +17,12 @@ # under the License. # +from __future__ import annotations + import struct import zlib from io import BytesIO +from typing import Callable from thrift.protocol.TBinaryProtocol import TBinaryProtocol from thrift.protocol.TCompactProtocol import TCompactProtocol, readVarint, writeVarint @@ -31,46 +34,46 @@ TTransportException, ) -U16 = struct.Struct("!H") -I32 = struct.Struct("!i") -HEADER_MAGIC = 0x0FFF -HARD_MAX_FRAME_SIZE = 0x3FFFFFFF +U16: struct.Struct = struct.Struct("!H") +I32: struct.Struct = struct.Struct("!i") +HEADER_MAGIC: int = 0x0FFF +HARD_MAX_FRAME_SIZE: int = 0x3FFFFFFF -class THeaderClientType(object): - HEADERS = 0x00 +class THeaderClientType: + HEADERS: int = 0x00 - FRAMED_BINARY = 0x01 - UNFRAMED_BINARY = 0x02 + FRAMED_BINARY: int = 0x01 + UNFRAMED_BINARY: int = 0x02 - FRAMED_COMPACT = 0x03 - UNFRAMED_COMPACT = 0x04 + FRAMED_COMPACT: int = 0x03 + UNFRAMED_COMPACT: int = 0x04 -class THeaderSubprotocolID(object): - BINARY = 0x00 - COMPACT = 0x02 +class THeaderSubprotocolID: + BINARY: int = 0x00 + COMPACT: int = 0x02 -class TInfoHeaderType(object): - KEY_VALUE = 0x01 +class TInfoHeaderType: + KEY_VALUE: int = 0x01 -class THeaderTransformID(object): - ZLIB = 0x01 +class THeaderTransformID: + ZLIB: int = 0x01 -READ_TRANSFORMS_BY_ID = { +READ_TRANSFORMS_BY_ID: dict[int, Callable[[bytes], bytes]] = { THeaderTransformID.ZLIB: zlib.decompress, } -WRITE_TRANSFORMS_BY_ID = { +WRITE_TRANSFORMS_BY_ID: dict[int, Callable[[bytes], bytes]] = { THeaderTransformID.ZLIB: zlib.compress, } -def _readString(trans): +def _readString(trans: TMemoryBuffer) -> bytes: size = readVarint(trans) if size < 0: raise TTransportException( @@ -80,13 +83,31 @@ def _readString(trans): return trans.read(size) -def _writeString(trans, value): +def _writeString(trans: BytesIO, value: bytes) -> None: writeVarint(trans, len(value)) trans.write(value) class THeaderTransport(TTransportBase, CReadableTransport): - def __init__(self, transport, allowed_client_types, default_protocol=THeaderSubprotocolID.BINARY): + _transport: TTransportBase + _client_type: int + _allowed_client_types: tuple[int, ...] + _read_buffer: BytesIO + _read_headers: dict[bytes, bytes] + _write_buffer: BytesIO + _write_headers: dict[bytes, bytes] + _write_transforms: list[int] + flags: int + sequence_id: int + _protocol_id: int + _max_frame_size: int + + def __init__( + self, + transport: TTransportBase, + allowed_client_types: tuple[int, ...], + default_protocol: int = THeaderSubprotocolID.BINARY, + ) -> None: self._transport = transport self._client_type = THeaderClientType.HEADERS self._allowed_client_types = allowed_client_types @@ -103,40 +124,40 @@ def __init__(self, transport, allowed_client_types, default_protocol=THeaderSubp self._protocol_id = default_protocol self._max_frame_size = HARD_MAX_FRAME_SIZE - def isOpen(self): + def isOpen(self) -> bool: return self._transport.isOpen() - def open(self): + def open(self) -> None: return self._transport.open() - def close(self): + def close(self) -> None: return self._transport.close() - def get_headers(self): + def get_headers(self) -> dict[bytes, bytes]: return self._read_headers - def set_header(self, key, value): + def set_header(self, key: bytes, value: bytes) -> None: if not isinstance(key, bytes): raise ValueError("header names must be bytes") if not isinstance(value, bytes): raise ValueError("header values must be bytes") self._write_headers[key] = value - def clear_headers(self): + def clear_headers(self) -> None: self._write_headers.clear() - def add_transform(self, transform_id): + def add_transform(self, transform_id: int) -> None: if transform_id not in WRITE_TRANSFORMS_BY_ID: raise ValueError("unknown transform") self._write_transforms.append(transform_id) - def set_max_frame_size(self, size): + def set_max_frame_size(self, size: int) -> None: if not 0 < size < HARD_MAX_FRAME_SIZE: raise ValueError("maximum frame size should be < %d and > 0" % HARD_MAX_FRAME_SIZE) self._max_frame_size = size @property - def protocol_id(self): + def protocol_id(self) -> int: if self._client_type == THeaderClientType.HEADERS: return self._protocol_id elif self._client_type in (THeaderClientType.FRAMED_BINARY, THeaderClientType.UNFRAMED_BINARY): @@ -149,7 +170,7 @@ def protocol_id(self): "Protocol ID not know for client type %d" % self._client_type, ) - def read(self, sz): + def read(self, sz: int) -> bytes: # if there are bytes left in the buffer, produce those first. bytes_read = self._read_buffer.read(sz) bytes_left_to_read = sz - len(bytes_read) @@ -166,7 +187,7 @@ def read(self, sz): self.readFrame(bytes_left_to_read) return bytes_read + self._read_buffer.read(bytes_left_to_read) - def _set_client_type(self, client_type): + def _set_client_type(self, client_type: int) -> None: if client_type not in self._allowed_client_types: raise TTransportException( TTransportException.INVALID_CLIENT_TYPE, @@ -174,7 +195,7 @@ def _set_client_type(self, client_type): ) self._client_type = client_type - def readFrame(self, req_sz): + def readFrame(self, req_sz: int) -> None: # the first word could either be the length field of a framed message # or the first bytes of an unframed message. first_word = self._transport.readAll(I32.size) @@ -227,7 +248,7 @@ def readFrame(self, req_sz): "Could not detect client transport type.", ) - def _parse_header_format(self, buffer): + def _parse_header_format(self, buffer: BytesIO) -> BytesIO: # make BytesIO look like TTransport for varint helpers buffer_transport = TMemoryBuffer() buffer_transport._buffer = buffer @@ -246,7 +267,7 @@ def _parse_header_format(self, buffer): self._protocol_id = readVarint(buffer_transport) - transforms = [] + transforms: list[int] = [] transform_count = readVarint(buffer_transport) for _ in range(transform_count): transform_id = readVarint(buffer_transport) @@ -258,7 +279,7 @@ def _parse_header_format(self, buffer): transforms.append(transform_id) transforms.reverse() - headers = {} + headers: dict[bytes, bytes] = {} while buffer.tell() < end_of_headers: header_type = readVarint(buffer_transport) if header_type == TInfoHeaderType.KEY_VALUE: @@ -280,10 +301,10 @@ def _parse_header_format(self, buffer): payload = transform_fn(payload) return BytesIO(payload) - def write(self, buf): + def write(self, buf: bytes) -> None: self._write_buffer.write(buf) - def flush(self): + def flush(self) -> None: payload = self._write_buffer.getvalue() self._write_buffer = BytesIO() @@ -340,10 +361,10 @@ def flush(self): self._transport.flush() @property - def cstringio_buf(self): + def cstringio_buf(self) -> BytesIO: return self._read_buffer - def cstringio_refill(self, partialread, reqlen): + def cstringio_refill(self, partialread: bytes, reqlen: int) -> BytesIO: result = bytearray(partialread) while len(result) < reqlen: result += self.read(reqlen - len(result)) diff --git a/lib/py/src/transport/THttpClient.py b/lib/py/src/thrift/transport/THttpClient.py similarity index 68% rename from lib/py/src/transport/THttpClient.py rename to lib/py/src/thrift/transport/THttpClient.py index 6281165ea25..4a55556af92 100644 --- a/lib/py/src/transport/THttpClient.py +++ b/lib/py/src/thrift/transport/THttpClient.py @@ -17,24 +17,53 @@ # under the License. # +from __future__ import annotations + from io import BytesIO import os import ssl import sys import warnings import base64 - +import http.client import urllib.parse import urllib.request -import http.client +from email.message import Message +from typing import Any -from .TTransport import TTransportBase +from .TTransport import TTransportBase, TTransportException class THttpClient(TTransportBase): """Http implementation of TTransport base.""" - def __init__(self, uri_or_host, port=None, path=None, cafile=None, cert_file=None, key_file=None, ssl_context=None): + scheme: str + host: str | None + port: int + path: str + context: ssl.SSLContext | None + realhost: str | None + realport: int | None + proxy_auth: str | None + __wbuf: BytesIO + __http: http.client.HTTPConnection | http.client.HTTPSConnection | None + __http_response: http.client.HTTPResponse | None + __timeout: float | None + __custom_headers: dict[str, str] | None + headers: Message | None + code: int + message: str + + def __init__( + self, + uri_or_host: str, + port: int | None = None, + path: str | None = None, + cafile: str | None = None, + cert_file: str | None = None, + key_file: str | None = None, + ssl_context: ssl.SSLContext | None = None, + ) -> None: """THttpClient supports two different types of construction: THttpClient(host, port, path) - deprecated @@ -65,7 +94,8 @@ def __init__(self, uri_or_host, port=None, path=None, cafile=None, cert_file=Non self.port = parsed.port or http.client.HTTPS_PORT if (cafile or cert_file or key_file) and not ssl_context: self.context = ssl.create_default_context(cafile=cafile) - self.context.load_cert_chain(certfile=cert_file, keyfile=key_file) + if cert_file: + self.context.load_cert_chain(certfile=cert_file, keyfile=key_file) else: self.context = ssl_context self.host = parsed.hostname @@ -77,14 +107,14 @@ def __init__(self, uri_or_host, port=None, path=None, cafile=None, cert_file=Non except KeyError: proxy = None else: - if urllib.request.proxy_bypass(self.host): + if self.host and urllib.request.proxy_bypass(self.host): proxy = None if proxy: parsed = urllib.parse.urlparse(proxy) self.realhost = self.host self.realport = self.port self.host = parsed.hostname - self.port = parsed.port + self.port = parsed.port or self.port # Fall back to original port if proxy port not specified self.proxy_auth = self.basic_proxy_auth_header(parsed) else: self.realhost = self.realport = self.proxy_auth = None @@ -96,18 +126,20 @@ def __init__(self, uri_or_host, port=None, path=None, cafile=None, cert_file=Non self.headers = None @staticmethod - def basic_proxy_auth_header(proxy): + def basic_proxy_auth_header(proxy: urllib.parse.ParseResult | None) -> str | None: if proxy is None or not proxy.username: return None ap = "%s:%s" % (urllib.parse.unquote(proxy.username), - urllib.parse.unquote(proxy.password)) + urllib.parse.unquote(proxy.password or '')) cr = base64.b64encode(ap.encode()).strip() - return "Basic " + six.ensure_str(cr) + return "Basic " + cr.decode('ascii') - def using_proxy(self): + def using_proxy(self) -> bool: return self.realhost is not None - def open(self): + def open(self) -> None: + if not self.host: + raise TTransportException(TTransportException.NOT_OPEN, "No host specified") if self.scheme == 'http': self.__http = http.client.HTTPConnection(self.host, self.port, timeout=self.__timeout) @@ -116,33 +148,33 @@ def open(self): timeout=self.__timeout, context=self.context) if self.using_proxy(): - self.__http.set_tunnel(self.realhost, self.realport, - {"Proxy-Authorization": self.proxy_auth}) + headers = {"Proxy-Authorization": self.proxy_auth} if self.proxy_auth else None + self.__http.set_tunnel(self.realhost, self.realport, headers) # type: ignore[union-attr] - def close(self): - self.__http.close() + def close(self) -> None: + self.__http.close() # type: ignore[union-attr] self.__http = None self.__http_response = None - def isOpen(self): + def isOpen(self) -> bool: return self.__http is not None - def setTimeout(self, ms): + def setTimeout(self, ms: int | None) -> None: if ms is None: self.__timeout = None else: self.__timeout = ms / 1000.0 - def setCustomHeaders(self, headers): + def setCustomHeaders(self, headers: dict[str, str]) -> None: self.__custom_headers = headers - def read(self, sz): - return self.__http_response.read(sz) + def read(self, sz: int) -> bytes: + return self.__http_response.read(sz) # type: ignore[union-attr] - def write(self, buf): + def write(self, buf: bytes) -> None: self.__wbuf.write(buf) - def flush(self): + def flush(self) -> None: if self.isOpen(): self.close() self.open() @@ -154,41 +186,41 @@ def flush(self): # HTTP request if self.using_proxy() and self.scheme == "http": # need full URL of real host for HTTP proxy here (HTTPS uses CONNECT tunnel) - self.__http.putrequest('POST', "http://%s:%s%s" % + self.__http.putrequest('POST', "http://%s:%s%s" % # type: ignore[union-attr] (self.realhost, self.realport, self.path)) else: - self.__http.putrequest('POST', self.path) + self.__http.putrequest('POST', self.path) # type: ignore[union-attr] # Write headers - self.__http.putheader('Content-Type', 'application/x-thrift') - self.__http.putheader('Content-Length', str(len(data))) + self.__http.putheader('Content-Type', 'application/x-thrift') # type: ignore[union-attr] + self.__http.putheader('Content-Length', str(len(data))) # type: ignore[union-attr] if self.using_proxy() and self.scheme == "http" and self.proxy_auth is not None: - self.__http.putheader("Proxy-Authorization", self.proxy_auth) + self.__http.putheader("Proxy-Authorization", self.proxy_auth) # type: ignore[union-attr] if not self.__custom_headers or 'User-Agent' not in self.__custom_headers: user_agent = 'Python/THttpClient' script = os.path.basename(sys.argv[0]) if script: user_agent = '%s (%s)' % (user_agent, urllib.parse.quote(script)) - self.__http.putheader('User-Agent', user_agent) + self.__http.putheader('User-Agent', user_agent) # type: ignore[union-attr] if self.__custom_headers: for key, val in self.__custom_headers.items(): - self.__http.putheader(key, val) + self.__http.putheader(key, val) # type: ignore[union-attr] # Saves the cookie sent by the server in the previous response. # HTTPConnection.putheader can only be called after a request has been # started, and before it's been sent. if self.headers and 'Set-Cookie' in self.headers: - self.__http.putheader('Cookie', self.headers['Set-Cookie']) + self.__http.putheader('Cookie', self.headers['Set-Cookie']) # type: ignore[union-attr] - self.__http.endheaders() + self.__http.endheaders() # type: ignore[union-attr] # Write payload - self.__http.send(data) + self.__http.send(data) # type: ignore[union-attr] # Get reply to flush the request - self.__http_response = self.__http.getresponse() + self.__http_response = self.__http.getresponse() # type: ignore[union-attr] self.code = self.__http_response.status self.message = self.__http_response.reason self.headers = self.__http_response.msg diff --git a/lib/py/src/transport/TSSLSocket.py b/lib/py/src/thrift/transport/TSSLSocket.py similarity index 79% rename from lib/py/src/transport/TSSLSocket.py rename to lib/py/src/thrift/transport/TSSLSocket.py index dc6c1fb5d31..fdafabdc1e0 100644 --- a/lib/py/src/transport/TSSLSocket.py +++ b/lib/py/src/thrift/transport/TSSLSocket.py @@ -17,30 +17,35 @@ # under the License. # +from __future__ import annotations + import logging import os import socket import ssl import sys import warnings +from typing import Any, Callable from .sslcompat import _match_has_ipaddress from thrift.transport import TSocket from thrift.transport.TTransport import TTransportException -_match_hostname = lambda cert, hostname: True +_match_hostname: Callable[[dict[str, Any], str], None] = lambda cert, hostname: None -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) warnings.filterwarnings( 'default', category=DeprecationWarning, module=__name__) -class TSSLBase(object): +class TSSLBase: + """Base class for SSL socket implementations.""" + # SSLContext is not available for Python < 2.7.9 - _has_ssl_context = sys.hexversion >= 0x020709F0 + _has_ssl_context: bool = sys.hexversion >= 0x020709F0 # ciphers argument is not available for Python < 2.7.0 - _has_ciphers = sys.hexversion >= 0x020700F0 + _has_ciphers: bool = sys.hexversion >= 0x020700F0 # For python >= 2.7.9, use latest TLS that both client and server # supports. @@ -48,6 +53,7 @@ class TSSLBase(object): # For python < 2.7.9, use TLS 1.0 since TLSv1_X nor OP_NO_SSLvX is # unavailable. # For python < 3.6, use SSLv23 since TLS is not available + _default_protocol: int if sys.version_info < (3, 6): _default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else \ ssl.PROTOCOL_TLSv1 @@ -55,7 +61,18 @@ class TSSLBase(object): _default_protocol = ssl.PROTOCOL_TLS_CLIENT if _has_ssl_context else \ ssl.PROTOCOL_TLSv1 - def _init_context(self, ssl_version): + _context: ssl.SSLContext | None + _ssl_version: int + _server_side: bool + _server_hostname: str | None + _custom_context: bool + cert_reqs: int + ca_certs: str | None + keyfile: str | None + _certfile: str | None + ciphers: str | None + + def _init_context(self, ssl_version: int) -> None: if self._has_ssl_context: self._context = ssl.SSLContext(ssl_version) if self._context.protocol == ssl.PROTOCOL_SSLv23: @@ -66,31 +83,31 @@ def _init_context(self, ssl_version): self._ssl_version = ssl_version @property - def _should_verify(self): + def _should_verify(self) -> bool: if self._has_ssl_context: - return self._context.verify_mode != ssl.CERT_NONE + return self._context.verify_mode != ssl.CERT_NONE # type: ignore[union-attr] else: return self.cert_reqs != ssl.CERT_NONE @property - def ssl_version(self): + def ssl_version(self) -> int: if self._has_ssl_context: - return self.ssl_context.protocol + return self.ssl_context.protocol # type: ignore[union-attr] else: return self._ssl_version @property - def ssl_context(self): + def ssl_context(self) -> ssl.SSLContext | None: return self._context - SSL_VERSION = _default_protocol + SSL_VERSION: int = _default_protocol """ Default SSL version. For backwards compatibility, it can be modified. Use __init__ keyword argument "ssl_version" instead. """ - def _deprecated_arg(self, args, kwargs, pos, key): + def _deprecated_arg(self, args: tuple[Any, ...], kwargs: dict[str, Any], pos: int, key: str) -> None: if len(args) <= pos: return real_pos = pos + 3 @@ -105,14 +122,14 @@ def _deprecated_arg(self, args, kwargs, pos, key): % (real_pos, key)) kwargs[key] = args[pos] - def _unix_socket_arg(self, host, port, args, kwargs): + def _unix_socket_arg(self, host: str | None, port: int | None, args: tuple[Any, ...], kwargs: dict[str, Any]) -> bool: key = 'unix_socket' if host is None and port is None and len(args) == 1 and key not in kwargs: kwargs[key] = args[0] return True return False - def __getattr__(self, key): + def __getattr__(self, key: str) -> Any: if key == 'SSL_VERSION': warnings.warn( 'SSL_VERSION is deprecated.' @@ -120,7 +137,7 @@ def __getattr__(self, key): DeprecationWarning, stacklevel=2) return self.ssl_version - def __init__(self, server_side, host, ssl_opts): + def __init__(self, server_side: bool, host: str | None, ssl_opts: dict[str, Any]) -> None: self._server_side = server_side if TSSLBase.SSL_VERSION != self._default_protocol: warnings.warn( @@ -164,47 +181,35 @@ def __init__(self, server_side, host, ssl_opts): 'certificates.' % (self.ca_certs)) @property - def certfile(self): + def certfile(self) -> str | None: return self._certfile @certfile.setter - def certfile(self, certfile): + def certfile(self, certfile: str | None) -> None: if self._server_side and not certfile: raise ValueError('certfile is needed for server-side') if certfile and not os.access(certfile, os.R_OK): raise IOError('No such certfile found: %s' % (certfile)) self._certfile = certfile - def _wrap_socket(self, sock): - if self._has_ssl_context: - if not self._custom_context: - self.ssl_context.verify_mode = self.cert_reqs - if self.certfile: - self.ssl_context.load_cert_chain(self.certfile, - self.keyfile) - if self.ciphers: - self.ssl_context.set_ciphers(self.ciphers) - if self.ca_certs: - self.ssl_context.load_verify_locations(self.ca_certs) - return self.ssl_context.wrap_socket( - sock, server_side=self._server_side, - server_hostname=self._server_hostname) - else: - ssl_opts = { - 'ssl_version': self._ssl_version, - 'server_side': self._server_side, - 'ca_certs': self.ca_certs, - 'keyfile': self.keyfile, - 'certfile': self.certfile, - 'cert_reqs': self.cert_reqs, - } + def _wrap_socket(self, sock: socket.socket) -> ssl.SSLSocket: + if not self._has_ssl_context: + # ssl.wrap_socket was removed in Python 3.12 + raise RuntimeError("SSLContext is required for Python 3.12+") + ctx = self.ssl_context + if ctx is None: + raise RuntimeError("ssl_context is None but _has_ssl_context is True") + if not self._custom_context: + ctx.verify_mode = ssl.VerifyMode(self.cert_reqs) + if self.certfile: + ctx.load_cert_chain(self.certfile, self.keyfile) if self.ciphers: - if self._has_ciphers: - ssl_opts['ciphers'] = self.ciphers - else: - logger.warning( - 'ciphers is specified but ignored due to old Python version') - return ssl.wrap_socket(sock, **ssl_opts) + ctx.set_ciphers(self.ciphers) + if self.ca_certs: + ctx.load_verify_locations(self.ca_certs) + return ctx.wrap_socket( + sock, server_side=self._server_side, + server_hostname=self._server_hostname) class TSSLSocket(TSocket.TSocket, TSSLBase): @@ -215,6 +220,10 @@ class TSSLSocket(TSocket.TSocket, TSSLBase): python standard ssl module for encrypted connections. """ + is_valid: bool + peercert: dict[str, Any] | None + _validate_callback: Callable[[dict[str, Any], str | None], None] + # New signature # def __init__(self, host='localhost', port=9090, unix_socket=None, # **ssl_args): @@ -222,7 +231,7 @@ class TSSLSocket(TSocket.TSocket, TSSLBase): # def __init__(self, host='localhost', port=9090, validate=True, # ca_certs=None, keyfile=None, certfile=None, # unix_socket=None, ciphers=None): - def __init__(self, host='localhost', port=9090, *args, **kwargs): + def __init__(self, host: str = 'localhost', port: int = 9090, *args: Any, **kwargs: Any) -> None: """Positional arguments: ``host``, ``port``, ``unix_socket`` Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, @@ -273,28 +282,28 @@ def __init__(self, host='localhost', port=9090, *args, **kwargs): TSocket.TSocket.__init__(self, host, port, unix_socket, socket_keepalive=socket_keepalive) - def close(self): + def close(self) -> None: try: - self.handle.settimeout(0.001) - self.handle = self.handle.unwrap() + self.handle.settimeout(0.001) # type: ignore[union-attr] + self.handle = self.handle.unwrap() # type: ignore[union-attr] except (ssl.SSLError, socket.error, OSError): # could not complete shutdown in a reasonable amount of time. bail. pass TSocket.TSocket.close(self) @property - def validate(self): + def validate(self) -> bool: warnings.warn('validate is deprecated. please use cert_reqs instead', DeprecationWarning, stacklevel=2) return self.cert_reqs != ssl.CERT_NONE @validate.setter - def validate(self, value): + def validate(self, value: bool) -> None: warnings.warn('validate is deprecated. please use cert_reqs instead', DeprecationWarning, stacklevel=2) self.cert_reqs = ssl.CERT_REQUIRED if value else ssl.CERT_NONE - def _do_open(self, family, socktype): + def _do_open(self, family: socket.AddressFamily, socktype: socket.SocketKind) -> ssl.SSLSocket: plain_sock = socket.socket(family, socktype) try: return self._wrap_socket(plain_sock) @@ -304,12 +313,12 @@ def _do_open(self, family, socktype): logger.exception(msg) raise TTransportException(type=TTransportException.NOT_OPEN, message=msg, inner=ex) - def open(self): + def open(self) -> None: super(TSSLSocket, self).open() if self._should_verify: - self.peercert = self.handle.getpeercert() + self.peercert = self.handle.getpeercert() # type: ignore[union-attr] try: - self._validate_callback(self.peercert, self._server_hostname) + self._validate_callback(self.peercert, self._server_hostname) # type: ignore[arg-type] self.is_valid = True except TTransportException: raise @@ -324,11 +333,13 @@ class TSSLServerSocket(TSocket.TServerSocket, TSSLBase): negotiated encryption. """ + _validate_callback: Callable[[dict[str, Any], str], None] + # New signature # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args): # Deprecated signature # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None): - def __init__(self, host=None, port=9090, *args, **kwargs): + def __init__(self, host: str | None = None, port: int = 9090, *args: Any, **kwargs: Any) -> None: """Positional arguments: ``host``, ``port``, ``unix_socket`` Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``, @@ -368,7 +379,7 @@ def __init__(self, host=None, port=9090, *args, **kwargs): raise ValueError('Need ipaddress and backports.ssl_match_hostname ' 'module to verify client certificate') - def setCertfile(self, certfile): + def setCertfile(self, certfile: str) -> None: """Set or change the server certificate file used to wrap new connections. @@ -383,8 +394,8 @@ def setCertfile(self, certfile): DeprecationWarning, stacklevel=2) self.certfile = certfile - def accept(self): - plain_client, addr = self.handle.accept() + def accept(self) -> TSocket.TSocket | None: + plain_client, addr = self.handle.accept() # type: ignore[union-attr] try: client = self._wrap_socket(plain_client) except (ssl.SSLError, socket.error, OSError): @@ -399,10 +410,10 @@ def accept(self): return None if self._should_verify: - client.peercert = client.getpeercert() + client.peercert = client.getpeercert() # type: ignore[attr-defined] try: - self._validate_callback(client.peercert, addr[0]) - client.is_valid = True + self._validate_callback(client.peercert, addr[0]) # type: ignore[attr-defined] + client.is_valid = True # type: ignore[attr-defined] except Exception: logger.warning('Failed to validate client certificate address: %s', addr[0], exc_info=True) diff --git a/lib/py/src/transport/TSocket.py b/lib/py/src/thrift/transport/TSocket.py similarity index 81% rename from lib/py/src/transport/TSocket.py rename to lib/py/src/thrift/transport/TSocket.py index 195bfcb57a9..be2417b109d 100644 --- a/lib/py/src/transport/TSocket.py +++ b/lib/py/src/thrift/transport/TSocket.py @@ -17,20 +17,29 @@ # under the License. # +from __future__ import annotations + import errno import logging import os import socket import sys import platform +from typing import Any -from .TTransport import TTransportBase, TTransportException, TServerTransportBase +from thrift.transport.TTransport import TTransportBase, TTransportException, TServerTransportBase logger = logging.getLogger(__name__) class TSocketBase(TTransportBase): - def _resolveAddr(self): + host: str + port: int + handle: socket.socket | None + _unix_socket: str | None + _socket_family: socket.AddressFamily + + def _resolveAddr(self) -> list[tuple[Any, ...]]: if self._unix_socket is not None: return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, self._unix_socket)] @@ -42,7 +51,7 @@ def _resolveAddr(self): 0, socket.AI_PASSIVE) - def close(self): + def close(self) -> None: if self.handle: self.handle.close() self.handle = None @@ -51,9 +60,17 @@ def close(self): class TSocket(TSocketBase): """Socket implementation of TTransport base.""" - def __init__(self, host='localhost', port=9090, unix_socket=None, - socket_family=socket.AF_UNSPEC, - socket_keepalive=False): + _timeout: float | None + _socket_keepalive: bool + + def __init__( + self, + host: str = 'localhost', + port: int = 9090, + unix_socket: str | None = None, + socket_family: socket.AddressFamily = socket.AF_UNSPEC, + socket_keepalive: bool = False, + ) -> None: """Initialize a TSocket @param host(str) The host to connect to. @@ -71,10 +88,10 @@ def __init__(self, host='localhost', port=9090, unix_socket=None, self._socket_family = socket_family self._socket_keepalive = socket_keepalive - def setHandle(self, h): + def setHandle(self, h: socket.socket) -> None: self.handle = h - def isOpen(self): + def isOpen(self) -> bool: if self.handle is None: return False @@ -108,7 +125,7 @@ def isOpen(self): self.close() return False - def setTimeout(self, ms): + def setTimeout(self, ms: int | None) -> None: if ms is None: self._timeout = None else: @@ -117,14 +134,14 @@ def setTimeout(self, ms): if self.handle is not None: self.handle.settimeout(self._timeout) - def _do_open(self, family, socktype): + def _do_open(self, family: socket.AddressFamily, socktype: socket.SocketKind) -> socket.socket: return socket.socket(family, socktype) @property - def _address(self): + def _address(self) -> str: return self._unix_socket if self._unix_socket else '%s:%d' % (self.host, self.port) - def open(self): + def open(self) -> None: if self.handle: raise TTransportException(type=TTransportException.ALREADY_OPEN, message="already open") try: @@ -134,7 +151,7 @@ def open(self): logger.exception(msg) raise TTransportException(type=TTransportException.NOT_OPEN, message=msg, inner=gai) # Preserve the last exception to report if all addresses fail. - last_exc = None + last_exc: Exception | None = None for family, socktype, _, _, sockaddr in addrs: handle = self._do_open(family, socktype) @@ -156,9 +173,9 @@ def open(self): logger.error(msg) raise TTransportException(type=TTransportException.NOT_OPEN, message=msg, inner=last_exc) - def read(self, sz): + def read(self, sz: int) -> bytes: try: - buff = self.handle.recv(sz) + buff = self.handle.recv(sz) # type: ignore[union-attr] # TODO: remove socket.timeout when 3.10 becomes the earliest version of python supported. except (socket.timeout, TimeoutError) as e: raise TTransportException(type=TTransportException.TIMED_OUT, message="read timeout", inner=e) @@ -171,7 +188,7 @@ def read(self, sz): # in lib/cpp/src/transport/TSocket.cpp. self.close() # Trigger the check to raise the END_OF_FILE exception below. - buff = '' + buff = b'' else: raise TTransportException(message="unexpected exception", inner=e) if len(buff) == 0: @@ -179,49 +196,60 @@ def read(self, sz): message='TSocket read 0 bytes') return buff - def write(self, buff): + def write(self, buf: bytes) -> None: if not self.handle: raise TTransportException(type=TTransportException.NOT_OPEN, message='Transport not open') sent = 0 - have = len(buff) + have = len(buf) while sent < have: try: - plus = self.handle.send(buff) + plus = self.handle.send(buf) if plus == 0: raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket sent 0 bytes') sent += plus - buff = buff[plus:] + buf = buf[plus:] except socket.error as e: raise TTransportException(message="unexpected exception", inner=e) - def flush(self): + def flush(self) -> None: pass class TServerSocket(TSocketBase, TServerTransportBase): """Socket implementation of TServerTransport base.""" - def __init__(self, host=None, port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC): - self.host = host + _backlog: int + + def __init__( + self, + host: str | None = None, + port: int = 9090, + unix_socket: str | None = None, + socket_family: socket.AddressFamily = socket.AF_UNSPEC, + ) -> None: + self.host = host or 'localhost' self.port = port self._unix_socket = unix_socket self._socket_family = socket_family self.handle = None self._backlog = 128 - def setBacklog(self, backlog=None): + def setBacklog(self, backlog: int | None = None) -> None: if not self.handle: - self._backlog = backlog + self._backlog = backlog or 128 else: # We cann't update backlog when it is already listening, since the # handle has been created. logger.warning('You have to set backlog before listen.') - def listen(self): + def listen(self) -> None: res0 = self._resolveAddr() + if not res0: + raise TTransportException(TTransportException.NOT_OPEN, "Could not resolve address") socket_family = self._socket_family == socket.AF_UNSPEC and socket.AF_INET6 or self._socket_family + res = res0[-1] # Default to last result for res in res0: if res[0] is socket_family or res is res0[-1]: break @@ -249,8 +277,8 @@ def listen(self): s.bind(res[4]) s.listen(self._backlog) - def accept(self): - client, addr = self.handle.accept() + def accept(self) -> TSocket | None: + client, addr = self.handle.accept() # type: ignore[union-attr] result = TSocket() result.setHandle(client) return result diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/thrift/transport/TTransport.py similarity index 66% rename from lib/py/src/transport/TTransport.py rename to lib/py/src/thrift/transport/TTransport.py index 4f6b67fe123..cba0139232f 100644 --- a/lib/py/src/transport/TTransport.py +++ b/lib/py/src/thrift/transport/TTransport.py @@ -17,49 +17,66 @@ # under the License. # +from __future__ import annotations + +from abc import ABC, abstractmethod from io import BytesIO from struct import pack, unpack +from typing import TYPE_CHECKING, Any, BinaryIO from thrift.Thrift import TException +if TYPE_CHECKING: + from puresasl.client import SASLClient # type: ignore[import-not-found] + class TTransportException(TException): """Custom Transport Exception class""" - UNKNOWN = 0 - NOT_OPEN = 1 - ALREADY_OPEN = 2 - TIMED_OUT = 3 - END_OF_FILE = 4 - NEGATIVE_SIZE = 5 - SIZE_LIMIT = 6 - INVALID_CLIENT_TYPE = 7 - - def __init__(self, type=UNKNOWN, message=None, inner=None): + UNKNOWN: int = 0 + NOT_OPEN: int = 1 + ALREADY_OPEN: int = 2 + TIMED_OUT: int = 3 + END_OF_FILE: int = 4 + NEGATIVE_SIZE: int = 5 + SIZE_LIMIT: int = 6 + INVALID_CLIENT_TYPE: int = 7 + + type: int + inner: Exception | None + + def __init__( + self, + type: int = UNKNOWN, + message: str | None = None, + inner: Exception | None = None, + ) -> None: TException.__init__(self, message) self.type = type self.inner = inner -class TTransportBase(object): +class TTransportBase(ABC): """Base class for Thrift transport layer.""" - def isOpen(self): + @abstractmethod + def isOpen(self) -> bool: pass - def open(self): + def open(self) -> None: pass - def close(self): + def close(self) -> None: pass - def read(self, sz): + @abstractmethod + def read(self, sz: int) -> bytes: pass - def readAll(self, sz): + def readAll(self, sz: int) -> bytes: buff = b'' have = 0 - while (have < sz): + while have < sz: chunk = self.read(sz - have) chunkLen = len(chunk) have += chunkLen @@ -70,29 +87,24 @@ def readAll(self, sz): return buff - def write(self, buf): + def write(self, buf: bytes) -> None: pass - def flush(self): + def flush(self) -> None: pass -# This class should be thought of as an interface. -class CReadableTransport(object): - """base class for transports that are readable from C""" +class CReadableTransport(ABC): + """Base class for transports that are readable from C""" - # TODO(dreiss): Think about changing this interface to allow us to use - # a (Python, not c) StringIO instead, because it allows - # you to write after reading. - - # NOTE: This is a classic class, so properties will NOT work - # correctly for setting. @property - def cstringio_buf(self): + @abstractmethod + def cstringio_buf(self) -> BytesIO: """A cStringIO buffer that contains the current chunk we are reading.""" pass - def cstringio_refill(self, partialread, reqlen): + @abstractmethod + def cstringio_refill(self, partialread: bytes, reqlen: int) -> BytesIO: """Refills cstringio_buf. Returns the currently used buffer (which can but need not be the same as @@ -106,30 +118,30 @@ def cstringio_refill(self, partialread, reqlen): pass -class TServerTransportBase(object): +class TServerTransportBase: """Base class for Thrift server transports.""" - def listen(self): + def listen(self) -> None: pass - def accept(self): + def accept(self) -> TTransportBase | None: pass - def close(self): + def close(self) -> None: pass -class TTransportFactoryBase(object): +class TTransportFactoryBase: """Base class for a Transport Factory""" - def getTransport(self, trans): + def getTransport(self, trans: TTransportBase) -> TTransportBase: return trans -class TBufferedTransportFactory(object): +class TBufferedTransportFactory: """Factory transport that builds buffered transports""" - def getTransport(self, trans): + def getTransport(self, trans: TTransportBase) -> TBufferedTransport: buffered = TBufferedTransport(trans) return buffered @@ -140,32 +152,38 @@ class TBufferedTransport(TTransportBase, CReadableTransport): The implementation uses a (configurable) fixed-size read buffer but buffers all writes until a flush is performed. """ - DEFAULT_BUFFER = 4096 - def __init__(self, trans, rbuf_size=DEFAULT_BUFFER): + DEFAULT_BUFFER: int = 4096 + + __trans: TTransportBase + __wbuf: BytesIO + __rbuf: BytesIO + __rbuf_size: int + + def __init__(self, trans: TTransportBase, rbuf_size: int = DEFAULT_BUFFER) -> None: self.__trans = trans self.__wbuf = BytesIO() # Pass string argument to initialize read buffer as cStringIO.InputType self.__rbuf = BytesIO(b'') self.__rbuf_size = rbuf_size - def isOpen(self): + def isOpen(self) -> bool: return self.__trans.isOpen() - def open(self): + def open(self) -> None: return self.__trans.open() - def close(self): + def close(self) -> None: return self.__trans.close() - def read(self, sz): + def read(self, sz: int) -> bytes: ret = self.__rbuf.read(sz) if len(ret) != 0: return ret self.__rbuf = BytesIO(self.__trans.read(max(sz, self.__rbuf_size))) return self.__rbuf.read(sz) - def write(self, buf): + def write(self, buf: bytes) -> None: try: self.__wbuf.write(buf) except Exception as e: @@ -173,7 +191,7 @@ def write(self, buf): self.__wbuf = BytesIO() raise e - def flush(self): + def flush(self) -> None: out = self.__wbuf.getvalue() # reset wbuf before write/flush to preserve state on underlying failure self.__wbuf = BytesIO() @@ -182,10 +200,10 @@ def flush(self): # Implement the CReadableTransport interface. @property - def cstringio_buf(self): + def cstringio_buf(self) -> BytesIO: return self.__rbuf - def cstringio_refill(self, partialread, reqlen): + def cstringio_refill(self, partialread: bytes, reqlen: int) -> BytesIO: retstring = partialread if reqlen < self.__rbuf_size: # try to make a read of as much as we can. @@ -208,7 +226,9 @@ class TMemoryBuffer(TTransportBase, CReadableTransport): TODO(dreiss): Make this work like the C++ version. """ - def __init__(self, value=None, offset=0): + _buffer: BytesIO + + def __init__(self, value: bytes | None = None, offset: int = 0) -> None: """value -- a value to read from for stringio If value is set, this will be a transport for reading, @@ -220,41 +240,41 @@ def __init__(self, value=None, offset=0): if offset: self._buffer.seek(offset) - def isOpen(self): + def isOpen(self) -> bool: return not self._buffer.closed - def open(self): + def open(self) -> None: pass - def close(self): + def close(self) -> None: self._buffer.close() - def read(self, sz): + def read(self, sz: int) -> bytes: return self._buffer.read(sz) - def write(self, buf): + def write(self, buf: bytes) -> None: self._buffer.write(buf) - def flush(self): + def flush(self) -> None: pass - def getvalue(self): + def getvalue(self) -> bytes: return self._buffer.getvalue() # Implement the CReadableTransport interface. @property - def cstringio_buf(self): + def cstringio_buf(self) -> BytesIO: return self._buffer - def cstringio_refill(self, partialread, reqlen): + def cstringio_refill(self, partialread: bytes, reqlen: int) -> BytesIO: # only one shot at reading... raise EOFError() -class TFramedTransportFactory(object): +class TFramedTransportFactory: """Factory transport that builds framed transports""" - def getTransport(self, trans): + def getTransport(self, trans: TTransportBase) -> TFramedTransport: framed = TFramedTransport(trans) return framed @@ -262,21 +282,25 @@ def getTransport(self, trans): class TFramedTransport(TTransportBase, CReadableTransport): """Class that wraps another transport and frames its I/O when writing.""" - def __init__(self, trans,): + __trans: TTransportBase + __rbuf: BytesIO + __wbuf: BytesIO + + def __init__(self, trans: TTransportBase) -> None: self.__trans = trans self.__rbuf = BytesIO(b'') self.__wbuf = BytesIO() - def isOpen(self): + def isOpen(self) -> bool: return self.__trans.isOpen() - def open(self): + def open(self) -> None: return self.__trans.open() - def close(self): + def close(self) -> None: return self.__trans.close() - def read(self, sz): + def read(self, sz: int) -> bytes: ret = self.__rbuf.read(sz) if len(ret) != 0: return ret @@ -284,15 +308,15 @@ def read(self, sz): self.readFrame() return self.__rbuf.read(sz) - def readFrame(self): + def readFrame(self) -> None: buff = self.__trans.readAll(4) - sz, = unpack('!i', buff) + (sz,) = unpack('!i', buff) self.__rbuf = BytesIO(self.__trans.readAll(sz)) - def write(self, buf): + def write(self, buf: bytes) -> None: self.__wbuf.write(buf) - def flush(self): + def flush(self) -> None: wout = self.__wbuf.getvalue() wsz = len(wout) # reset wbuf before write/flush to preserve state on underlying failure @@ -307,39 +331,41 @@ def flush(self): # Implement the CReadableTransport interface. @property - def cstringio_buf(self): + def cstringio_buf(self) -> BytesIO: return self.__rbuf - def cstringio_refill(self, prefix, reqlen): + def cstringio_refill(self, partialread: bytes, reqlen: int) -> BytesIO: # self.__rbuf will already be empty here because fastbinary doesn't # ask for a refill until the previous buffer is empty. Therefore, # we can start reading new frames immediately. - while len(prefix) < reqlen: + while len(partialread) < reqlen: self.readFrame() - prefix += self.__rbuf.getvalue() - self.__rbuf = BytesIO(prefix) + partialread += self.__rbuf.getvalue() + self.__rbuf = BytesIO(partialread) return self.__rbuf class TFileObjectTransport(TTransportBase): """Wraps a file-like object to make it work as a Thrift transport.""" - def __init__(self, fileobj): + fileobj: BinaryIO + + def __init__(self, fileobj: BinaryIO) -> None: self.fileobj = fileobj - def isOpen(self): + def isOpen(self) -> bool: return True - def close(self): + def close(self) -> None: self.fileobj.close() - def read(self, sz): + def read(self, sz: int) -> bytes: return self.fileobj.read(sz) - def write(self, buf): + def write(self, buf: bytes) -> None: self.fileobj.write(buf) - def flush(self): + def flush(self) -> None: self.fileobj.flush() @@ -348,14 +374,25 @@ class TSaslClientTransport(TTransportBase, CReadableTransport): SASL transport """ - START = 1 - OK = 2 - BAD = 3 - ERROR = 4 - COMPLETE = 5 - - def __init__(self, transport, host, service, mechanism='GSSAPI', - **sasl_kwargs): + START: int = 1 + OK: int = 2 + BAD: int = 3 + ERROR: int = 4 + COMPLETE: int = 5 + + transport: TTransportBase + sasl: SASLClient + __wbuf: BytesIO + __rbuf: BytesIO + + def __init__( + self, + transport: TTransportBase, + host: str, + service: str, + mechanism: str = 'GSSAPI', + **sasl_kwargs: Any, + ) -> None: """ transport: an underlying transport to use, typically just a TSocket host: the name of the server, from a SASL perspective @@ -366,7 +403,7 @@ def __init__(self, transport, host, service, mechanism='GSSAPI', constructor. """ - from puresasl.client import SASLClient + from puresasl.client import SASLClient # type: ignore[import-not-found] self.transport = transport self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) @@ -374,11 +411,13 @@ def __init__(self, transport, host, service, mechanism='GSSAPI', self.__wbuf = BytesIO() self.__rbuf = BytesIO(b'') - def open(self): + def open(self) -> None: if not self.transport.isOpen(): self.transport.open() - self.send_sasl_msg(self.START, bytes(self.sasl.mechanism, 'ascii')) + mechanism = self.sasl.mechanism + assert mechanism is not None, "SASL mechanism must be set" + self.send_sasl_msg(self.START, bytes(mechanism, 'ascii')) self.send_sasl_msg(self.OK, self.sasl.process()) while True: @@ -390,43 +429,44 @@ def open(self): raise TTransportException( TTransportException.NOT_OPEN, "The server erroneously indicated " - "that SASL negotiation was complete") + "that SASL negotiation was complete", + ) else: break else: raise TTransportException( TTransportException.NOT_OPEN, - "Bad SASL negotiation status: %d (%s)" - % (status, challenge)) + "Bad SASL negotiation status: %d (%s)" % (status, challenge), + ) - def isOpen(self): + def isOpen(self) -> bool: return self.transport.isOpen() - def send_sasl_msg(self, status, body): + def send_sasl_msg(self, status: int, body: bytes) -> None: header = pack(">BI", status, len(body)) self.transport.write(header + body) self.transport.flush() - def recv_sasl_msg(self): + def recv_sasl_msg(self) -> tuple[int, bytes | str]: header = self.transport.readAll(5) status, length = unpack(">BI", header) if length > 0: - payload = self.transport.readAll(length) + payload: bytes | str = self.transport.readAll(length) else: payload = "" return status, payload - def write(self, data): - self.__wbuf.write(data) + def write(self, buf: bytes) -> None: + self.__wbuf.write(buf) - def flush(self): + def flush(self) -> None: data = self.__wbuf.getvalue() encoded = self.sasl.wrap(data) self.transport.write(pack("!i", len(encoded)) + encoded) self.transport.flush() self.__wbuf = BytesIO() - def read(self, sz): + def read(self, sz: int) -> bytes: ret = self.__rbuf.read(sz) if len(ret) != 0: return ret @@ -434,27 +474,27 @@ def read(self, sz): self._read_frame() return self.__rbuf.read(sz) - def _read_frame(self): + def _read_frame(self) -> None: header = self.transport.readAll(4) - length, = unpack('!i', header) + (length,) = unpack('!i', header) encoded = self.transport.readAll(length) self.__rbuf = BytesIO(self.sasl.unwrap(encoded)) - def close(self): + def close(self) -> None: self.sasl.dispose() self.transport.close() # based on TFramedTransport @property - def cstringio_buf(self): + def cstringio_buf(self) -> BytesIO: return self.__rbuf - def cstringio_refill(self, prefix, reqlen): + def cstringio_refill(self, partialread: bytes, reqlen: int) -> BytesIO: # self.__rbuf will already be empty here because fastbinary doesn't # ask for a refill until the previous buffer is empty. Therefore, # we can start reading new frames immediately. - while len(prefix) < reqlen: + while len(partialread) < reqlen: self._read_frame() - prefix += self.__rbuf.getvalue() - self.__rbuf = BytesIO(prefix) + partialread += self.__rbuf.getvalue() + self.__rbuf = BytesIO(partialread) return self.__rbuf diff --git a/lib/py/src/transport/TTwisted.py b/lib/py/src/thrift/transport/TTwisted.py similarity index 59% rename from lib/py/src/transport/TTwisted.py rename to lib/py/src/thrift/transport/TTwisted.py index a27f0adade2..859c92f60bd 100644 --- a/lib/py/src/transport/TTwisted.py +++ b/lib/py/src/thrift/transport/TTwisted.py @@ -17,52 +17,76 @@ # under the License. # +from __future__ import annotations + from io import BytesIO import struct +from typing import Any, TYPE_CHECKING -from zope.interface import implementer, Interface, Attribute -from twisted.internet.protocol import ServerFactory, ClientFactory, \ - connectionDone -from twisted.internet import defer -from twisted.internet.threads import deferToThread -from twisted.protocols import basic -from twisted.web import server, resource, http +from zope.interface import implementer, Interface, Attribute # type: ignore[import-untyped] +from twisted.internet.protocol import ServerFactory, ClientFactory, connectionDone # type: ignore[import-untyped] +from twisted.internet import defer # type: ignore[import-untyped] +from twisted.internet.threads import deferToThread # type: ignore[import-untyped] +from twisted.protocols import basic # type: ignore[import-untyped] +from twisted.web import server, resource, http # type: ignore[import-untyped] from thrift.transport import TTransport +if TYPE_CHECKING: + from thrift.protocol.TProtocol import TProtocolFactory + class TMessageSenderTransport(TTransport.TTransportBase): + """Base transport class for message-based sending.""" + + __wbuf: BytesIO - def __init__(self): + def __init__(self) -> None: self.__wbuf = BytesIO() - def write(self, buf): + def write(self, buf: bytes) -> None: self.__wbuf.write(buf) - def flush(self): + def flush(self) -> Any: msg = self.__wbuf.getvalue() self.__wbuf = BytesIO() return self.sendMessage(msg) - def sendMessage(self, message): + def sendMessage(self, message: bytes) -> Any: raise NotImplementedError class TCallbackTransport(TMessageSenderTransport): + """Transport that invokes a callback function for message sending.""" - def __init__(self, func): + func: Any # Callable[[bytes], Any] + + def __init__(self, func: Any) -> None: TMessageSenderTransport.__init__(self) self.func = func - def sendMessage(self, message): + def sendMessage(self, message: bytes) -> Any: return self.func(message) -class ThriftClientProtocol(basic.Int32StringReceiver): +class ThriftClientProtocol(basic.Int32StringReceiver): # type: ignore[misc] + """Twisted protocol for Thrift clients.""" + + MAX_LENGTH: int = 2 ** 31 - 1 - MAX_LENGTH = 2 ** 31 - 1 + _client_class: type[Any] + _iprot_factory: TProtocolFactory + _oprot_factory: TProtocolFactory + recv_map: dict[str, Any] + started: defer.Deferred[Any] + client: Any - def __init__(self, client_class, iprot_factory, oprot_factory=None): + def __init__( + self, + client_class: type[Any], + iprot_factory: TProtocolFactory, + oprot_factory: TProtocolFactory | None = None, + ) -> None: self._client_class = client_class self._iprot_factory = iprot_factory if oprot_factory is None: @@ -73,15 +97,15 @@ def __init__(self, client_class, iprot_factory, oprot_factory=None): self.recv_map = {} self.started = defer.Deferred() - def dispatch(self, msg): + def dispatch(self, msg: bytes) -> None: self.sendString(msg) - def connectionMade(self): + def connectionMade(self) -> None: tmo = TCallbackTransport(self.dispatch) self.client = self._client_class(tmo, self._oprot_factory) self.started.callback(self.client) - def connectionLost(self, reason=connectionDone): + def connectionLost(self, reason: Any = connectionDone) -> None: # the called errbacks can add items to our client's _reqs, # so we need to use a tmp, and iterate until no more requests # are added during errbacks @@ -95,7 +119,7 @@ def connectionLost(self, reason=connectionDone): del self.client._reqs self.client = None - def stringReceived(self, frame): + def stringReceived(self, frame: bytes) -> None: tr = TTransport.TMemoryBuffer(frame) iprot = self._iprot_factory.getProtocol(tr) (fname, mtype, rseqid) = iprot.readMessageBegin() @@ -110,17 +134,31 @@ def stringReceived(self, frame): class ThriftSASLClientProtocol(ThriftClientProtocol): - - START = 1 - OK = 2 - BAD = 3 - ERROR = 4 - COMPLETE = 5 - - MAX_LENGTH = 2 ** 31 - 1 - - def __init__(self, client_class, iprot_factory, oprot_factory=None, - host=None, service=None, mechanism='GSSAPI', **sasl_kwargs): + """Twisted protocol for SASL-authenticated Thrift clients.""" + + START: int = 1 + OK: int = 2 + BAD: int = 3 + ERROR: int = 4 + COMPLETE: int = 5 + + MAX_LENGTH: int = 2 ** 31 - 1 + + SASLClient: type[Any] + sasl: Any + _sasl_negotiation_deferred: defer.Deferred[Any] | None + _sasl_negotiation_status: int | None + + def __init__( + self, + client_class: type[Any], + iprot_factory: TProtocolFactory, + oprot_factory: TProtocolFactory | None = None, + host: str | None = None, + service: str | None = None, + mechanism: str = 'GSSAPI', + **sasl_kwargs: Any, + ) -> None: """ host: the name of the server, from a SASL perspective service: the name of the server's service, from a SASL perspective @@ -130,7 +168,7 @@ def __init__(self, client_class, iprot_factory, oprot_factory=None, constructor. """ - from puresasl.client import SASLClient + from puresasl.client import SASLClient # type: ignore[import-untyped] self.SASLCLient = SASLClient ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory) @@ -142,16 +180,16 @@ def __init__(self, client_class, iprot_factory, oprot_factory=None, if host is not None: self.createSASLClient(host, service, mechanism, **sasl_kwargs) - def createSASLClient(self, host, service, mechanism, **kwargs): + def createSASLClient(self, host: str, service: str | None, mechanism: str, **kwargs: Any) -> None: self.sasl = self.SASLClient(host, service, mechanism, **kwargs) - def dispatch(self, msg): + def dispatch(self, msg: bytes) -> None: encoded = self.sasl.wrap(msg) - len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded)) + len_and_encoded = b''.join((struct.pack('!i', len(encoded)), encoded)) ThriftClientProtocol.dispatch(self, len_and_encoded) - @defer.inlineCallbacks - def connectionMade(self): + @defer.inlineCallbacks # type: ignore[misc] + def connectionMade(self) -> Any: self._sendSASLMessage(self.START, self.sasl.mechanism) initial_message = yield deferToThread(self.sasl.process) self._sendSASLMessage(self.OK, initial_message) @@ -165,42 +203,44 @@ def connectionMade(self): if not self.sasl.complete: msg = "The server erroneously indicated that SASL " \ "negotiation was complete" - raise TTransport.TTransportException(msg, message=msg) + raise TTransport.TTransportException(message=msg) else: break else: msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge) - raise TTransport.TTransportException(msg, message=msg) + raise TTransport.TTransportException(message=msg) self._sasl_negotiation_deferred = None ThriftClientProtocol.connectionMade(self) - def _sendSASLMessage(self, status, body): + def _sendSASLMessage(self, status: int, body: bytes | str | None) -> None: if body is None: - body = "" + body = b"" + elif isinstance(body, str): + body = body.encode('utf-8') header = struct.pack(">BI", status, len(body)) self.transport.write(header + body) - def _receiveSASLMessage(self): + def _receiveSASLMessage(self) -> defer.Deferred[tuple[int | None, bytes]]: self._sasl_negotiation_deferred = defer.Deferred() self._sasl_negotiation_status = None return self._sasl_negotiation_deferred - def connectionLost(self, reason=connectionDone): + def connectionLost(self, reason: Any = connectionDone) -> None: if self.client: ThriftClientProtocol.connectionLost(self, reason) - def dataReceived(self, data): + def dataReceived(self, data: bytes) -> None: if self._sasl_negotiation_deferred: # we got a sasl challenge in the format (status, length, challenge) # save the status, let IntNStringReceiver piece the challenge data together - self._sasl_negotiation_status, = struct.unpack("B", data[0]) + self._sasl_negotiation_status, = struct.unpack("B", data[0:1]) ThriftClientProtocol.dataReceived(self, data[1:]) else: # normal frame, let IntNStringReceiver piece it together ThriftClientProtocol.dataReceived(self, data) - def stringReceived(self, frame): + def stringReceived(self, frame: bytes) -> None: if self._sasl_negotiation_deferred: # the frame is just a SASL challenge response = (self._sasl_negotiation_status, frame) @@ -211,23 +251,25 @@ def stringReceived(self, frame): ThriftClientProtocol.stringReceived(self, decoded_frame) -class ThriftServerProtocol(basic.Int32StringReceiver): +class ThriftServerProtocol(basic.Int32StringReceiver): # type: ignore[misc] + """Twisted protocol for Thrift servers.""" - MAX_LENGTH = 2 ** 31 - 1 + MAX_LENGTH: int = 2 ** 31 - 1 + factory: ThriftServerFactory - def dispatch(self, msg): + def dispatch(self, msg: bytes) -> None: self.sendString(msg) - def processError(self, error): + def processError(self, error: Any) -> None: self.transport.loseConnection() - def processOk(self, _, tmo): + def processOk(self, _: Any, tmo: TTransport.TMemoryBuffer) -> None: msg = tmo.getvalue() if len(msg) > 0: self.dispatch(msg) - def stringReceived(self, frame): + def stringReceived(self, frame: bytes) -> None: tmi = TTransport.TMemoryBuffer(frame) tmo = TTransport.TMemoryBuffer() @@ -239,7 +281,8 @@ def stringReceived(self, frame): callbackArgs=(tmo,)) -class IThriftServerFactory(Interface): +class IThriftServerFactory(Interface): # type: ignore[misc] + """Interface for Thrift server factories.""" processor = Attribute("Thrift processor") @@ -248,7 +291,8 @@ class IThriftServerFactory(Interface): oprot_factory = Attribute("Output protocol factory") -class IThriftClientFactory(Interface): +class IThriftClientFactory(Interface): # type: ignore[misc] + """Interface for Thrift client factories.""" client_class = Attribute("Thrift client class") @@ -258,11 +302,21 @@ class IThriftClientFactory(Interface): @implementer(IThriftServerFactory) -class ThriftServerFactory(ServerFactory): +class ThriftServerFactory(ServerFactory): # type: ignore[misc] + """Factory for creating Thrift server protocols.""" + + protocol: type[ThriftServerProtocol] = ThriftServerProtocol - protocol = ThriftServerProtocol + processor: Any + iprot_factory: TProtocolFactory + oprot_factory: TProtocolFactory - def __init__(self, processor, iprot_factory, oprot_factory=None): + def __init__( + self, + processor: Any, + iprot_factory: TProtocolFactory, + oprot_factory: TProtocolFactory | None = None, + ) -> None: self.processor = processor self.iprot_factory = iprot_factory if oprot_factory is None: @@ -272,11 +326,21 @@ def __init__(self, processor, iprot_factory, oprot_factory=None): @implementer(IThriftClientFactory) -class ThriftClientFactory(ClientFactory): +class ThriftClientFactory(ClientFactory): # type: ignore[misc] + """Factory for creating Thrift client protocols.""" - protocol = ThriftClientProtocol + protocol: type[ThriftClientProtocol] = ThriftClientProtocol - def __init__(self, client_class, iprot_factory, oprot_factory=None): + client_class: type[Any] + iprot_factory: TProtocolFactory + oprot_factory: TProtocolFactory + + def __init__( + self, + client_class: type[Any], + iprot_factory: TProtocolFactory, + oprot_factory: TProtocolFactory | None = None, + ) -> None: self.client_class = client_class self.iprot_factory = iprot_factory if oprot_factory is None: @@ -284,19 +348,28 @@ def __init__(self, client_class, iprot_factory, oprot_factory=None): else: self.oprot_factory = oprot_factory - def buildProtocol(self, addr): + def buildProtocol(self, addr: Any) -> ThriftClientProtocol: p = self.protocol(self.client_class, self.iprot_factory, self.oprot_factory) - p.factory = self + p.factory = self # type: ignore[attr-defined] return p -class ThriftResource(resource.Resource): +class ThriftResource(resource.Resource): # type: ignore[misc] + """Twisted web resource for serving Thrift over HTTP.""" + + allowedMethods: tuple[str, ...] = ('POST',) - allowedMethods = ('POST',) + inputProtocolFactory: TProtocolFactory + outputProtocolFactory: TProtocolFactory + processor: Any - def __init__(self, processor, inputProtocolFactory, - outputProtocolFactory=None): + def __init__( + self, + processor: Any, + inputProtocolFactory: TProtocolFactory, + outputProtocolFactory: TProtocolFactory | None = None, + ) -> None: resource.Resource.__init__(self) self.inputProtocolFactory = inputProtocolFactory if outputProtocolFactory is None: @@ -305,17 +378,17 @@ def __init__(self, processor, inputProtocolFactory, self.outputProtocolFactory = outputProtocolFactory self.processor = processor - def getChild(self, path, request): + def getChild(self, path: bytes, request: Any) -> ThriftResource: return self - def _cbProcess(self, _, request, tmo): + def _cbProcess(self, _: Any, request: Any, tmo: TTransport.TMemoryBuffer) -> None: msg = tmo.getvalue() request.setResponseCode(http.OK) request.setHeader("content-type", "application/x-thrift") request.write(msg) request.finish() - def render_POST(self, request): + def render_POST(self, request: Any) -> int: request.content.seek(0, 0) data = request.content.read() tmi = TTransport.TMemoryBuffer(data) diff --git a/lib/py/src/transport/TZlibTransport.py b/lib/py/src/thrift/transport/TZlibTransport.py similarity index 83% rename from lib/py/src/transport/TZlibTransport.py rename to lib/py/src/thrift/transport/TZlibTransport.py index a476d2a0aae..d4ea01a4985 100644 --- a/lib/py/src/transport/TZlibTransport.py +++ b/lib/py/src/thrift/transport/TZlibTransport.py @@ -22,10 +22,12 @@ data compression. """ +from __future__ import annotations + import zlib from io import BytesIO -from .TTransport import TTransportBase, CReadableTransport +from thrift.transport.TTransport import TTransportBase, CReadableTransport class TZlibTransportFactory: @@ -44,10 +46,10 @@ class TZlibTransportFactory: easier to understand. """ # class scoped cache of last transport given and zlibtransport returned - _last_trans = None - _last_z = None + _last_trans: TTransportBase | None = None + _last_z: TZlibTransport | None = None - def getTransport(self, trans, compresslevel=9): + def getTransport(self, trans: TTransportBase, compresslevel: int = 9) -> TZlibTransport: """Wrap a transport, trans, with the TZlibTransport compressed transport class, returning a new transport to the caller. @@ -60,7 +62,7 @@ def getTransport(self, trans, compresslevel=9): passed C{trans} TTransport derived instance. """ if trans == self._last_trans: - return self._last_z + return self._last_z # type: ignore[return-value] ztrans = TZlibTransport(trans, compresslevel) self._last_trans = trans self._last_z = ztrans @@ -74,9 +76,20 @@ class TZlibTransport(TTransportBase, CReadableTransport): """ # Read buffer size for the python fastbinary C extension, # the TBinaryProtocolAccelerated class. - DEFAULT_BUFFSIZE = 4096 - - def __init__(self, trans, compresslevel=9): + DEFAULT_BUFFSIZE: int = 4096 + + __trans: TTransportBase + compresslevel: int + __rbuf: BytesIO + __wbuf: BytesIO + _zcomp_read: zlib._Decompress + _zcomp_write: zlib._Compress + bytes_in: int + bytes_out: int + bytes_in_comp: int + bytes_out_comp: int + + def __init__(self, trans: TTransportBase, compresslevel: int = 9) -> None: """Create a new TZlibTransport, wrapping C{trans}, another TTransport derived object. @@ -93,14 +106,14 @@ def __init__(self, trans, compresslevel=9): self._init_zlib() self._init_stats() - def _reinit_buffers(self): + def _reinit_buffers(self) -> None: """Internal method to initialize/reset the internal StringIO objects for read and write buffers. """ self.__rbuf = BytesIO() self.__wbuf = BytesIO() - def _init_stats(self): + def _init_stats(self) -> None: """Internal method to reset the internal statistics counters for compression ratios and bandwidth savings. """ @@ -109,14 +122,14 @@ def _init_stats(self): self.bytes_in_comp = 0 self.bytes_out_comp = 0 - def _init_zlib(self): + def _init_zlib(self) -> None: """Internal method for setting up the zlib compression and decompression objects. """ self._zcomp_read = zlib.decompressobj() self._zcomp_write = zlib.compressobj(self.compresslevel) - def getCompRatio(self): + def getCompRatio(self) -> tuple[float | None, float | None]: """Get the current measured compression ratios (in,out) from this transport. @@ -133,14 +146,15 @@ def getCompRatio(self): None is returned if no bytes have yet been processed in a particular direction. """ - r_percent, w_percent = (None, None) + r_percent: float | None = None + w_percent: float | None = None if self.bytes_in > 0: r_percent = self.bytes_in_comp / self.bytes_in if self.bytes_out > 0: w_percent = self.bytes_out_comp / self.bytes_out return (r_percent, w_percent) - def getCompSavings(self): + def getCompSavings(self) -> tuple[int, int]: """Get the current count of saved bytes due to data compression. @@ -155,30 +169,30 @@ def getCompSavings(self): w_saved = self.bytes_out - self.bytes_out_comp return (r_saved, w_saved) - def isOpen(self): + def isOpen(self) -> bool: """Return the underlying transport's open status""" return self.__trans.isOpen() - def open(self): + def open(self) -> None: """Open the underlying transport""" self._init_stats() return self.__trans.open() - def listen(self): + def listen(self) -> None: """Invoke the underlying transport's listen() method""" - self.__trans.listen() + self.__trans.listen() # type: ignore[attr-defined] - def accept(self): + def accept(self) -> TTransportBase | None: """Accept connections on the underlying transport""" - return self.__trans.accept() + return self.__trans.accept() # type: ignore[attr-defined] - def close(self): + def close(self) -> None: """Close the underlying transport,""" self._reinit_buffers() self._init_zlib() return self.__trans.close() - def read(self, sz): + def read(self, sz: int) -> bytes: """Read up to sz bytes from the decompressed bytes buffer, and read from the underlying transport if the decompression buffer is empty. @@ -193,7 +207,7 @@ def read(self, sz): ret = self.__rbuf.read(sz) return ret - def readComp(self, sz): + def readComp(self, sz: int) -> bool: """Read compressed data from the underlying transport, then decompress it and append it to the internal StringIO read buffer """ @@ -208,13 +222,13 @@ def readComp(self, sz): return False return True - def write(self, buf): + def write(self, buf: bytes) -> None: """Write some bytes, putting them into the internal write buffer for eventual compression. """ self.__wbuf.write(buf) - def flush(self): + def flush(self) -> None: """Flush any queued up data in the write buffer and ensure the compression buffer is flushed out to the underlying transport """ @@ -224,7 +238,7 @@ def flush(self): self.bytes_out += len(wout) self.bytes_out_comp += len(zbuf) else: - zbuf = '' + zbuf = b'' ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH) self.bytes_out_comp += len(ztail) if (len(zbuf) + len(ztail)) > 0: @@ -233,11 +247,11 @@ def flush(self): self.__trans.flush() @property - def cstringio_buf(self): + def cstringio_buf(self) -> BytesIO: """Implement the CReadableTransport interface""" return self.__rbuf - def cstringio_refill(self, partialread, reqlen): + def cstringio_refill(self, partialread: bytes, reqlen: int) -> BytesIO: """Implement the CReadableTransport interface for refill""" retstring = partialread if reqlen < self.DEFAULT_BUFFSIZE: diff --git a/lib/py/src/transport/__init__.py b/lib/py/src/thrift/transport/__init__.py similarity index 100% rename from lib/py/src/transport/__init__.py rename to lib/py/src/thrift/transport/__init__.py diff --git a/lib/py/src/transport/sslcompat.py b/lib/py/src/thrift/transport/sslcompat.py similarity index 86% rename from lib/py/src/transport/sslcompat.py rename to lib/py/src/thrift/transport/sslcompat.py index 54235ec6d1d..e1852cbd6d4 100644 --- a/lib/py/src/transport/sslcompat.py +++ b/lib/py/src/thrift/transport/sslcompat.py @@ -17,15 +17,18 @@ # under the License. # +from __future__ import annotations + import logging import sys +from typing import Any, Callable from thrift.transport.TTransport import TTransportException -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -def legacy_validate_callback(cert, hostname): +def legacy_validate_callback(cert: dict[str, Any], hostname: str) -> None: """legacy method to validate the peer's SSL certificate, and to check the commonName of the certificate to ensure it matches the hostname we used to make this connection. Does not support subjectAltName records @@ -64,7 +67,7 @@ def legacy_validate_callback(cert, hostname): % (hostname, cert)) -def _optional_dependencies(): +def _optional_dependencies() -> tuple[bool, Callable[[dict[str, Any], str], None]]: try: import ipaddress # noqa logger.debug('ipaddress module is available') @@ -73,9 +76,10 @@ def _optional_dependencies(): logger.warning('ipaddress module is unavailable') ipaddr = False + match: Callable[[dict[str, Any], str], None] if sys.hexversion < 0x030500F0: try: - from backports.ssl_match_hostname import match_hostname, __version__ as ver + from backports.ssl_match_hostname import match_hostname, __version__ as ver # type: ignore[import-not-found] ver = list(map(int, ver.split('.'))) logger.debug('backports.ssl_match_hostname module is available') match = match_hostname @@ -88,7 +92,7 @@ def _optional_dependencies(): logger.warning('backports.ssl_match_hostname is unavailable') ipaddr = False try: - from ssl import match_hostname + from ssl import match_hostname # type: ignore[attr-defined] logger.debug('ssl.match_hostname is available') match = match_hostname except ImportError: @@ -97,11 +101,13 @@ def _optional_dependencies(): # 3.7. OpenSSL performs hostname matching since Python 3.7, Python no # longer uses the ssl.match_hostname() function."" if sys.version_info[0] > 3 or (sys.version_info[0] == 3 and sys.version_info[1] >= 12): - match = lambda cert, hostname: True + match = lambda cert, hostname: None else: logger.warning('using legacy validation callback') match = legacy_validate_callback return ipaddr, match +_match_has_ipaddress: bool +_match_hostname: Callable[[dict[str, Any], str], None] _match_has_ipaddress, _match_hostname = _optional_dependencies() From 3aa74a0d91cc0014bc33bba7b3fb692cafc2ba0f Mon Sep 17 00:00:00 2001 From: Gregg Donovan Date: Mon, 22 Dec 2025 10:51:40 -0500 Subject: [PATCH 2/3] Fix C extension build on macOS MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit On macOS, includes which already defines ntohll and htonll as macros: #define ntohll(x) __DARWIN_OSSwapInt64(x) #define htonll(x) __DARWIN_OSSwapInt64(x) When we tried to define functions with these names, the macros would expand, causing syntax errors like: error: expected ')' static inline unsigned long long ntohll(...) ^ note: expanded from macro 'ntohll' The fix wraps our ntohll/htonll definitions in #ifndef guards so they're only defined when not already provided by system headers. This allows the C extension (fastbinary) to build successfully on macOS while maintaining compatibility with other platforms. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- lib/py/src/thrift/ext/endian.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lib/py/src/thrift/ext/endian.h b/lib/py/src/thrift/ext/endian.h index 8f9e978c38a..71276e71dbd 100644 --- a/lib/py/src/thrift/ext/endian.h +++ b/lib/py/src/thrift/ext/endian.h @@ -29,6 +29,9 @@ #else #include +// On macOS, ntohll/htonll are already defined as macros in +// (included via ). Only define them if not already available. +#ifndef ntohll static inline unsigned long long ntohll(unsigned long long n) { union { unsigned long long f; @@ -43,8 +46,11 @@ static inline unsigned long long ntohll(unsigned long long n) { | static_cast(u.t[5]) << 16 | static_cast(u.t[6]) << 8 | static_cast(u.t[7]); } +#endif +#ifndef htonll #define htonll(n) ntohll(n) +#endif #endif // !_WIN32 From 7b510a370979120b0ba826657eb1dd7e97612445 Mon Sep 17 00:00:00 2001 From: Gregg Donovan Date: Wed, 14 Jan 2026 06:51:22 -0500 Subject: [PATCH 3/3] Enable Python type hints by default - Changed gen_type_hints_ default from false to true - Removed requirement for enum option when using type_hints - Added no_type_hints option to disable when needed - Updated Optional[T] syntax to modern T | None (Python 3.10+) Co-Authored-By: Claude --- compiler/cpp/src/thrift/generate/t_py_generator.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/compiler/cpp/src/thrift/generate/t_py_generator.cc b/compiler/cpp/src/thrift/generate/t_py_generator.cc index f8fb9f871ff..67a5573e75d 100644 --- a/compiler/cpp/src/thrift/generate/t_py_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_py_generator.cc @@ -63,7 +63,7 @@ class t_py_generator : public t_generator { gen_twisted_ = false; gen_dynamic_ = false; gen_enum_ = false; - gen_type_hints_ = false; + gen_type_hints_ = true; coding_ = ""; gen_dynbaseclass_ = ""; gen_dynbaseclass_exc_ = ""; @@ -127,10 +127,9 @@ class t_py_generator : public t_generator { } else if( iter->first.compare("coding") == 0) { coding_ = iter->second; } else if (iter->first.compare("type_hints") == 0) { - if (!gen_enum_) { - throw "the type_hints py option requires the enum py option"; - } gen_type_hints_ = true; + } else if (iter->first.compare("no_type_hints") == 0) { + gen_type_hints_ = false; } else { throw "unknown option py:" + iter->first; } @@ -2859,7 +2858,7 @@ string t_py_generator::arg_hint(t_type* type) { string t_py_generator::member_hint(t_type* type, t_field::e_req req) { if (gen_type_hints_) { if (req != t_field::T_REQUIRED) { - return ": typing.Optional[" + type_to_py_type(type) + "]"; + return ": " + type_to_py_type(type) + " | None"; } else { return ": " + type_to_py_type(type); } @@ -3023,5 +3022,6 @@ THRIFT_REGISTER_GENERATOR( " Package prefix for generated files.\n" " old_style: Deprecated. Generate old-style classes.\n" " enum: Generates Python's IntEnum, connects thrift to python enums. Python 3.4 and higher.\n" - " type_hints: Generate type hints and type checks in write method. Requires the enum option.\n" + " type_hints: Generate type hints (enabled by default).\n" + " no_type_hints: Disable type hint generation.\n" )