Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions compiler/cpp/src/thrift/generate/t_py_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class t_py_generator : public t_generator {
gen_twisted_ = false;
gen_dynamic_ = false;
gen_enum_ = false;
gen_type_hints_ = false;
gen_type_hints_ = true;
coding_ = "";
gen_dynbaseclass_ = "";
gen_dynbaseclass_exc_ = "";
Expand Down Expand Up @@ -127,10 +127,9 @@ class t_py_generator : public t_generator {
} else if( iter->first.compare("coding") == 0) {
coding_ = iter->second;
} else if (iter->first.compare("type_hints") == 0) {
if (!gen_enum_) {
throw "the type_hints py option requires the enum py option";
}
gen_type_hints_ = true;
} else if (iter->first.compare("no_type_hints") == 0) {
gen_type_hints_ = false;
} else {
throw "unknown option py:" + iter->first;
}
Expand Down Expand Up @@ -2859,7 +2858,7 @@ string t_py_generator::arg_hint(t_type* type) {
string t_py_generator::member_hint(t_type* type, t_field::e_req req) {
if (gen_type_hints_) {
if (req != t_field::T_REQUIRED) {
return ": typing.Optional[" + type_to_py_type(type) + "]";
return ": " + type_to_py_type(type) + " | None";
} else {
return ": " + type_to_py_type(type);
}
Expand Down Expand Up @@ -3023,5 +3022,6 @@ THRIFT_REGISTER_GENERATOR(
" Package prefix for generated files.\n"
" old_style: Deprecated. Generate old-style classes.\n"
" enum: Generates Python's IntEnum, connects thrift to python enums. Python 3.4 and higher.\n"
" type_hints: Generate type hints and type checks in write method. Requires the enum option.\n"
" type_hints: Generate type hints (enabled by default).\n"
" no_type_hints: Disable type hint generation.\n"
)
71 changes: 71 additions & 0 deletions lib/py/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
[project]
name = "thrift"
version = "0.23.0"
description = "Python bindings for the Apache Thrift RPC system"
readme = "README.md"
license = "Apache-2.0"
requires-python = ">=3.12"
authors = [
{ name = "Apache Thrift Developers", email = "dev@thrift.apache.org" }
]
classifiers = [
"Development Status :: 5 - Production/Stable",
"Environment :: Console",
"Intended Audience :: Developers",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Typing :: Typed",
"Topic :: Software Development :: Libraries",
"Topic :: System :: Networking",
]

[project.urls]
Homepage = "https://thrift.apache.org"
Repository = "https://github.com/apache/thrift"

[project.optional-dependencies]
tornado = ["tornado>=4.0"]
twisted = ["twisted"]
all = ["tornado>=4.0", "twisted"]

[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"

[tool.setuptools.packages.find]
where = ["src"]

[tool.setuptools.package-data]
thrift = ["py.typed"]

# Development dependencies
[dependency-groups]
dev = [
"ty>=0.0.5",
"pytest>=8.0",
"pure-sasl", # For SASL transport tests
]

# ty type checker configuration
[tool.ty.environment]
python-version = "3.12"

[tool.ty.rules]
possibly-unresolved-reference = "error"
invalid-assignment = "error"
invalid-argument-type = "error"
invalid-return-type = "error"
missing-argument = "error"
call-non-callable = "error"
no-matching-overload = "error"
unresolved-import = "error"
invalid-type-form = "error"

# pytest configuration
[tool.pytest.ini_options]
testpaths = ["test"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
32 changes: 16 additions & 16 deletions lib/py/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,9 @@
#

import sys
try:
from setuptools import setup, Extension
except Exception:
from distutils.core import setup, Extension

from distutils.command.build_ext import build_ext
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext
from distutils.errors import CCompilerError, DistutilsExecError, DistutilsPlatformError

# Fix to build sdist under vagrant
Expand All @@ -36,7 +33,7 @@
except AttributeError:
pass

include_dirs = ['src']
include_dirs = ['src/thrift']
if sys.platform == 'win32':
include_dirs.append('compat/win32')
ext_errors = (CCompilerError, DistutilsExecError, DistutilsPlatformError, IOError)
Expand Down Expand Up @@ -83,10 +80,10 @@ def run_setup(with_binary):
Extension('thrift.protocol.fastbinary',
extra_compile_args=['-std=c++11'],
sources=[
'src/ext/module.cpp',
'src/ext/types.cpp',
'src/ext/binary.cpp',
'src/ext/compact.cpp',
'src/thrift/ext/module.cpp',
'src/thrift/ext/types.cpp',
'src/thrift/ext/binary.cpp',
'src/thrift/ext/compact.cpp',
],
include_dirs=include_dirs,
)
Expand All @@ -96,13 +93,11 @@ def run_setup(with_binary):
else:
extensions = dict()

ssl_deps = []
if sys.hexversion < 0x03050000:
ssl_deps.append('backports.ssl_match_hostname>=3.5')
tornado_deps = ['tornado>=4.0']
twisted_deps = ['twisted']

setup(name='thrift',
python_requires='>=3.12',
version='0.23.0',
description='Python bindings for the Apache Thrift RPC system',
long_description=read_file("README.md"),
Expand All @@ -112,24 +107,29 @@ def run_setup(with_binary):
url='http://thrift.apache.org',
license='Apache License 2.0',
extras_require={
'ssl': ssl_deps,
'tornado': tornado_deps,
'twisted': twisted_deps,
'all': ssl_deps + tornado_deps + twisted_deps,
'all': tornado_deps + twisted_deps,
},
package_data={
'thrift': ['py.typed'],
},
packages=[
'thrift',
'thrift.protocol',
'thrift.transport',
'thrift.server',
],
package_dir={'thrift': 'src'},
package_dir={'': 'src'}, # Standard src layout
classifiers=[
'Development Status :: 5 - Production/Stable',
'Environment :: Console',
'Intended Audience :: Developers',
'Programming Language :: Python',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.12',
'Programming Language :: Python :: 3.13',
'Typing :: Typed',
'Topic :: Software Development :: Libraries',
'Topic :: System :: Networking'
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,29 @@
# under the License.
#

from __future__ import annotations

from typing import Any, Callable, TYPE_CHECKING

from thrift.Thrift import TProcessor, TMessageType
from thrift.protocol import TProtocolDecorator, TMultiplexedProtocol
from thrift.protocol.TProtocol import TProtocolException

if TYPE_CHECKING:
from thrift.protocol.TProtocol import TProtocolBase


class TMultiplexedProcessor(TProcessor):
def __init__(self):
"""Processor that multiplexes multiple services on a single connection."""

defaultProcessor: TProcessor | None
services: dict[str, TProcessor]

def __init__(self) -> None:
self.defaultProcessor = None
self.services = {}

def registerDefault(self, processor):
def registerDefault(self, processor: TProcessor) -> None:
"""
If a non-multiplexed processor connects to the server and wants to
communicate, use the given processor to handle it. This mechanism
Expand All @@ -36,14 +48,14 @@ def registerDefault(self, processor):
"""
self.defaultProcessor = processor

def registerProcessor(self, serviceName, processor):
def registerProcessor(self, serviceName: str, processor: TProcessor) -> None:
self.services[serviceName] = processor

def on_message_begin(self, func):
def on_message_begin(self, func: Callable[[str, int, int], None]) -> None:
for key in self.services.keys():
self.services[key].on_message_begin(func)

def process(self, iprot, oprot):
def process(self, iprot: TProtocolBase, oprot: TProtocolBase) -> bool | None:
(name, type, seqid) = iprot.readMessageBegin()
if type != TMessageType.CALL and type != TMessageType.ONEWAY:
raise TProtocolException(
Expand Down Expand Up @@ -75,8 +87,12 @@ def process(self, iprot, oprot):


class StoredMessageProtocol(TProtocolDecorator.TProtocolDecorator):
def __init__(self, protocol, messageBegin):
"""Protocol decorator that stores and returns a predetermined message begin."""

messageBegin: tuple[str, int, int]

def __init__(self, protocol: TProtocolBase, messageBegin: tuple[str, int, int]) -> None:
self.messageBegin = messageBegin

def readMessageBegin(self):
def readMessageBegin(self) -> tuple[str, int, int]:
return self.messageBegin
18 changes: 11 additions & 7 deletions lib/py/src/TRecursive.py → lib/py/src/thrift/TRecursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,19 @@
# under the License.
#

from __future__ import annotations

from typing import Any

from thrift.Thrift import TType

TYPE_IDX = 1
SPEC_ARGS_IDX = 3
SPEC_ARGS_CLASS_REF_IDX = 0
SPEC_ARGS_THRIFT_SPEC_IDX = 1
TYPE_IDX: int = 1
SPEC_ARGS_IDX: int = 3
SPEC_ARGS_CLASS_REF_IDX: int = 0
SPEC_ARGS_THRIFT_SPEC_IDX: int = 1


def fix_spec(all_structs):
def fix_spec(all_structs: list[Any]) -> None:
"""Wire up recursive references for all TStruct definitions inside of each thrift_spec."""
for struc in all_structs:
spec = struc.thrift_spec
Expand All @@ -41,7 +45,7 @@ def fix_spec(all_structs):
_fix_map(thrift_spec[SPEC_ARGS_IDX])


def _fix_list_or_set(element_type):
def _fix_list_or_set(element_type: list[Any]) -> None:
# For a list or set, the thrift_spec entry looks like,
# (1, TType.LIST, 'lister', (TType.STRUCT, [RecList, None], False), None, ), # 1
# so ``element_type`` will be,
Expand All @@ -54,7 +58,7 @@ def _fix_list_or_set(element_type):
_fix_map(element_type[1])


def _fix_map(element_type):
def _fix_map(element_type: list[Any]) -> None:
# For a map of key -> value type, ``element_type`` will be,
# (TType.I16, None, TType.STRUCT, [RecMapBasic, None], False), None, )
# which is just a normal struct definition.
Expand Down
12 changes: 8 additions & 4 deletions lib/py/src/TSCons.py → lib/py/src/thrift/TSCons.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,23 @@
# under the License.
#

from __future__ import annotations

from os import path
from SCons.Builder import Builder
from typing import Any, Iterator

from SCons.Builder import Builder # type: ignore[import-untyped]


def scons_env(env, add=''):
def scons_env(env: Any, add: str = '') -> None:
opath = path.dirname(path.abspath('$TARGET'))
lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE'
cppbuild = Builder(action=lstr)
env.Append(BUILDERS={'ThriftCpp': cppbuild})


def gen_cpp(env, dir, file):
def gen_cpp(env: Any, dir: str, file: str) -> Any:
scons_env(env)
suffixes = ['_types.h', '_types.cpp']
targets = map(lambda s: 'gen-cpp/' + file + s, suffixes)
targets: Iterator[str] = map(lambda s: 'gen-cpp/' + file + s, suffixes)
return env.ThriftCpp(targets, dir + file + '.thrift')
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,36 @@
# under the License.
#

from __future__ import annotations

from typing import Any, TYPE_CHECKING, TypeVar

from .protocol import TBinaryProtocol
from .transport import TTransport

if TYPE_CHECKING:
from thrift.protocol.TProtocol import TProtocolFactory
from thrift.protocol.TBase import TBase

T = TypeVar('T')


def serialize(thrift_object,
protocol_factory=TBinaryProtocol.TBinaryProtocolFactory()):
def serialize(
thrift_object: TBase,
protocol_factory: TProtocolFactory = TBinaryProtocol.TBinaryProtocolFactory(),
) -> bytes:
transport = TTransport.TMemoryBuffer()
protocol = protocol_factory.getProtocol(transport)
thrift_object.write(protocol)
return transport.getvalue()


def deserialize(base,
buf,
protocol_factory=TBinaryProtocol.TBinaryProtocolFactory()):
def deserialize(
base: T,
buf: bytes,
protocol_factory: TProtocolFactory = TBinaryProtocol.TBinaryProtocolFactory(),
) -> T:
transport = TTransport.TMemoryBuffer(buf)
protocol = protocol_factory.getProtocol(transport)
base.read(protocol)
base.read(protocol) # type: ignore[union-attr]
return base
Loading