diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index cca213e2f14..be203179a5e 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -547,6 +547,9 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
+ - name: Set up uv
+ uses: astral-sh/setup-uv@v5
+
- name: Python setup
run: |
python -m pip install --upgrade pip setuptools wheel flake8 "tornado>=6.3.0" "twisted>=24.3.0" "zope.interface>=6.1"
@@ -609,6 +612,9 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
+ - name: Set up uv
+ uses: astral-sh/setup-uv@v5
+
- name: Python setup
run: |
python -m pip install --upgrade pip setuptools wheel flake8 "tornado>=6.3.0" "twisted>=24.3.0" "zope.interface>=6.1"
diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml
index 12859dbb60b..783fc5c1c4a 100644
--- a/.github/workflows/pypi.yml
+++ b/.github/workflows/pypi.yml
@@ -38,7 +38,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v6
with:
- python-version: "3.8"
+ python-version: "3.10"
- name: Build
run: |
diff --git a/.gitignore b/.gitignore
index aeaaf2ff39f..9feec315571 100644
--- a/.gitignore
+++ b/.gitignore
@@ -57,6 +57,12 @@ project.lock.json
.vscode
.vs
+/.venv*
+/build-cmake/
+/lib/py/src/protocol/*.so
+# Generated from test/test_thrift_file/TestServer.thrift during make check
+/lib/py/test/TestServer/
+
/aclocal/libtool.m4
/aclocal/lt*.m4
/autoscan.log
diff --git a/LANGUAGES.md b/LANGUAGES.md
index 1fd76ea1ed7..d611495a73b 100644
--- a/LANGUAGES.md
+++ b/LANGUAGES.md
@@ -300,7 +300,7 @@ Thrift's core protocol is TBinary, supported by all languages except for JavaScr
Python |
0.2.0 |
 |  |
-2.7.12, 3.5.2 | 2.7.15, 3.6.8 |
+3.10 | 3.14 |
 |
 |  |  |  |  |  |
 |  |  |  |
diff --git a/Makefile.am b/Makefile.am
index 735cd405929..00b9158beb4 100644
--- a/Makefile.am
+++ b/Makefile.am
@@ -78,14 +78,10 @@ empty :=
space := $(empty) $(empty)
comma := ,
-CROSS_LANGS = @MAYBE_CPP@ @MAYBE_C_GLIB@ @MAYBE_CL@ @MAYBE_D@ @MAYBE_JAVA@ @MAYBE_PYTHON@ @MAYBE_PY3@ @MAYBE_RUBY@ @MAYBE_PERL@ @MAYBE_PHP@ @MAYBE_GO@ @MAYBE_NODEJS@ @MAYBE_DART@ @MAYBE_ERLANG@ @MAYBE_LUA@ @MAYBE_RS@ @MAYBE_NETSTD@ @MAYBE_NODETS@ @MAYBE_KOTLIN@ @MAYBE_SWIFT@
+CROSS_LANGS = @MAYBE_CPP@ @MAYBE_C_GLIB@ @MAYBE_CL@ @MAYBE_D@ @MAYBE_JAVA@ @MAYBE_PYTHON@ @MAYBE_RUBY@ @MAYBE_PERL@ @MAYBE_PHP@ @MAYBE_GO@ @MAYBE_NODEJS@ @MAYBE_DART@ @MAYBE_ERLANG@ @MAYBE_LUA@ @MAYBE_RS@ @MAYBE_NETSTD@ @MAYBE_NODETS@ @MAYBE_KOTLIN@ @MAYBE_SWIFT@
CROSS_LANGS_COMMA_SEPARATED = $(subst $(space),$(comma),$(CROSS_LANGS))
-if WITH_PY3
-CROSS_PY=$(PYTHON3)
-else
CROSS_PY=$(PYTHON)
-endif
if WITH_PYTHON
crossfeature: precross
diff --git a/build/appveyor/MSVC-appveyor-full.bat b/build/appveyor/MSVC-appveyor-full.bat
index d4d2896c651..ea8821a5436 100644
--- a/build/appveyor/MSVC-appveyor-full.bat
+++ b/build/appveyor/MSVC-appveyor-full.bat
@@ -145,9 +145,7 @@ IF "%WITH_PYTHON%" == "ON" (
"!PYTHON_ROOT!\python.exe" -m ensurepip --upgrade || EXIT /B
"!PYTHON_ROOT!\python.exe" -m pip install --upgrade pip setuptools wheel || EXIT /B
"!PYTHON_ROOT!\python.exe" -m pip ^
- install backports.ssl_match_hostname ^
- ipaddress ^
- tornado>=6.3.0 ^
+ install tornado>=6.3.0 ^
twisted>=24.3.0 ^
zope.interface>=6.1 || EXIT /B
)
diff --git a/build/cmake/DefineOptions.cmake b/build/cmake/DefineOptions.cmake
index 928e19b165d..1608051831e 100644
--- a/build/cmake/DefineOptions.cmake
+++ b/build/cmake/DefineOptions.cmake
@@ -117,7 +117,7 @@ CMAKE_DEPENDENT_OPTION(BUILD_NODEJS "Build NodeJS library" ON
# Python
option(WITH_PYTHON "Build Python Thrift library" ON)
-find_package(Python3
+find_package(Python3 3.10
COMPONENTS
Interpreter # for Python executable
Development # for Python.h
diff --git a/build/docker/README.md b/build/docker/README.md
index 0f2d293dffa..f70e0cd7522 100644
--- a/build/docker/README.md
+++ b/build/docker/README.md
@@ -172,31 +172,30 @@ Last updated: March 5, 2024
## Compiler/Language Versions per Dockerfile ##
-| Tool | ubuntu-focal | ubuntu-jammy | ubuntu-noble | Notes |
-| :-------- | :------------ | :------------ | :------------ | :---- |
-| as of | Mar 06, 2018 | Jul 1, 2019 | | |
-| as3 | 4.6.0 | 4.6.0 | | |
-| C++ gcc | 9.4.0 | 11.4.0 | | |
-| C++ clang | 13.0.0 | 13.0.0 | | |
-| c\_glib | 3.2.12 | 3.2.12 | | |
-| cl (sbcl) | | 1.5.3 | | |
-| d | 2.087.0 | 2.087.0 | | |
-| dart | 2.7.2-1 | 2.7.2-1 | | |
-| delphi | | | | Not in CI |
-| erlang | OTP-25.3.2.9 | OTP-25.3.2.9 | | |
-| go | 1.21.7 | 1.21.7 | | |
-| haxe | 4.2.1 | 4.2.1 | | |
-| java | 17 | 17 | | |
-| js | Node.js 16.20.2, npm 8.19.4 | | | Node.js 16.20.2, npm 8.19.4 |
-| lua | 5.2.4 | 5.2.4 | | Lua 5.3: see THRIFT-4386 |
-| netstd | 9.0 | 9.0 | 9.0 | |
-| nodejs | 16.20.2 | 16.20.2 | | |
-| ocaml | 4.08.1 | 4.13.1 | | |
-| perl | 5.30.0 | 5.34.0 | | |
-| php | 7.4.3 | 8.1.2 | 8.3 | |
-| python2 | 2.7.18 | | | |
-| python3 | 3.8.10 | 3.10.12 | | |
-| ruby | 2.7.0p0 | 3.0.2p107 | | |
-| rust | 1.83.0 | 1.83.0 | | |
-| smalltalk | | | | Not in CI |
-| swift | 5.7 | 5.7 | 6.1 | |
+| Tool | ubuntu-focal | ubuntu-jammy | ubuntu-noble | Notes |
+| :-------- | :------------ | :------------ | :------------ |:-----------------------------------------------|
+| as of | Mar 06, 2018 | Jul 1, 2019 | | |
+| as3 | 4.6.0 | 4.6.0 | | |
+| C++ gcc | 9.4.0 | 11.4.0 | | |
+| C++ clang | 13.0.0 | 13.0.0 | | |
+| c\_glib | 3.2.12 | 3.2.12 | | |
+| cl (sbcl) | | 1.5.3 | | |
+| d | 2.087.0 | 2.087.0 | | |
+| dart | 2.7.2-1 | 2.7.2-1 | | |
+| delphi | | | | Not in CI |
+| erlang | OTP-25.3.2.9 | OTP-25.3.2.9 | | |
+| go | 1.21.7 | 1.21.7 | | |
+| haxe | 4.2.1 | 4.2.1 | | |
+| java | 17 | 17 | | |
+| js | Node.js 16.20.2, npm 8.19.4 | | | Node.js 16.20.2, npm 8.19.4 |
+| lua | 5.2.4 | 5.2.4 | | Lua 5.3: see THRIFT-4386 |
+| netstd | 9.0 | 9.0 | 9.0 | |
+| nodejs | 16.20.2 | 16.20.2 | | |
+| ocaml | 4.08.1 | 4.13.1 | | |
+| perl | 5.30.0 | 5.34.0 | | |
+| php | 7.4.3 | 8.1.2 | 8.3 | |
+| python | 3.10.14 | 3.10.12 | 3.12.3 | focal: built from source (ships with 3.8) |
+| ruby | 2.7.0p0 | 3.0.2p107 | | |
+| rust | 1.83.0 | 1.83.0 | | |
+| smalltalk | | | | Not in CI |
+| swift | 5.7 | 5.7 | 6.1 | |
diff --git a/build/docker/ubuntu-focal/Dockerfile b/build/docker/ubuntu-focal/Dockerfile
index 465c0f1e439..df30ab85685 100644
--- a/build/docker/ubuntu-focal/Dockerfile
+++ b/build/docker/ubuntu-focal/Dockerfile
@@ -254,16 +254,37 @@ RUN apt-get install -y --no-install-recommends \
re2c \
composer
+# Python 3.10 built from source (Focal ships with 3.8, but we require 3.10+)
+ENV PYTHON_VERSION=3.10.14
RUN apt-get install -y --no-install-recommends \
- `# Python3 dependencies` \
- python3-all \
- python3-all-dbg \
- python3-all-dev \
- python3-pip \
- python3-setuptools \
- python3-wheel
-
-RUN python3 -m pip install --no-cache-dir --upgrade "tornado>=6.3.0" "twisted>=24.3.0" "zope.interface>=6.1"
+ `# Python build dependencies` \
+ libbz2-dev \
+ libffi-dev \
+ libgdbm-dev \
+ liblzma-dev \
+ libncurses5-dev \
+ libreadline-dev \
+ libsqlite3-dev \
+ libssl-dev \
+ tk-dev \
+ uuid-dev \
+ xz-utils && \
+ cd /tmp && \
+ wget -q https://www.python.org/ftp/python/${PYTHON_VERSION}/Python-${PYTHON_VERSION}.tgz && \
+ tar xzf Python-${PYTHON_VERSION}.tgz && \
+ cd Python-${PYTHON_VERSION} && \
+ ./configure --enable-optimizations --with-ensurepip=install && \
+ make -j$(nproc) && \
+ make altinstall && \
+ cd / && rm -rf /tmp/Python-${PYTHON_VERSION}* && \
+ update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.10 1 && \
+ python3.10 -m pip install --upgrade pip && \
+ pip3.10 install --no-cache-dir \
+ setuptools \
+ wheel \
+ tornado>=6.3.0 \
+ twisted>=24.3.0 \
+ zope.interface>=6.1
RUN apt-get install -y --no-install-recommends \
`# Ruby dependencies` \
@@ -281,7 +302,6 @@ USER root
RUN apt-get install -yq \
libedit-dev \
libz3-dev \
- libpython2-dev \
libxml2-dev && \
cd / && \
wget --quiet https://download.swift.org/swift-5.7-release/ubuntu2004/swift-5.7-RELEASE/swift-5.7-RELEASE-ubuntu20.04.tar.gz && \
diff --git a/build/docker/ubuntu-jammy/Dockerfile b/build/docker/ubuntu-jammy/Dockerfile
index a2331ab695d..df1e0b03db9 100644
--- a/build/docker/ubuntu-jammy/Dockerfile
+++ b/build/docker/ubuntu-jammy/Dockerfile
@@ -274,7 +274,6 @@ USER root
RUN apt-get install -yq \
libedit-dev \
libz3-dev \
- libpython2-dev \
libxml2-dev && \
cd / && \
wget --quiet https://download.swift.org/swift-5.7-release/ubuntu2204/swift-5.7-RELEASE/swift-5.7-RELEASE-ubuntu22.04.tar.gz && \
diff --git a/build/docker/ubuntu-noble/Dockerfile b/build/docker/ubuntu-noble/Dockerfile
index a195fd460b5..63ae03eafc9 100644
--- a/build/docker/ubuntu-noble/Dockerfile
+++ b/build/docker/ubuntu-noble/Dockerfile
@@ -295,8 +295,8 @@ RUN apt-get install -y --no-install-recommends \
RUN apt-get install -y --no-install-recommends \
`# Static Code Analysis dependencies` \
cppcheck \
- sloccount
-
+ sloccount
+
#RUN pip install flake8
# NOTE: this does not reduce the image size but adds an additional layer.
diff --git a/compiler/cpp/CMakeLists.txt b/compiler/cpp/CMakeLists.txt
index 2f5cb7a1e1f..bcab4219b5c 100644
--- a/compiler/cpp/CMakeLists.txt
+++ b/compiler/cpp/CMakeLists.txt
@@ -108,7 +108,7 @@ THRIFT_ADD_COMPILER(netstd "Enable compiler for .NET Standard" ON)
THRIFT_ADD_COMPILER(ocaml "Enable compiler for OCaml" ON)
THRIFT_ADD_COMPILER(perl "Enable compiler for Perl" ON)
THRIFT_ADD_COMPILER(php "Enable compiler for PHP" ON)
-THRIFT_ADD_COMPILER(py "Enable compiler for Python 2.0" ON)
+THRIFT_ADD_COMPILER(py "Enable compiler for Python" ON)
THRIFT_ADD_COMPILER(rb "Enable compiler for Ruby" ON)
THRIFT_ADD_COMPILER(rs "Enable compiler for Rust" ON)
THRIFT_ADD_COMPILER(st "Enable compiler for Smalltalk" ON)
diff --git a/compiler/cpp/src/thrift/generate/t_py_generator.cc b/compiler/cpp/src/thrift/generate/t_py_generator.cc
index f8fb9f871ff..848c23bbe16 100644
--- a/compiler/cpp/src/thrift/generate/t_py_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_py_generator.cc
@@ -54,16 +54,12 @@ class t_py_generator : public t_generator {
std::map::const_iterator iter;
- gen_newstyle_ = true;
- gen_utf8strings_ = true;
gen_dynbase_ = false;
gen_slots_ = false;
gen_tornado_ = false;
gen_zope_interface_ = false;
gen_twisted_ = false;
gen_dynamic_ = false;
- gen_enum_ = false;
- gen_type_hints_ = false;
coding_ = "";
gen_dynbaseclass_ = "";
gen_dynbaseclass_exc_ = "";
@@ -72,24 +68,12 @@ class t_py_generator : public t_generator {
import_dynbase_ = "";
package_prefix_ = "";
for( iter = parsed_options.begin(); iter != parsed_options.end(); ++iter) {
- if( iter->first.compare("enum") == 0) {
- gen_enum_ = true;
- } else if( iter->first.compare("new_style") == 0) {
- pwarning(0, "new_style is enabled by default, so the option will be removed in the near future.\n");
- } else if( iter->first.compare("old_style") == 0) {
- gen_newstyle_ = false;
- pwarning(0, "old_style is deprecated and may be removed in the future.\n");
- } else if( iter->first.compare("utf8strings") == 0) {
- pwarning(0, "utf8strings is enabled by default, so the option will be removed in the near future.\n");
- } else if( iter->first.compare("no_utf8strings") == 0) {
- gen_utf8strings_ = false;
- } else if( iter->first.compare("slots") == 0) {
+ if( iter->first.compare("slots") == 0) {
gen_slots_ = true;
} else if( iter->first.compare("package_prefix") == 0) {
package_prefix_ = iter->second;
} else if( iter->first.compare("dynamic") == 0) {
gen_dynamic_ = true;
- gen_newstyle_ = false; // dynamic is newstyle
if( gen_dynbaseclass_.empty()) {
gen_dynbaseclass_ = "TBase";
}
@@ -126,11 +110,6 @@ class t_py_generator : public t_generator {
gen_tornado_ = true;
} else if( iter->first.compare("coding") == 0) {
coding_ = iter->second;
- } else if (iter->first.compare("type_hints") == 0) {
- if (!gen_enum_) {
- throw "the type_hints py option requires the enum py option";
- }
- gen_type_hints_ = true;
} else {
throw "unknown option py:" + iter->first;
}
@@ -303,12 +282,6 @@ class t_py_generator : public t_generator {
private:
- /**
- * True if we should generate new-style classes.
- */
- bool gen_newstyle_;
- bool gen_enum_;
-
/**
* True if we should generate dynamic style classes.
*/
@@ -324,11 +297,6 @@ class t_py_generator : public t_generator {
bool gen_slots_;
- /**
- * True if we should generate classes type hints and type checks in write methods.
- */
- bool gen_type_hints_;
-
std::string copy_options_;
/**
@@ -346,11 +314,6 @@ class t_py_generator : public t_generator {
*/
bool gen_tornado_;
- /**
- * True if strings should be encoded using utf-8.
- */
- bool gen_utf8strings_;
-
/**
* specify generated file encoding
* eg. # -*- coding: utf-8 -*-
@@ -372,9 +335,9 @@ class t_py_generator : public t_generator {
protected:
std::set lang_keywords_for_validation() const override {
- std::string keywords[] = { "False", "None", "True", "and", "as", "assert", "break", "class",
- "continue", "def", "del", "elif", "else", "except", "exec", "finally", "for", "from",
- "global", "if", "import", "in", "is", "lambda", "nonlocal", "not", "or", "pass", "print",
+ std::string keywords[] = { "False", "None", "True", "and", "as", "assert", "async", "await",
+ "break", "class", "continue", "def", "del", "elif", "else", "except", "finally", "for",
+ "from", "global", "if", "import", "in", "is", "lambda", "nonlocal", "not", "or", "pass",
"raise", "return", "try", "while", "with", "yield" };
return std::set(keywords, keywords + sizeof(keywords)/sizeof(keywords[0]) );
}
@@ -430,6 +393,12 @@ void t_py_generator::init_generator() {
f_init << "]" << '\n';
f_init.close();
+ // Generate py.typed marker for PEP 561 (typed package)
+ string f_py_typed_name = package_dir_ + "/py.typed";
+ ofstream_with_content_based_conditional_update f_py_typed;
+ f_py_typed.open(f_py_typed_name.c_str());
+ f_py_typed.close();
+
// Print header
f_types_ << py_autogen_comment() << '\n'
<< py_imports() << '\n'
@@ -437,11 +406,7 @@ void t_py_generator::init_generator() {
<< "from thrift.transport import TTransport" << '\n'
<< import_dynbase_;
- if (gen_type_hints_) {
- f_types_ << "all_structs: list[typing.Any] = []" << '\n';
- } else {
- f_types_ << "all_structs = []" << '\n';
- }
+ f_types_ << "all_structs: list[typing.Any] = []" << '\n';
f_consts_ <<
py_autogen_comment() << '\n' <<
@@ -479,11 +444,10 @@ string t_py_generator::py_autogen_comment() {
*/
string t_py_generator::py_imports() {
ostringstream ss;
- if (gen_type_hints_) {
- ss << "from __future__ import annotations" << '\n' << "import typing" << '\n';
- }
-
- ss << "from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, "
+ ss << "from __future__ import annotations" << '\n'
+ << "import typing" << '\n'
+ << '\n'
+ << "from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, "
"TApplicationException"
<< '\n'
<< "from thrift.protocol.TProtocol import TProtocolException"
@@ -491,13 +455,8 @@ string t_py_generator::py_imports() {
<< "from thrift.TRecursive import fix_spec"
<< '\n'
<< "from uuid import UUID"
- << '\n';
- if (gen_enum_) {
- ss << "from enum import IntEnum" << '\n';
- }
- if (gen_utf8strings_) {
- ss << '\n' << "import sys";
- }
+ << '\n'
+ << "from enum import IntEnum" << '\n';
return ss.str();
}
@@ -531,49 +490,39 @@ void t_py_generator::generate_typedef(t_typedef* ttypedef) {
* @param tenum The enumeration
*/
void t_py_generator::generate_enum(t_enum* tenum) {
- std::ostringstream to_string_mapping, from_string_mapping;
- std::string base_class;
-
- if (gen_enum_) {
- base_class = "IntEnum";
- } else if (gen_newstyle_) {
- base_class = "object";
- } else if (gen_dynamic_) {
- base_class = gen_dynbaseclass_;
- }
-
+ // Python 3.10+: All enums use IntEnum
f_types_ << '\n'
<< '\n'
- << "class " << tenum->get_name()
- << (base_class.empty() ? "" : "(" + base_class + ")")
- << ":"
+ << "class " << tenum->get_name() << "(IntEnum):"
<< '\n';
indent_up();
generate_python_docstring(f_types_, tenum);
- to_string_mapping << indent() << "_VALUES_TO_NAMES = {" << '\n';
- from_string_mapping << indent() << "_NAMES_TO_VALUES = {" << '\n';
-
vector constants = tenum->get_constants();
vector::iterator c_iter;
for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) {
int value = (*c_iter)->get_value();
indent(f_types_) << (*c_iter)->get_name() << " = " << value << '\n';
-
- // Dictionaries to/from string names of enums
- to_string_mapping << indent() << indent() << value << ": \""
- << escape_string((*c_iter)->get_name()) << "\"," << '\n';
- from_string_mapping << indent() << indent() << '"' << escape_string((*c_iter)->get_name())
- << "\": " << value << ',' << '\n';
}
- to_string_mapping << indent() << "}" << '\n';
- from_string_mapping << indent() << "}" << '\n';
+
+ // Handle unknown enum values gracefully
+ f_types_ << '\n';
+ indent(f_types_) << "@classmethod" << '\n';
+ indent(f_types_) << "def _missing_(cls, value):" << '\n';
+ indent_up();
+ indent(f_types_) << "if not isinstance(value, int):" << '\n';
+ indent_up();
+ indent(f_types_) << "return None" << '\n';
+ indent_down();
+ indent(f_types_) << "unknown = int.__new__(cls, value)" << '\n';
+ indent(f_types_) << "unknown._name_ = f\"UNKNOWN_{value}\"" << '\n';
+ indent(f_types_) << "unknown._value_ = value" << '\n';
+ indent(f_types_) << "cls._value2member_map_.setdefault(value, unknown)" << '\n';
+ indent(f_types_) << "return unknown" << '\n';
+ indent_down();
indent_down();
f_types_ << '\n';
- if (!gen_enum_) {
- f_types_ << to_string_mapping.str() << '\n' << from_string_mapping.str();
- }
}
/**
@@ -631,12 +580,8 @@ string t_py_generator::render_const_value(t_type* type, t_const_value* value) {
} else if (type->is_enum()) {
out << indent();
int64_t int_val = value->get_integer();
- if (gen_enum_) {
- t_enum_value* enum_val = ((t_enum*)type)->get_constant_by_value(int_val);
- out << type_name(type) << "." << enum_val->get_name();
- } else {
- out << int_val;
- }
+ t_enum_value* enum_val = ((t_enum*)type)->get_constant_by_value(int_val);
+ out << type_name(type) << "." << enum_val->get_name();
} else if (type->is_struct() || type->is_xception()) {
out << type_name(type) << "(**{" << '\n';
indent_up();
@@ -829,14 +774,12 @@ void t_py_generator::generate_py_struct_definition(ostream& out,
} else {
out << "(" << gen_dynbaseclass_ << ")";
}
- } else if (gen_newstyle_) {
- out << "(object)";
}
+ // Note: For Python 3.10+, we don't need explicit (object) base class
out << ":" << '\n';
indent_up();
generate_python_docstring(out, tstruct);
- std::string thrift_spec_type = gen_type_hints_ ? ": typing.Any" : "";
- out << indent() << "thrift_spec" << thrift_spec_type << " = None" << '\n';
+ out << indent() << "thrift_spec: typing.Any = None" << '\n';
out << '\n';
@@ -870,6 +813,17 @@ void t_py_generator::generate_py_struct_definition(ostream& out,
indent(out) << ")" << '\n' << '\n';
}
+ // For immutable structs without slots, declare class-level attributes
+ // so type checkers can recognize the attributes set via super().__setattr__
+ // Always use | None since __init__ parameters always allow None
+ if (is_immutable(tstruct) && !gen_slots_ && !gen_dynamic_ && members.size() > 0) {
+ for (m_iter = sorted_members.begin(); m_iter != sorted_members.end(); ++m_iter) {
+ indent(out) << (*m_iter)->get_name()
+ << ": " << type_to_py_type((*m_iter)->get_type()) << " | None" << '\n';
+ }
+ out << '\n';
+ }
+
// TODO(dreiss): Look into generating an empty tuple instead of None
// for structures with no members.
// TODO(dreiss): Test encoding of structs where some inner structs
@@ -899,21 +853,37 @@ void t_py_generator::generate_py_struct_definition(ostream& out,
}
if (is_immutable(tstruct)) {
- if (gen_enum_ && type->is_enum()) {
- indent(out) << "super(" << tstruct->get_name() << ", self).__setattr__('"
- << (*m_iter)->get_name() << "', " << (*m_iter)->get_name()
- << " if hasattr(" << (*m_iter)->get_name() << ", 'value') else "
- << type_name(type) << ".__members__.get(" << (*m_iter)->get_name() << "))" << '\n';
- } else if (gen_newstyle_ || gen_dynamic_) {
- indent(out) << "super(" << tstruct->get_name() << ", self).__setattr__('"
- << (*m_iter)->get_name() << "', " << (*m_iter)->get_name() << ")" << '\n';
+ if (type->is_enum()) {
+ string enum_value = tmp("_enum_value");
+ indent(out) << enum_value << " = " << (*m_iter)->get_name() << '\n';
+ indent(out) << "if " << enum_value << " is not None and not hasattr(" << enum_value
+ << ", 'value'):" << '\n';
+ indent_up();
+ indent(out) << "try:" << '\n';
+ indent_up();
+ indent(out) << enum_value << " = " << type_name(type) << "(" << enum_value << ")" << '\n';
+ indent_down();
+ indent(out) << "except (ValueError, TypeError):" << '\n';
+ indent_up();
+ indent(out) << enum_value << " = " << type_name(type) << ".__members__.get(" << enum_value
+ << ")" << '\n';
+ indent(out) << "if " << enum_value << " is None:" << '\n';
+ indent_up();
+ indent(out) << "raise" << '\n';
+ indent_down();
+ indent_down();
+ indent_down();
+ indent(out) << "super().__setattr__('"
+ << (*m_iter)->get_name() << "', " << enum_value << ")" << '\n';
} else {
- indent(out) << "self.__dict__['" << (*m_iter)->get_name()
- << "'] = " << (*m_iter)->get_name() << '\n';
+ // For immutable structs, use super().__setattr__ to bypass __setattr__ override
+ indent(out) << "super().__setattr__('"
+ << (*m_iter)->get_name() << "', " << (*m_iter)->get_name() << ")" << '\n';
}
} else {
+ // Instance attribute type hint should always allow None to match __init__ params
indent(out) << "self." << (*m_iter)->get_name()
- << member_hint((*m_iter)->get_type(), (*m_iter)->get_req()) << " = "
+ << ": " << type_to_py_type((*m_iter)->get_type()) << " | None = "
<< (*m_iter)->get_name() << '\n';
}
}
@@ -937,6 +907,14 @@ void t_py_generator::generate_py_struct_definition(ostream& out,
out << indent() << "super().__setattr__(*args)" << '\n'
<< indent() << "return" << '\n';
indent_down();
+ } else if (is_exception) {
+ // For exceptions without slots, allow Python internal exception attributes
+ // that are modified by contextlib.contextmanager and multiprocessing.Pool
+ out << indent() << "if args[0] in ('__traceback__', '__context__', '__cause__', '__suppress_context__'):" << '\n';
+ indent_up();
+ out << indent() << "super().__setattr__(*args)" << '\n'
+ << indent() << "return" << '\n';
+ indent_down();
}
out << indent() << "raise TypeError(\"can't modify immutable instance\")" << '\n';
indent_down();
@@ -955,6 +933,14 @@ void t_py_generator::generate_py_struct_definition(ostream& out,
out << indent() << "super().__delattr__(*args)" << '\n'
<< indent() << "return" << '\n';
indent_down();
+ } else if (is_exception) {
+ // For exceptions without slots, allow Python internal exception attributes
+ // that are modified by contextlib.contextmanager and multiprocessing.Pool
+ out << indent() << "if args[0] in ('__traceback__', '__context__', '__cause__', '__suppress_context__'):" << '\n';
+ indent_up();
+ out << indent() << "super().__delattr__(*args)" << '\n'
+ << indent() << "return" << '\n';
+ indent_down();
}
out << indent() << "raise TypeError(\"can't modify immutable instance\")" << '\n';
indent_down();
@@ -970,7 +956,8 @@ void t_py_generator::generate_py_struct_definition(ostream& out,
}
out << "))" << '\n';
- } else if (gen_enum_) {
+ } else {
+ // For mutable structs with enum fields, generate __setattr__ to handle enum conversion
bool has_enum = false;
for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
t_type* type = (*m_iter)->get_type();
@@ -987,10 +974,28 @@ void t_py_generator::generate_py_struct_definition(ostream& out,
for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
t_type* type = (*m_iter)->get_type();
if (type->is_enum()) {
- out << indent() << "if name == \"" << (*m_iter)->get_name() << "\":" << '\n'
- << indent() << indent_str() << "super().__setattr__(name, value if hasattr(value, 'value') or value is None else "
- << type_name(type) << "(value))" << '\n'
- << indent() << indent_str() << "return" << '\n';
+ out << indent() << "if name == \"" << (*m_iter)->get_name() << "\":" << '\n';
+ indent_up();
+ out << indent() << "if hasattr(value, 'value') or value is None:" << '\n';
+ indent_up();
+ out << indent() << "super().__setattr__(name, value)" << '\n'
+ << indent() << "return" << '\n';
+ indent_down();
+ out << indent() << "try:" << '\n';
+ indent_up();
+ out << indent() << "enum_value = " << type_name(type) << "(value)" << '\n';
+ indent_down();
+ out << indent() << "except (ValueError, TypeError):" << '\n';
+ indent_up();
+ out << indent() << "enum_value = " << type_name(type) << ".__members__.get(value)" << '\n';
+ out << indent() << "if enum_value is None:" << '\n';
+ indent_up();
+ out << indent() << "raise" << '\n';
+ indent_down();
+ indent_down();
+ out << indent() << "super().__setattr__(name, enum_value)" << '\n'
+ << indent() << "return" << '\n';
+ indent_down();
}
}
indent(out) << "super().__setattr__(name, value)" << '\n' << '\n';
@@ -1271,6 +1276,9 @@ void t_py_generator::generate_service(t_service* tservice) {
<< import_dynbase_;
if (gen_zope_interface_) {
f_service_ << "from zope.interface import Interface, implementer" << '\n';
+ } else {
+ // Import Protocol for type-safe interface definitions
+ f_service_ << "from typing import Protocol" << '\n';
}
if (gen_twisted_) {
@@ -1353,8 +1361,10 @@ void t_py_generator::generate_service_interface(t_service* tservice) {
} else {
if (gen_zope_interface_) {
extends_if = "(Interface)";
- } else if (gen_newstyle_ || gen_dynamic_ || gen_tornado_) {
- extends_if = "(object)";
+ } else {
+ // Inherit from Protocol for type-safe interface definitions
+ // This allows type checkers to recognize abstract methods with ellipsis body
+ extends_if = "(Protocol)";
}
}
@@ -1376,7 +1386,10 @@ void t_py_generator::generate_service_interface(t_service* tservice) {
f_service_ << indent() << "def " << function_signature(*f_iter, true) << ":" << '\n';
indent_up();
generate_python_docstring(f_service_, (*f_iter));
- f_service_ << indent() << "pass" << '\n';
+ // Use ellipsis (...) instead of pass for interface stubs
+ // This is the Python convention for abstract/protocol methods
+ // and type checkers recognize this pattern
+ f_service_ << indent() << "..." << '\n';
indent_down();
}
}
@@ -1399,11 +1412,8 @@ void t_py_generator::generate_service_client(t_service* tservice) {
} else {
extends_client = extends + ".Client, ";
}
- } else {
- if (gen_zope_interface_ && (gen_newstyle_ || gen_dynamic_)) {
- extends_client = "(object)";
- }
}
+ // Note: For Python 3.10+, we don't need explicit (object) base class
f_service_ << '\n' << '\n';
@@ -1745,10 +1755,7 @@ void t_py_generator::generate_service_remote(t_service* tservice) {
py_autogen_comment() << '\n' <<
"import sys" << '\n' <<
"import pprint" << '\n' <<
- "if sys.version_info[0] > 2:" << '\n' <<
- indent_str() << "from urllib.parse import urlparse" << '\n' <<
- "else:" << '\n' <<
- indent_str() << "from urlparse import urlparse" << '\n' <<
+ "from urllib.parse import urlparse" << '\n' <<
"from thrift.transport import TTransport, TSocket, TSSLSocket, THttpClient" << '\n' <<
"from thrift.protocol.TBinaryProtocol import TBinaryProtocol" << '\n' << '\n';
@@ -2347,10 +2354,8 @@ void t_py_generator::generate_deserialize_field(ostream& out,
case t_base_type::TYPE_STRING:
if (type->is_binary()) {
out << "readBinary()";
- } else if(!gen_utf8strings_) {
- out << "readString()";
} else {
- out << "readString().decode('utf-8', errors='replace') if sys.version_info[0] == 2 else iprot.readString()";
+ out << "readString()";
}
break;
case t_base_type::TYPE_BOOL:
@@ -2380,11 +2385,7 @@ void t_py_generator::generate_deserialize_field(ostream& out,
}
out << '\n';
} else if (type->is_enum()) {
- if (gen_enum_) {
- indent(out) << name << " = " << type_name(type) << "(iprot.readI32())";
- } else {
- indent(out) << name << " = iprot.readI32()";
- }
+ indent(out) << name << " = " << type_name(type) << "(iprot.readI32())";
out << '\n';
} else {
printf("DO NOT KNOW HOW TO DESERIALIZE FIELD '%s' TYPE '%s'\n",
@@ -2542,10 +2543,8 @@ void t_py_generator::generate_serialize_field(ostream& out, t_field* tfield, str
case t_base_type::TYPE_STRING:
if (type->is_binary()) {
out << "writeBinary(" << name << ")";
- } else if (!gen_utf8strings_) {
- out << "writeString(" << name << ")";
} else {
- out << "writeString(" << name << ".encode('utf-8') if sys.version_info[0] == 2 else " << name << ")";
+ out << "writeString(" << name << ")";
}
break;
case t_base_type::TYPE_BOOL:
@@ -2573,11 +2572,7 @@ void t_py_generator::generate_serialize_field(ostream& out, t_field* tfield, str
throw "compiler error: no Python name for base type " + t_base_type::t_base_name(tbase);
}
} else if (type->is_enum()) {
- if (gen_enum_){
- out << "writeI32(" << name << ".value)";
- } else {
- out << "writeI32(" << name << ")";
- }
+ out << "writeI32(" << name << ".value)";
}
out << '\n';
} else {
@@ -2742,8 +2737,10 @@ void t_py_generator::generate_python_docstring(ostream& out, t_doc* tdoc) {
*/
string t_py_generator::declare_argument(t_field* tfield) {
std::ostringstream result;
- t_field::e_req req = tfield->get_req();
- result << tfield->get_name() << member_hint(tfield->get_type(), req);
+ // For __init__ parameters, always use `| None` type hint since all params
+ // have None as default for backward compatibility. Validation of required
+ // fields happens at runtime in validate().
+ result << tfield->get_name() << ": " << type_to_py_type(tfield->get_type()) << " | None";
result << " = ";
if (tfield->get_value() != nullptr) {
@@ -2849,31 +2846,20 @@ string t_py_generator::type_name(t_type* ttype) {
}
string t_py_generator::arg_hint(t_type* type) {
- if (gen_type_hints_) {
- return ": " + type_to_py_type(type);
- }
-
- return "";
+ return ": " + type_to_py_type(type);
}
string t_py_generator::member_hint(t_type* type, t_field::e_req req) {
- if (gen_type_hints_) {
- if (req != t_field::T_REQUIRED) {
- return ": typing.Optional[" + type_to_py_type(type) + "]";
- } else {
- return ": " + type_to_py_type(type);
- }
+ if (req != t_field::T_REQUIRED) {
+ // Python 3.10+ union syntax for optional fields
+ return ": " + type_to_py_type(type) + " | None";
+ } else {
+ return ": " + type_to_py_type(type);
}
-
- return "";
}
string t_py_generator::func_hint(t_type* type) {
- if (gen_type_hints_) {
- return " -> " + type_to_py_type(type);
- }
-
- return "";
+ return " -> " + type_to_py_type(type);
}
/**
@@ -2971,8 +2957,9 @@ string t_py_generator::type_to_spec_args(t_type* ttype) {
if (ttype->is_binary()) {
return "'BINARY'";
- } else if (gen_utf8strings_ && ttype->is_base_type()
+ } else if (ttype->is_base_type()
&& reinterpret_cast(ttype)->is_string()) {
+ // Python 3: strings are always UTF-8
return "'UTF8'";
} else if (ttype->is_base_type() || ttype->is_enum()) {
return "None";
@@ -3009,7 +2996,6 @@ THRIFT_REGISTER_GENERATOR(
" zope.interface: Generate code for use with zope.interface.\n"
" twisted: Generate Twisted-friendly RPC services.\n"
" tornado: Generate code for use with Tornado.\n"
- " no_utf8strings: Do not Encode/decode strings using utf8 in the generated code. Basically no effect for Python 3.\n"
" coding=CODING: Add file encoding declare in generated file.\n"
" slots: Generate code using slots for instance members.\n"
" dynamic: Generate dynamic code, less code generated but slower.\n"
@@ -3021,7 +3007,4 @@ THRIFT_REGISTER_GENERATOR(
" Add an import line to generated code to find the dynbase class.\n"
" package_prefix='top.package.'\n"
" Package prefix for generated files.\n"
- " old_style: Deprecated. Generate old-style classes.\n"
- " enum: Generates Python's IntEnum, connects thrift to python enums. Python 3.4 and higher.\n"
- " type_hints: Generate type hints and type checks in write method. Requires the enum option.\n"
)
diff --git a/compiler/cpp/tests/CMakeLists.txt b/compiler/cpp/tests/CMakeLists.txt
index 468de6ee846..12e24d73965 100644
--- a/compiler/cpp/tests/CMakeLists.txt
+++ b/compiler/cpp/tests/CMakeLists.txt
@@ -136,7 +136,7 @@ THRIFT_ADD_COMPILER(netstd "Enable compiler for .NET Standard" ON)
THRIFT_ADD_COMPILER(ocaml "Enable compiler for OCaml" ON)
THRIFT_ADD_COMPILER(perl "Enable compiler for Perl" OFF)
THRIFT_ADD_COMPILER(php "Enable compiler for PHP" OFF)
-THRIFT_ADD_COMPILER(py "Enable compiler for Python 2.0" OFF)
+THRIFT_ADD_COMPILER(py "Enable compiler for Python" OFF)
THRIFT_ADD_COMPILER(rb "Enable compiler for Ruby" OFF)
THRIFT_ADD_COMPILER(rs "Enable compiler for Rust" OFF)
THRIFT_ADD_COMPILER(st "Enable compiler for Smalltalk" OFF)
diff --git a/configure.ac b/configure.ac
index a93f7019449..d5bd8fe5a29 100644
--- a/configure.ac
+++ b/configure.ac
@@ -128,7 +128,6 @@ if test "$enable_libs" = "no"; then
with_java="no"
with_kotlin="no"
with_python="no"
- with_py3="no"
with_ruby="no"
with_haxe="no"
with_netstd="no"
@@ -286,7 +285,7 @@ fi
AM_CONDITIONAL(WITH_LUA, [test "$have_lua" = "yes"])
# Find python regardless of with_python value, because it's needed by make cross
-AM_PATH_PYTHON(2.6,, :)
+AM_PATH_PYTHON(3.10,, :)
AX_THRIFT_LIB(python, [Python], yes)
if test "$with_python" = "yes"; then
if test -n "$PYTHON"; then
@@ -300,25 +299,6 @@ fi
AM_CONDITIONAL(WITH_PYTHON, [test "$have_python" = "yes"])
AM_CONDITIONAL(WITH_TWISTED_TEST, [test "$have_trial" = "yes"])
-# Find "python3" executable.
-# It's distro specific and far from ideal but needed to cross test py2-3 at once.
-# TODO: find "python2" if it's 3.x
-have_py3="no"
-AX_THRIFT_LIB(py3, [Py3], yes)
-if test "$with_py3" = "yes"; then
- # if $PYTHON is 2.x then search for python 3. otherwise, $PYTHON is already 3.x
- if $PYTHON --version 2>&1 | grep -q "Python 2"; then
- AC_PATH_PROGS([PYTHON3], [python3 python3.8 python38 python3.7 python37 python3.6 python36 python3.5 python35 python3.4 python34])
- if test -n "$PYTHON3"; then
- have_py3="yes"
- fi
- elif $PYTHON --version 2>&1 | grep -q "Python 3"; then
- have_py3="yes"
- PYTHON3=$PYTHON
- fi
-fi
-AM_CONDITIONAL(WITH_PY3, [test "$have_py3" = "yes"])
-
AX_THRIFT_LIB(perl, [Perl], yes)
if test "$with_perl" = "yes"; then
AC_PATH_PROG([PERL], [perl])
@@ -873,8 +853,6 @@ if test "$have_kotlin" = "yes" ; then MAYBE_KOTLIN="kotlin" ; else MAYBE_KOTLIN=
AC_SUBST([MAYBE_KOTLIN])
if test "$have_python" = "yes" ; then MAYBE_PYTHON="py" ; else MAYBE_PYTHON="" ; fi
AC_SUBST([MAYBE_PYTHON])
-if test "$have_py3" = "yes" ; then MAYBE_PY3="py3" ; else MAYBE_PY3="" ; fi
-AC_SUBST([MAYBE_PY3])
if test "$have_ruby" = "yes" ; then MAYBE_RUBY="rb" ; else MAYBE_RUBY="" ; fi
AC_SUBST([MAYBE_RUBY])
if test "$have_perl" = "yes" ; then MAYBE_PERL="perl" ; else MAYBE_PERL="" ; fi
@@ -924,7 +902,6 @@ echo "Building NodeJS Library ...... : $have_nodejs"
echo "Building Perl Library ........ : $have_perl"
echo "Building PHP Library ......... : $have_php"
echo "Building Python Library ...... : $have_python"
-echo "Building Py3 Library ......... : $have_py3"
echo "Building Ruby Library ........ : $have_ruby"
echo "Building Rust Library ........ : $have_rs"
echo "Building Swift Library ....... : $have_swift"
@@ -1036,10 +1013,6 @@ if test "$have_python" = "yes" ; then
echo "Python Library:"
echo " Using Python .............. : $PYTHON"
echo " Using Python version ...... : $($PYTHON --version 2>&1)"
- if test "$have_py3" = "yes" ; then
- echo " Using Python3 ............. : $PYTHON3"
- echo " Using Python3 version ..... : $($PYTHON3 --version)"
- fi
if test "$have_trial" = "yes"; then
echo " Using trial ............... : $TRIAL"
fi
diff --git a/contrib/Vagrantfile b/contrib/Vagrantfile
index a5371dd82d0..3e508308faa 100644
--- a/contrib/Vagrantfile
+++ b/contrib/Vagrantfile
@@ -43,10 +43,10 @@ sudo apt-get install -qq automake libtool flex bison pkg-config g++ libssl-dev m
sudo apt-get install -qq libboost-dev libboost-test-dev libboost-program-options-dev libboost-filesystem-dev libboost-system-dev libevent-dev
# Java dependencies
-sudo apt-get install -qq ant openjdk-8-jdk maven
+sudo apt-get install -qq ant openjdk-17-jdk maven
# Python dependencies
-sudo apt-get install -qq python-all python-all-dev python-all-dbg python-setuptools python-support
+sudo apt-get install -qq python3-all python3-all-dev python3-all-dbg python3-setuptools
# Ruby dependencies
sudo apt-get install -qq ruby ruby-dev
@@ -56,7 +56,7 @@ sudo gem install bundler rake
sudo apt-get install -qq libbit-vector-perl libclass-accessor-class-perl
# Php dependencies
-sudo apt-get install -qq php5 php5-dev php5-cli php-pear re2c
+sudo apt-get install -qq php php-dev php-cli php-pear re2c
# GlibC dependencies
sudo apt-get install -qq libglib2.0-dev
@@ -72,7 +72,7 @@ sudo apt-get -y install -qq golang golang-go
sudo apt-get install -qq lua5.2 lua5.2-dev
# Node.js dependencies
-sudo apt-get install -qq nodejs nodejs-dev nodejs-legacy npm
+sudo apt-get install -qq nodejs npm
# D dependencies
sudo wget http://master.dl.sourceforge.net/project/d-apt/files/d-apt.list -O /etc/apt/sources.list.d/d-apt.list
@@ -81,17 +81,8 @@ sudo apt-get install -qq xdg-utils dmd-bin
# Customize the system
# ---
-# Default java to latest 1.8 version
-update-java-alternatives -s java-1.8.0-openjdk-amd64
-
-# PHPUnit package broken in ubuntu. see https://bugs.launchpad.net/ubuntu/+source/phpunit/+bug/701544
-sudo apt-get upgrade pear
-sudo pear channel-discover pear.phpunit.de
-sudo pear channel-discover pear.symfony.com
-sudo pear channel-discover components.ez.no
-sudo pear update-channels
-sudo pear upgrade-all
-sudo pear install --alldeps phpunit/PHPUnit
+# Default java to latest 17 version
+update-java-alternatives -s java-1.17.0-openjdk-amd64 || true
date > /etc/vagrant.provisioned
@@ -108,9 +99,9 @@ echo "Finished building Apache Thrift."
SCRIPT
Vagrant.configure("2") do |config|
- # Ubuntu 14.04 LTS (Trusty Tahr)
- config.vm.box = "trusty64"
- config.vm.box_url = "https://cloud-images.ubuntu.com/vagrant/trusty/current/trusty-server-cloudimg-amd64-vagrant-disk1.box"
+ # Ubuntu 22.04 LTS (Jammy Jellyfish)
+ config.vm.box = "ubuntu/jammy64"
+ config.vm.box_url = "https://cloud-images.ubuntu.com/jammy/current/jammy-server-cloudimg-amd64-vagrant.box"
config.vm.synced_folder "../", "/thrift"
diff --git a/contrib/async-test/test-leaf.py b/contrib/async-test/test-leaf.py
index c4772f706a9..808eae2e196 100755
--- a/contrib/async-test/test-leaf.py
+++ b/contrib/async-test/test-leaf.py
@@ -1,5 +1,4 @@
-#!/usr/bin/env python
-
+#!/usr/bin/env python3
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
diff --git a/contrib/fb303/configure.ac b/contrib/fb303/configure.ac
index 73b35ba07a4..0adde2a2417 100644
--- a/contrib/fb303/configure.ac
+++ b/contrib/fb303/configure.ac
@@ -107,7 +107,7 @@ AM_CONDITIONAL(WITH_PHP, [test "$have_php" = "yes"])
AX_THRIFT_LIB(python, [Python], yes)
if test "$with_python" = "yes"; then
- AM_PATH_PYTHON(2.4,, :)
+ AM_PATH_PYTHON(3.10,, :)
if test "x$PYTHON" != "x" && test "x$PYTHON" != "x:" ; then
have_python="yes"
fi
diff --git a/contrib/fb303/py/fb303/FacebookBase.py b/contrib/fb303/py/fb303/FacebookBase.py
index 07db10cd3de..6f0e87d30c8 100644
--- a/contrib/fb303/py/fb303/FacebookBase.py
+++ b/contrib/fb303/py/fb303/FacebookBase.py
@@ -1,5 +1,4 @@
-#!/usr/bin/env python
-
+#!/usr/bin/env python3
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
diff --git a/contrib/fb303/py/fb303_scripts/fb303_simple_mgmt.py b/contrib/fb303/py/fb303_scripts/fb303_simple_mgmt.py
index 62a729e1d8f..9fe46ca93e5 100644
--- a/contrib/fb303/py/fb303_scripts/fb303_simple_mgmt.py
+++ b/contrib/fb303/py/fb303_scripts/fb303_simple_mgmt.py
@@ -1,5 +1,4 @@
-#!/usr/bin/env python
-
+#!/usr/bin/env python3
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
diff --git a/contrib/fb303/py/setup.py b/contrib/fb303/py/setup.py
index c07cf55ca0a..bbc2f0ae491 100644
--- a/contrib/fb303/py/setup.py
+++ b/contrib/fb303/py/setup.py
@@ -1,5 +1,4 @@
-#!/usr/bin/env python
-
+#!/usr/bin/env python3
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
@@ -19,9 +18,7 @@
# under the License.
#
-import sys
-
-from setuptools import Extension, setup
+from setuptools import setup
setup(name='thrift_fb303',
version='1.0.0',
@@ -34,12 +31,18 @@
'fb303',
'fb303_scripts',
],
+ python_requires='>=3.10',
classifiers=[
'Development Status :: 7 - Inactive',
'Environment :: Console',
'Intended Audience :: Developers',
'Programming Language :: Python',
- 'Programming Language :: Python :: 2',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.10',
+ 'Programming Language :: Python :: 3.11',
+ 'Programming Language :: Python :: 3.12',
+ 'Programming Language :: Python :: 3.13',
+ 'Programming Language :: Python :: 3.14',
'Topic :: Software Development :: Libraries',
'Topic :: System :: Networking'
],
diff --git a/contrib/parse_profiling.py b/contrib/parse_profiling.py
index 0be5f29ed7e..42a524f01b3 100755
--- a/contrib/parse_profiling.py
+++ b/contrib/parse_profiling.py
@@ -1,4 +1,4 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
diff --git a/contrib/zeromq/test-client.py b/contrib/zeromq/test-client.py
index d51216e459e..30867879100 100755
--- a/contrib/zeromq/test-client.py
+++ b/contrib/zeromq/test-client.py
@@ -1,4 +1,4 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
import sys
import time
import zmq
diff --git a/contrib/zeromq/test-server.py b/contrib/zeromq/test-server.py
index 299b84c523a..0463ba79627 100755
--- a/contrib/zeromq/test-server.py
+++ b/contrib/zeromq/test-server.py
@@ -1,5 +1,4 @@
-#!/usr/bin/env python
-
+#!/usr/bin/env python3
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
diff --git a/debian/control b/debian/control
index 06c0d483416..5ba2e0d7d43 100644
--- a/debian/control
+++ b/debian/control
@@ -1,10 +1,9 @@
Source: thrift
Section: devel
Priority: extra
-Build-Depends: dotnet-runtime-6.0, dotnet-sdk-6.0, debhelper (>= 9), build-essential, python-dev, ant,
+Build-Depends: dotnet-runtime-6.0, dotnet-sdk-6.0, debhelper (>= 9), build-essential, ant,
ruby-dev | ruby1.9.1-dev, ruby-bundler ,autoconf, automake,
pkg-config, libtool, bison, flex, libboost-dev | libboost1.56-dev | libboost1.63-all-dev,
- python-all, python-setuptools, python-all-dev, python-all-dbg,
python3-all, python3-setuptools, python3-all-dev, python3-all-dbg,
openjdk-17-jdk | openjdk-17-jdk-headless | default-jdk,
libboost-test-dev | libboost-test1.56-dev | libboost-test1.63-dev, libevent-dev, libssl-dev, perl (>= 5.8.0-7),
@@ -14,8 +13,7 @@ Homepage: http://thrift.apache.org/
Vcs-Git: https://github.com/apache/thrift.git
Vcs-Browser: https://github.com/apache/thrift
Standards-Version: 3.9.7
-X-Python-Version: >= 2.6
-X-Python3-Version: >= 3.3
+X-Python3-Version: >= 3.10
Package: thrift-compiler
Architecture: any
@@ -29,39 +27,6 @@ Description: Compiler for Thrift definition files
from .thrift files (containing the definitions) to the language binding
for the supported languages.
-Package: python-thrift
-Architecture: any
-Section: python
-Depends: ${python:Depends}, ${shlibs:Depends}, ${misc:Depends}
-Recommends: python-twisted-web, python-backports.ssl-match-hostname, python-ipaddress
-Provides: ${python:Provides}
-Description: Python bindings for Thrift (Python 2)
- Thrift is a software framework for scalable cross-language services
- development. It combines a software stack with a code generation engine to
- build services that work efficiently and seamlessly.
- .
- This package contains the Python bindings for Thrift. You will need the thrift
- tool (in the thrift-compiler package) to compile your definition to Python
- classes, and then the modules in this package will allow you to use those
- classes in your programs.
- .
- This package installs the library for Python 2.
-
-Package: python-thrift-dbg
-Architecture: any
-Section: debug
-Depends: ${shlibs:Depends}, ${misc:Depends}, python-thrift (= ${binary:Version}), python-all-dbg
-Provides: ${python:Provides}
-Description: Python bindings for Thrift (debug version)
- Thrift is a software framework for scalable cross-language services
- development. It combines a software stack with a code generation engine to
- build services that work efficiently and seamlessly.
- .
- This package contains the Python bindings for Thrift with debugging symbols.
- You will need the thrift tool (in the thrift-compiler package) to compile your
- definition to Python classes, and then the modules in this package will allow
- you to use those classes in your programs.
-
Package: python3-thrift
Architecture: any
Section: python
diff --git a/debian/rules b/debian/rules
index ba886faaefe..a2b270cf89a 100755
--- a/debian/rules
+++ b/debian/rules
@@ -16,8 +16,6 @@
# This has to be exported to make some magic below work.
export DH_OPTIONS
-PYVERS := $(shell pyversions -r)
-
export CPPFLAGS:=$(shell dpkg-buildflags --get CPPFLAGS)
export CFLAGS:=$(shell dpkg-buildflags --get CFLAGS)
export CXXFLAGS:=$(shell dpkg-buildflags --get CXXFLAGS)
@@ -53,10 +51,8 @@ $(CURDIR)/compiler/cpp/thrift build-arch-stamp: configure-stamp
# Python library
cd $(CURDIR)/lib/py && \
- for py in $(PYVERS); do \
- $$py setup.py build; \
- $$py-dbg setup.py build; \
- done
+ python3 setup.py build && \
+ python3-dbg setup.py build
# PHP
cd $(CURDIR)/lib/php/src/ext/thrift_protocol && \
@@ -91,7 +87,7 @@ clean:
dh_testroot
rm -f build-arch-stamp build-indep-stamp configure-stamp
- cd $(CURDIR)/lib/py && python setup.py clean --all
+ cd $(CURDIR)/lib/py && python3 setup.py clean --all
# Add here commands to clean up after the build process.
-$(MAKE) clean
@@ -153,19 +149,9 @@ install-arch:
# Python
cd $(CURDIR)/lib/py && \
- python2 setup.py install --install-layout=deb --no-compile --root=$(CURDIR)/debian/python-thrift && \
- python2-dbg setup.py install --install-layout=deb --no-compile --root=$(CURDIR)/debian/python-thrift-dbg && \
python3 setup.py install --install-layout=deb --no-compile --root=$(CURDIR)/debian/python3-thrift && \
python3-dbg setup.py install --install-layout=deb --no-compile --root=$(CURDIR)/debian/python3-thrift-dbg
- find $(CURDIR)/debian/python-thrift -name "*.py[co]" -print0 | xargs -0 rm -f
- find $(CURDIR)/debian/python-thrift -name "__pycache__" -print0 | xargs -0 rm -fr
- find $(CURDIR)/debian/python-thrift-dbg -name "__pycache__" -print0 | xargs -0 rm -fr
- find $(CURDIR)/debian/python-thrift-dbg -name "*.py[co]" -print0 | xargs -0 rm -f
- find $(CURDIR)/debian/python-thrift-dbg -name "*.py" -print0 | xargs -0 rm -f
- find $(CURDIR)/debian/python-thrift-dbg -name "*.egg-info" -print0 | xargs -0 rm -rf
- find $(CURDIR)/debian/python-thrift-dbg -depth -type d -empty -exec rmdir {} \;
-
find $(CURDIR)/debian/python3-thrift -name "*.py[co]" -print0 | xargs -0 rm -f
find $(CURDIR)/debian/python3-thrift -name "__pycache__" -print0 | xargs -0 rm -fr
find $(CURDIR)/debian/python3-thrift-dbg -name "__pycache__" -print0 | xargs -0 rm -fr
@@ -201,7 +187,6 @@ binary-common:
dh_installman
dh_link
dh_strip -plibthrift0 --dbg-package=libthrift0-dbg
- dh_strip -ppython-thrift --dbg-package=python-thrift-dbg
dh_strip -ppython3-thrift --dbg-package=python3-thrift-dbg
dh_strip -pthrift-compiler
dh_compress
diff --git a/doc/install/README.md b/doc/install/README.md
index 0ebe77c7110..a07fa33ddd9 100644
--- a/doc/install/README.md
+++ b/doc/install/README.md
@@ -31,7 +31,7 @@ These are only required if you choose to build the libraries for the given langu
* Java 17 (latest LTS)
* Gradle 8.4
* C#: Mono 1.2.4 (and pkg-config to detect it) or Visual Studio 2005+
-* Python 2.6 (including header files for extension modules)
+* Python 3.10+ (including header files for extension modules)
* PHP 7.1 (optionally including header files for extension modules)
* Ruby 1.8
* bundler gem
diff --git a/doc/install/debian.md b/doc/install/debian.md
index 3d80531c85e..7fe32a9a0c9 100644
--- a/doc/install/debian.md
+++ b/doc/install/debian.md
@@ -23,7 +23,7 @@ If you would like to build Apache Thrift libraries for other programming languag
* Ruby
* ruby-full ruby-dev ruby-rspec rake rubygems bundler
* Python
- * python-all python-all-dev python-all-dbg
+ * python3-all python3-all-dev python3-all-dbg python3-setuptools (3.10+)
* Perl
* libbit-vector-perl libclass-accessor-class-perl
* Php, install
diff --git a/lib/py/Makefile.am b/lib/py/Makefile.am
index 2be72de2fa3..5b461d413ba 100644
--- a/lib/py/Makefile.am
+++ b/lib/py/Makefile.am
@@ -19,25 +19,30 @@
AUTOMAKE_OPTIONS = serial-tests
DESTDIR ?= /
-if WITH_PY3
-py3-build:
- $(PYTHON3) setup.py build
-py3-test: py3-build
- $(PYTHON3) test/thrift_json.py
- $(PYTHON3) test/thrift_transport.py
- $(PYTHON3) test/test_sslsocket.py
- $(PYTHON3) test/thrift_TBinaryProtocol.py
- $(PYTHON3) test/thrift_TZlibTransport.py
- $(PYTHON3) test/thrift_TCompactProtocol.py
- $(PYTHON3) test/thrift_TNonblockingServer.py
- $(PYTHON3) test/thrift_TSerializer.py
-else
-py3-build:
-py3-test:
-endif
-
-all-local: py3-build
+py-build:
$(PYTHON) setup.py build
+py-test: py-build
+ $(PYTHON) test/thrift_json.py
+ $(PYTHON) test/thrift_transport.py
+ $(PYTHON) test/test_sslcontext_hostname.py
+ $(PYTHON) test/test_sslsocket.py
+ $(PYTHON) test/test_socket.py
+ $(PYTHON) test/thrift_TBinaryProtocol.py
+ $(PYTHON) test/thrift_TZlibTransport.py
+ $(PYTHON) test/thrift_TCompactProtocol.py
+ $(PYTHON) test/thrift_TNonblockingServer.py
+ $(PYTHON) test/thrift_TSerializer.py
+ $(PYTHON) test/test_type_check.py
+py-security:
+ @command -v uv >/dev/null 2>&1 || { \
+ echo "Installing uv..."; \
+ curl -LsSf https://astral.sh/uv/install.sh | sh; \
+ }
+ uv tool run bandit -r src --severity-level medium --confidence-level medium
+ uv tool run semgrep scan --config p/security-audit --config p/python \
+ --severity WARNING --severity ERROR --metrics off src test
+
+all-local: py-build
${THRIFT} --gen py test/test_thrift_file/TestServer.thrift
${THRIFT} --gen py ../../test/v0.16/FuzzTestNoUuid.thrift
@@ -48,16 +53,7 @@ all-local: py3-build
install-exec-hook:
$(PYTHON) -m pip install . --root=$(DESTDIR) --prefix=$(PY_PREFIX) $(PYTHON_SETUPUTIL_ARGS)
-check-local: all py3-test
- $(PYTHON) test/thrift_json.py
- $(PYTHON) test/thrift_transport.py
- $(PYTHON) test/test_sslsocket.py
- $(PYTHON) test/test_socket.py
- $(PYTHON) test/thrift_TBinaryProtocol.py
- $(PYTHON) test/thrift_TZlibTransport.py
- $(PYTHON) test/thrift_TCompactProtocol.py
- $(PYTHON) test/thrift_TNonblockingServer.py
- $(PYTHON) test/thrift_TSerializer.py
+check-local: all py-test py-security
clean-local:
diff --git a/lib/py/setup.py b/lib/py/setup.py
index 2dd2a77aa32..d314e77d178 100644
--- a/lib/py/setup.py
+++ b/lib/py/setup.py
@@ -1,5 +1,4 @@
-#!/usr/bin/env python
-
+#!/usr/bin/env python3
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
@@ -93,9 +92,6 @@ def run_setup(with_binary):
else:
extensions = dict()
- ssl_deps = []
- if sys.hexversion < 0x03050000:
- ssl_deps.append('backports.ssl_match_hostname>=3.5')
tornado_deps = ['tornado>=6.3.0']
twisted_deps = ['twisted>=24.3.0', 'zope.interface>=6.1']
@@ -109,10 +105,9 @@ def run_setup(with_binary):
url='http://thrift.apache.org',
license='Apache License 2.0',
extras_require={
- 'ssl': ssl_deps,
'tornado': tornado_deps,
'twisted': twisted_deps,
- 'all': ssl_deps + tornado_deps + twisted_deps,
+ 'all': tornado_deps + twisted_deps,
},
packages=[
'thrift',
@@ -121,12 +116,18 @@ def run_setup(with_binary):
'thrift.server',
],
package_dir={'thrift': 'src'},
+ python_requires='>=3.10',
classifiers=[
'Development Status :: 5 - Production/Stable',
'Environment :: Console',
'Intended Audience :: Developers',
'Programming Language :: Python',
'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.10',
+ 'Programming Language :: Python :: 3.11',
+ 'Programming Language :: Python :: 3.12',
+ 'Programming Language :: Python :: 3.13',
+ 'Programming Language :: Python :: 3.14',
'Topic :: Software Development :: Libraries',
'Topic :: System :: Networking'
],
diff --git a/lib/py/src/TTornado.py b/lib/py/src/TTornado.py
index c7218301afd..f9cd70cfd45 100644
--- a/lib/py/src/TTornado.py
+++ b/lib/py/src/TTornado.py
@@ -20,7 +20,6 @@
import logging
import socket
import struct
-import warnings
from .transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer
@@ -66,17 +65,7 @@ def _lock_context(self):
class TTornadoStreamTransport(TTransportBase):
"""a framed, buffered transport over a Tornado stream"""
- def __init__(self, host, port, stream=None, io_loop=None):
- if io_loop is not None:
- warnings.warn(
- "The `io_loop` parameter is deprecated and unused. Passing "
- "`io_loop` is unnecessary because Tornado now automatically "
- "provides the current I/O loop via `IOLoop.current()`. "
- "Remove the `io_loop` parameter to ensure compatibility - it "
- "will be removed in a future release.",
- DeprecationWarning,
- stacklevel=2,
- )
+ def __init__(self, host, port, stream=None):
self.host = host
self.port = port
self.io_loop = ioloop.IOLoop.current()
diff --git a/lib/py/src/Thrift.py b/lib/py/src/Thrift.py
index 81fe8cf33fe..63e858d21c3 100644
--- a/lib/py/src/Thrift.py
+++ b/lib/py/src/Thrift.py
@@ -18,7 +18,7 @@
#
-class TType(object):
+class TType:
STOP = 0
VOID = 1
BOOL = 2
@@ -59,14 +59,14 @@ class TType(object):
)
-class TMessageType(object):
+class TMessageType:
CALL = 1
REPLY = 2
EXCEPTION = 3
ONEWAY = 4
-class TProcessor(object):
+class TProcessor:
"""Base class for processor, which works on two streams."""
def process(self, iprot, oprot):
diff --git a/lib/py/src/ext/module.cpp b/lib/py/src/ext/module.cpp
index a1b0e5633e6..6f3a4bea57f 100644
--- a/lib/py/src/ext/module.cpp
+++ b/lib/py/src/ext/module.cpp
@@ -28,8 +28,8 @@
// TODO(dreiss): defval appears to be unused. Look into removing it.
// TODO(dreiss): Make parse_spec_args recursive, and cache the output
// permanently in the object. (Malloc and orphan.)
-// TODO(dreiss): Why do we need cStringIO for reading, why not just char*?
-// Can cStringIO let us work with a BufferedTransport?
+// TODO(dreiss): Why do we need BytesIO for reading, why not just char*?
+// Can BytesIO let us work with a BufferedTransport?
// TODO(dreiss): Don't ignore the rv from cwrite (maybe).
// Doing a benchmark shows that interning actually makes a difference, amazingly.
@@ -70,7 +70,7 @@ static PyObject* encode_impl(PyObject* args) {
static inline long as_long_then_delete(PyObject* value, long default_value) {
ScopedPyObject scope(value);
- long v = PyInt_AsLong(value);
+ long v = PyLong_AsLong(value);
if (INT_CONV_ERROR_OCCURRED(v)) {
PyErr_Clear();
return default_value;
@@ -145,8 +145,6 @@ static PyMethodDef ThriftFastBinaryMethods[] = {
{nullptr, nullptr, 0, nullptr} /* Sentinel */
};
-#if PY_MAJOR_VERSION >= 3
-
static struct PyModuleDef ThriftFastBinaryDef = {PyModuleDef_HEAD_INIT,
"thrift.protocol.fastbinary",
nullptr,
@@ -161,21 +159,9 @@ static struct PyModuleDef ThriftFastBinaryDef = {PyModuleDef_HEAD_INIT,
PyObject* PyInit_fastbinary() {
-#else
-
-#define INITERROR return;
-
-void initfastbinary() {
-
- PycString_IMPORT;
- if (PycStringIO == nullptr)
- INITERROR
-
-#endif
-
#define INIT_INTERN_STRING(value) \
do { \
- INTERN_STRING(value) = PyString_InternFromString(#value); \
+ INTERN_STRING(value) = PyUnicode_InternFromString(#value); \
if (!INTERN_STRING(value)) \
INITERROR \
} while (0)
@@ -188,12 +174,7 @@ void initfastbinary() {
INIT_INTERN_STRING(trans);
#undef INIT_INTERN_STRING
- PyObject* module =
-#if PY_MAJOR_VERSION >= 3
- PyModule_Create(&ThriftFastBinaryDef);
-#else
- Py_InitModule("thrift.protocol.fastbinary", ThriftFastBinaryMethods);
-#endif
+ PyObject* module = PyModule_Create(&ThriftFastBinaryDef);
if (module == nullptr)
INITERROR;
@@ -201,8 +182,6 @@ void initfastbinary() {
PyUnstable_Module_SetGIL(module, Py_MOD_GIL_NOT_USED);
#endif
-#if PY_MAJOR_VERSION >= 3
return module;
-#endif
}
}
diff --git a/lib/py/src/ext/protocol.tcc b/lib/py/src/ext/protocol.tcc
index aad5a3c88e5..de1d5656a09 100644
--- a/lib/py/src/ext/protocol.tcc
+++ b/lib/py/src/ext/protocol.tcc
@@ -20,97 +20,16 @@
#ifndef THRIFT_PY_PROTOCOL_TCC
#define THRIFT_PY_PROTOCOL_TCC
+#include
#include
#define CHECK_RANGE(v, min, max) (((v) <= (max)) && ((v) >= (min)))
#define INIT_OUTBUF_SIZE 128
-#if PY_MAJOR_VERSION < 3
-#include
-#else
-#include
-#endif
-
namespace apache {
namespace thrift {
namespace py {
-#if PY_MAJOR_VERSION < 3
-
-namespace detail {
-
-inline bool input_check(PyObject* input) {
- return PycStringIO_InputCheck(input);
-}
-
-inline EncodeBuffer* new_encode_buffer(size_t size) {
- if (!PycStringIO) {
- PycString_IMPORT;
- }
- if (!PycStringIO) {
- return nullptr;
- }
- return PycStringIO->NewOutput(size);
-}
-
-inline int read_buffer(PyObject* buf, char** output, int len) {
- if (!PycStringIO) {
- PycString_IMPORT;
- }
- if (!PycStringIO) {
- PyErr_SetString(PyExc_ImportError, "failed to import native cStringIO");
- return -1;
- }
- return PycStringIO->cread(buf, output, len);
-}
-}
-
-template
-inline ProtocolBase::~ProtocolBase() {
- if (output_) {
- Py_CLEAR(output_);
- }
-}
-
-template
-inline bool ProtocolBase::isUtf8(PyObject* typeargs) {
- return PyString_Check(typeargs) && !strncmp(PyString_AS_STRING(typeargs), "UTF8", 4);
-}
-
-template
-PyObject* ProtocolBase::getEncodedValue() {
- if (!PycStringIO) {
- PycString_IMPORT;
- }
- if (!PycStringIO) {
- return nullptr;
- }
- return PycStringIO->cgetvalue(output_);
-}
-
-template
-inline bool ProtocolBase::writeBuffer(char* data, size_t size) {
- if (!PycStringIO) {
- PycString_IMPORT;
- }
- if (!PycStringIO) {
- PyErr_SetString(PyExc_ImportError, "failed to import native cStringIO");
- return false;
- }
- int len = PycStringIO->cwrite(output_, data, size);
- if (len < 0) {
- PyErr_SetString(PyExc_IOError, "failed to write to cStringIO object");
- return false;
- }
- if (static_cast(len) != size) {
- PyErr_Format(PyExc_EOFError, "write length mismatch: expected %lu got %d", size, len);
- return false;
- }
- return true;
-}
-
-#else
-
namespace detail {
inline bool input_check(PyObject* input) {
@@ -127,22 +46,14 @@ inline EncodeBuffer* new_encode_buffer(size_t size) {
struct bytesio {
PyObject_HEAD
-#if PY_MINOR_VERSION < 5
- char* buf;
-#else
PyObject* buf;
-#endif
Py_ssize_t pos;
Py_ssize_t string_size;
};
inline int read_buffer(PyObject* buf, char** output, int len) {
bytesio* buf2 = reinterpret_cast(buf);
-#if PY_MINOR_VERSION < 5
- *output = buf2->buf + buf2->pos;
-#else
*output = PyBytes_AS_STRING(buf2->buf) + buf2->pos;
-#endif
Py_ssize_t pos0 = buf2->pos;
buf2->pos = (std::min)(buf2->pos + static_cast(len), buf2->string_size);
return static_cast(buf2->pos - pos0);
@@ -158,8 +69,7 @@ inline ProtocolBase::~ProtocolBase() {
template
inline bool ProtocolBase::isUtf8(PyObject* typeargs) {
- // while condition for py2 is "arg == 'UTF8'", it should be "arg != 'BINARY'" for py3.
- // HACK: check the length and don't bother reading the value
+ // Check if encoding is not 'BINARY' (length 6) - if so, treat as UTF-8
return !PyUnicode_Check(typeargs) || PyUnicode_GET_LENGTH(typeargs) != 6;
}
@@ -183,8 +93,6 @@ inline bool ProtocolBase::writeBuffer(char* data, size_t size) {
return true;
}
-#endif
-
namespace detail {
#define DECLARE_OP_SCOPE(name, op) \
@@ -222,7 +130,7 @@ inline bool check_ssize_t_32(Py_ssize_t len) {
template
bool parse_pyint(PyObject* o, T* ret, int32_t min, int32_t max) {
- long val = PyInt_AsLong(o);
+ long val = PyLong_AsLong(o);
if (INT_CONV_ERROR_OCCURRED(val)) {
return false;
@@ -660,21 +568,21 @@ PyObject* ProtocolBase::decodeValue(TType type, PyObject* typeargs) {
if (!impl()->readI8(v)) {
return nullptr;
}
- return PyInt_FromLong(v);
+ return PyLong_FromLong(v);
}
case T_I16: {
int16_t v = 0;
if (!impl()->readI16(v)) {
return nullptr;
}
- return PyInt_FromLong(v);
+ return PyLong_FromLong(v);
}
case T_I32: {
int32_t v = 0;
if (!impl()->readI32(v)) {
return nullptr;
}
- return PyInt_FromLong(v);
+ return PyLong_FromLong(v);
}
case T_I64: {
@@ -685,7 +593,7 @@ PyObject* ProtocolBase::decodeValue(TType type, PyObject* typeargs) {
// TODO(dreiss): Find out if we can take this fastpath always when
// sizeof(long) == sizeof(long long).
if (CHECK_RANGE(v, LONG_MIN, LONG_MAX)) {
- return PyInt_FromLong((long)v);
+ return PyLong_FromLong((long)v);
}
return PyLong_FromLongLong(v);
}
diff --git a/lib/py/src/ext/types.cpp b/lib/py/src/ext/types.cpp
index 0c20e56224e..d221190b2df 100644
--- a/lib/py/src/ext/types.cpp
+++ b/lib/py/src/ext/types.cpp
@@ -27,11 +27,7 @@ namespace py {
PyObject* ThriftModule = nullptr;
-#if PY_MAJOR_VERSION < 3
-char refill_signature[] = {'s', '#', 'i'};
-#else
const char* refill_signature = "y#i";
-#endif
bool parse_struct_item_spec(StructItemSpec* dest, PyObject* spec_tuple) {
// i'd like to use ParseArgs here, but it seems to be a bottleneck.
@@ -41,12 +37,12 @@ bool parse_struct_item_spec(StructItemSpec* dest, PyObject* spec_tuple) {
return false;
}
- dest->tag = static_cast(PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 0)));
+ dest->tag = static_cast(PyLong_AsLong(PyTuple_GET_ITEM(spec_tuple, 0)));
if (INT_CONV_ERROR_OCCURRED(dest->tag)) {
return false;
}
- dest->type = static_cast(PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 1)));
+ dest->type = static_cast(PyLong_AsLong(PyTuple_GET_ITEM(spec_tuple, 1)));
if (INT_CONV_ERROR_OCCURRED(dest->type)) {
return false;
}
@@ -63,7 +59,7 @@ bool parse_set_list_args(SetListTypeArgs* dest, PyObject* typeargs) {
return false;
}
- dest->element_type = static_cast(PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)));
+ dest->element_type = static_cast(PyLong_AsLong(PyTuple_GET_ITEM(typeargs, 0)));
if (INT_CONV_ERROR_OCCURRED(dest->element_type)) {
return false;
}
@@ -81,12 +77,12 @@ bool parse_map_args(MapTypeArgs* dest, PyObject* typeargs) {
return false;
}
- dest->ktag = static_cast(PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)));
+ dest->ktag = static_cast(PyLong_AsLong(PyTuple_GET_ITEM(typeargs, 0)));
if (INT_CONV_ERROR_OCCURRED(dest->ktag)) {
return false;
}
- dest->vtag = static_cast(PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 2)));
+ dest->vtag = static_cast(PyLong_AsLong(PyTuple_GET_ITEM(typeargs, 2)));
if (INT_CONV_ERROR_OCCURRED(dest->vtag)) {
return false;
}
diff --git a/lib/py/src/ext/types.h b/lib/py/src/ext/types.h
index 9b45dd065f5..0bfce2e9287 100644
--- a/lib/py/src/ext/types.h
+++ b/lib/py/src/ext/types.h
@@ -28,18 +28,8 @@
#endif
#include
-#if PY_MAJOR_VERSION >= 3
-
#include
-// TODO: better macros
-#define PyInt_AsLong(v) PyLong_AsLong(v)
-#define PyInt_FromLong(v) PyLong_FromLong(v)
-
-#define PyString_InternFromString(v) PyUnicode_InternFromString(v)
-
-#endif
-
#define INTERN_STRING(value) _intern_##value
#define INT_CONV_ERROR_OCCURRED(v) (((v) == -1) && PyErr_Occurred())
@@ -123,16 +113,11 @@ struct DecodeBuffer {
ScopedPyObject refill_callable;
};
-#if PY_MAJOR_VERSION < 3
-extern char refill_signature[3];
-typedef PyObject EncodeBuffer;
-#else
extern const char* refill_signature;
struct EncodeBuffer {
std::vector buf;
size_t pos;
};
-#endif
/**
* A cache of the spec_args for a set or list,
diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/protocol/TBase.py
index 6c6ef18e877..4acb21cc928 100644
--- a/lib/py/src/protocol/TBase.py
+++ b/lib/py/src/protocol/TBase.py
@@ -20,7 +20,7 @@
from thrift.transport import TTransport
-class TBase(object):
+class TBase:
__slots__ = ()
def __repr__(self):
diff --git a/lib/py/src/protocol/TBinaryProtocol.py b/lib/py/src/protocol/TBinaryProtocol.py
index af64ec10356..d6fdd51ec3d 100644
--- a/lib/py/src/protocol/TBinaryProtocol.py
+++ b/lib/py/src/protocol/TBinaryProtocol.py
@@ -25,9 +25,8 @@
class TBinaryProtocol(TProtocolBase):
"""Binary implementation of the Thrift protocol driver."""
- # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be
- # positive, converting this into a long. If we hardcode the int value
- # instead it'll stay in 32 bit-land.
+ # NastyHaxx. On 32-bit builds, large hex constants can become long. Use
+ # negative values to keep them in 32-bit range.
# VERSION_MASK = 0xffff0000
VERSION_MASK = -65536
diff --git a/lib/py/src/protocol/TCompactProtocol.py b/lib/py/src/protocol/TCompactProtocol.py
index a3527cd47a3..abe7405eabb 100644
--- a/lib/py/src/protocol/TCompactProtocol.py
+++ b/lib/py/src/protocol/TCompactProtocol.py
@@ -80,7 +80,7 @@ def readVarint(trans):
shift += 7
-class CompactType(object):
+class CompactType:
STOP = 0x00
TRUE = 0x01
FALSE = 0x02
diff --git a/lib/py/src/protocol/TJSONProtocol.py b/lib/py/src/protocol/TJSONProtocol.py
index a42aaa6315d..218cd9c6c95 100644
--- a/lib/py/src/protocol/TJSONProtocol.py
+++ b/lib/py/src/protocol/TJSONProtocol.py
@@ -84,7 +84,7 @@
JTYPES[CTYPES[key]] = key
-class JSONBaseContext(object):
+class JSONBaseContext:
def __init__(self, protocol):
self.protocol = protocol
@@ -178,11 +178,11 @@ def __init__(self, trans):
# We don't have length limit implementation for JSON protocols
@property
- def string_length_limit(senf):
+ def string_length_limit(self):
return None
@property
- def container_length_limit(senf):
+ def container_length_limit(self):
return None
def resetWriteContext(self):
@@ -572,11 +572,11 @@ def getProtocol(self, trans):
return TJSONProtocol(trans)
@property
- def string_length_limit(senf):
+ def string_length_limit(self):
return None
@property
- def container_length_limit(senf):
+ def container_length_limit(self):
return None
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index 5b4f4d85d81..4b441e7d8a4 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -41,7 +41,7 @@ def __init__(self, type=UNKNOWN, message=None):
self.type = type
-class TProtocolBase(object):
+class TProtocolBase:
"""Base class for Thrift protocol driver."""
def __init__(self, trans):
@@ -409,6 +409,6 @@ def checkIntegerLimits(i, bits):
"i64 requires -9223372036854775808 <= number <= 9223372036854775807")
-class TProtocolFactory(object):
+class TProtocolFactory:
def getProtocol(self, trans):
pass
diff --git a/lib/py/src/protocol/TProtocolDecorator.py b/lib/py/src/protocol/TProtocolDecorator.py
index f5546c736e1..87a0693f727 100644
--- a/lib/py/src/protocol/TProtocolDecorator.py
+++ b/lib/py/src/protocol/TProtocolDecorator.py
@@ -18,7 +18,7 @@
#
-class TProtocolDecorator(object):
+class TProtocolDecorator:
def __new__(cls, protocol, *args, **kwargs):
decorated_cls = type(''.join(['Decorated', protocol.__class__.__name__]),
(cls, protocol.__class__),
diff --git a/lib/py/src/server/THttpServer.py b/lib/py/src/server/THttpServer.py
index 21f2c869149..f4432241d6a 100644
--- a/lib/py/src/server/THttpServer.py
+++ b/lib/py/src/server/THttpServer.py
@@ -24,6 +24,7 @@
from thrift.Thrift import TMessageType
from thrift.server import TServer
from thrift.transport import TTransport
+from thrift.transport.sslcompat import enforce_minimum_tls
class ResponseException(Exception):
@@ -117,10 +118,14 @@ def on_begin(self, name, type, seqid):
self.httpd = server_class(server_address, RequestHander)
if (kwargs.get('cafile') or kwargs.get('cert_file') or kwargs.get('key_file')):
- context = ssl.create_default_context(cafile=kwargs.get('cafile'))
- context.check_hostname = False
+ context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ if kwargs.get('cafile'):
+ context.load_verify_locations(cafile=kwargs.get('cafile'))
+ context.verify_mode = ssl.CERT_REQUIRED
+ else:
+ context.verify_mode = ssl.CERT_NONE
context.load_cert_chain(kwargs.get('cert_file'), kwargs.get('key_file'))
- context.verify_mode = ssl.CERT_REQUIRED if kwargs.get('cafile') else ssl.CERT_NONE
+ enforce_minimum_tls(context)
self.httpd.socket = context.wrap_socket(self.httpd.socket, server_side=True)
def serve(self):
diff --git a/lib/py/src/server/TNonblockingServer.py b/lib/py/src/server/TNonblockingServer.py
index a7a40cafb53..a4c7af75949 100644
--- a/lib/py/src/server/TNonblockingServer.py
+++ b/lib/py/src/server/TNonblockingServer.py
@@ -92,7 +92,7 @@ def read(self, *args, **kwargs):
return read
-class Message(object):
+class Message:
def __init__(self, offset, len_, header):
self.offset = offset
self.len = len_
@@ -104,7 +104,7 @@ def end(self):
return self.offset + self.len
-class Connection(object):
+class Connection:
"""Basic class is represented connection.
It can be in state:
@@ -234,7 +234,7 @@ def close(self):
self.socket.close()
-class TNonblockingServer(object):
+class TNonblockingServer:
"""Non-blocking server."""
def __init__(self,
@@ -266,6 +266,9 @@ def prepare(self):
if self.prepared:
return
self.socket.listen()
+ if self.poll:
+ self.poll.register(self.socket.handle.fileno(), select.POLLIN | select.POLLRDNORM)
+ self.poll.register(self._read.fileno(), select.POLLIN | select.POLLRDNORM)
for _ in range(self.threads):
thread = Worker(self.tasks)
thread.daemon = True
@@ -323,17 +326,14 @@ def _poll_select(self):
"""Does poll on open connections, if available."""
remaining = []
- self.poll.register(self.socket.handle.fileno(), select.POLLIN | select.POLLRDNORM)
- self.poll.register(self._read.fileno(), select.POLLIN | select.POLLRDNORM)
-
for i, connection in list(self.clients.items()):
if connection.is_readable():
self.poll.register(connection.fileno(), select.POLLIN | select.POLLRDNORM | select.POLLERR | select.POLLHUP | select.POLLNVAL)
if connection.remaining or connection.received:
remaining.append(connection.fileno())
- if connection.is_writeable():
+ elif connection.is_writeable():
self.poll.register(connection.fileno(), select.POLLOUT | select.POLLWRNORM)
- if connection.is_closed():
+ elif connection.is_closed():
try:
self.poll.unregister(i)
except KeyError:
@@ -403,6 +403,16 @@ def close(self):
for _ in range(self.threads):
self.tasks.put([None, None, None, None, None])
self.socket.close()
+ if self._read:
+ try:
+ self._read.close()
+ finally:
+ self._read = None
+ if self._write:
+ try:
+ self._write.close()
+ finally:
+ self._write = None
self.prepared = False
def serve(self):
diff --git a/lib/py/src/server/TServer.py b/lib/py/src/server/TServer.py
index 81144f14a9b..4940db65713 100644
--- a/lib/py/src/server/TServer.py
+++ b/lib/py/src/server/TServer.py
@@ -29,7 +29,7 @@
logger = logging.getLogger(__name__)
-class TServer(object):
+class TServer:
"""Base interface for a server, which must have a serve() method.
Three constructors for all servers:
diff --git a/lib/py/src/transport/THeaderTransport.py b/lib/py/src/transport/THeaderTransport.py
index 4fb20343020..375cf7919f3 100644
--- a/lib/py/src/transport/THeaderTransport.py
+++ b/lib/py/src/transport/THeaderTransport.py
@@ -37,7 +37,7 @@
HARD_MAX_FRAME_SIZE = 0x3FFFFFFF
-class THeaderClientType(object):
+class THeaderClientType:
HEADERS = 0x00
FRAMED_BINARY = 0x01
@@ -47,16 +47,16 @@ class THeaderClientType(object):
UNFRAMED_COMPACT = 0x04
-class THeaderSubprotocolID(object):
+class THeaderSubprotocolID:
BINARY = 0x00
COMPACT = 0x02
-class TInfoHeaderType(object):
+class TInfoHeaderType:
KEY_VALUE = 0x01
-class THeaderTransformID(object):
+class THeaderTransformID:
ZLIB = 0x01
diff --git a/lib/py/src/transport/THttpClient.py b/lib/py/src/transport/THttpClient.py
index 6281165ea25..aa09e9e7f5d 100644
--- a/lib/py/src/transport/THttpClient.py
+++ b/lib/py/src/transport/THttpClient.py
@@ -28,6 +28,7 @@
import urllib.request
import http.client
+from .sslcompat import enforce_minimum_tls, validate_minimum_tls
from .TTransport import TTransportBase
@@ -63,11 +64,14 @@ def __init__(self, uri_or_host, port=None, path=None, cafile=None, cert_file=Non
self.port = parsed.port or http.client.HTTP_PORT
elif self.scheme == 'https':
self.port = parsed.port or http.client.HTTPS_PORT
- if (cafile or cert_file or key_file) and not ssl_context:
- self.context = ssl.create_default_context(cafile=cafile)
- self.context.load_cert_chain(certfile=cert_file, keyfile=key_file)
- else:
+ if ssl_context is not None:
+ validate_minimum_tls(ssl_context)
self.context = ssl_context
+ else:
+ self.context = ssl.create_default_context(cafile=cafile)
+ if cert_file or key_file:
+ self.context.load_cert_chain(certfile=cert_file, keyfile=key_file)
+ enforce_minimum_tls(self.context)
self.host = parsed.hostname
self.path = parsed.path
if parsed.query:
@@ -102,7 +106,7 @@ def basic_proxy_auth_header(proxy):
ap = "%s:%s" % (urllib.parse.unquote(proxy.username),
urllib.parse.unquote(proxy.password))
cr = base64.b64encode(ap.encode()).strip()
- return "Basic " + six.ensure_str(cr)
+ return "Basic " + cr.decode("ascii")
def using_proxy(self):
return self.realhost is not None
@@ -112,9 +116,11 @@ def open(self):
self.__http = http.client.HTTPConnection(self.host, self.port,
timeout=self.__timeout)
elif self.scheme == 'https':
- self.__http = http.client.HTTPSConnection(self.host, self.port,
- timeout=self.__timeout,
- context=self.context)
+ # Python 3.10+ uses an explicit SSLContext; TLS 1.2+ enforced in __init__.
+ self.__http = http.client.HTTPSConnection( # nosem
+ self.host, self.port,
+ timeout=self.__timeout,
+ context=self.context)
if self.using_proxy():
self.__http.set_tunnel(self.realhost, self.realport,
{"Proxy-Authorization": self.proxy_auth})
diff --git a/lib/py/src/transport/TSSLSocket.py b/lib/py/src/transport/TSSLSocket.py
index dc6c1fb5d31..c02742034a0 100644
--- a/lib/py/src/transport/TSSLSocket.py
+++ b/lib/py/src/transport/TSSLSocket.py
@@ -21,75 +21,64 @@
import os
import socket
import ssl
-import sys
import warnings
-from .sslcompat import _match_has_ipaddress
+from .sslcompat import (
+ validate_minimum_tls,
+ MINIMUM_TLS_VERSION,
+)
from thrift.transport import TSocket
from thrift.transport.TTransport import TTransportException
-_match_hostname = lambda cert, hostname: True
-
logger = logging.getLogger(__name__)
warnings.filterwarnings(
'default', category=DeprecationWarning, module=__name__)
-class TSSLBase(object):
- # SSLContext is not available for Python < 2.7.9
- _has_ssl_context = sys.hexversion >= 0x020709F0
-
- # ciphers argument is not available for Python < 2.7.0
- _has_ciphers = sys.hexversion >= 0x020700F0
-
- # For python >= 2.7.9, use latest TLS that both client and server
- # supports.
- # SSL 2.0 and 3.0 are disabled via ssl.OP_NO_SSLv2 and ssl.OP_NO_SSLv3.
- # For python < 2.7.9, use TLS 1.0 since TLSv1_X nor OP_NO_SSLvX is
- # unavailable.
- # For python < 3.6, use SSLv23 since TLS is not available
- if sys.version_info < (3, 6):
- _default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else \
- ssl.PROTOCOL_TLSv1
- else:
- _default_protocol = ssl.PROTOCOL_TLS_CLIENT if _has_ssl_context else \
- ssl.PROTOCOL_TLSv1
+class TSSLBase:
+ _minimum_tls_version = MINIMUM_TLS_VERSION
def _init_context(self, ssl_version):
- if self._has_ssl_context:
- self._context = ssl.SSLContext(ssl_version)
- if self._context.protocol == ssl.PROTOCOL_SSLv23:
- self._context.options |= ssl.OP_NO_SSLv2
- self._context.options |= ssl.OP_NO_SSLv3
+ """Initialize SSL context with the given version.
+
+ Args:
+ ssl_version: Minimum TLS version to accept. Must be
+ ssl.TLSVersion.TLSv1_2 or ssl.TLSVersion.TLSv1_3.
+ Higher versions are negotiated when available.
+ Deprecated protocol constants are not supported.
+ """
+ if not isinstance(ssl_version, ssl.TLSVersion):
+ raise ValueError(
+ 'ssl_version must be ssl.TLSVersion.TLSv1_2 or ssl.TLSVersion.TLSv1_3. '
+ 'Deprecated protocol constants (PROTOCOL_*) are not supported.'
+ )
+ if ssl_version < self._minimum_tls_version:
+ raise ValueError(
+ 'TLS 1.0/1.1 are not supported; use ssl.TLSVersion.TLSv1_2 or higher.'
+ )
+
+ if self._server_side:
+ protocol = ssl.PROTOCOL_TLS_SERVER
else:
- self._context = None
- self._ssl_version = ssl_version
+ protocol = ssl.PROTOCOL_TLS_CLIENT
+ self._context = ssl.SSLContext(protocol)
+ self._context.minimum_version = ssl_version
+ # Don't set maximum_version - allow negotiation up to newest TLS
@property
def _should_verify(self):
- if self._has_ssl_context:
+ if self._custom_context:
return self._context.verify_mode != ssl.CERT_NONE
- else:
- return self.cert_reqs != ssl.CERT_NONE
+ return self.cert_reqs != ssl.CERT_NONE
@property
def ssl_version(self):
- if self._has_ssl_context:
- return self.ssl_context.protocol
- else:
- return self._ssl_version
+ return self.ssl_context.protocol
@property
def ssl_context(self):
return self._context
- SSL_VERSION = _default_protocol
- """
- Default SSL version.
- For backwards compatibility, it can be modified.
- Use __init__ keyword argument "ssl_version" instead.
- """
-
def _deprecated_arg(self, args, kwargs, pos, key):
if len(args) <= pos:
return
@@ -112,37 +101,22 @@ def _unix_socket_arg(self, host, port, args, kwargs):
return True
return False
- def __getattr__(self, key):
- if key == 'SSL_VERSION':
- warnings.warn(
- 'SSL_VERSION is deprecated.'
- 'please use ssl_version attribute instead.',
- DeprecationWarning, stacklevel=2)
- return self.ssl_version
-
def __init__(self, server_side, host, ssl_opts):
self._server_side = server_side
- if TSSLBase.SSL_VERSION != self._default_protocol:
- warnings.warn(
- 'SSL_VERSION is deprecated.'
- 'please use ssl_version keyword argument instead.',
- DeprecationWarning, stacklevel=2)
self._context = ssl_opts.pop('ssl_context', None)
self._server_hostname = None
if not self._server_side:
self._server_hostname = ssl_opts.pop('server_hostname', host)
if self._context:
self._custom_context = True
+ validate_minimum_tls(self._context)
if ssl_opts:
raise ValueError(
'Incompatible arguments: ssl_context and %s'
% ' '.join(ssl_opts.keys()))
- if not self._has_ssl_context:
- raise ValueError(
- 'ssl_context is not available for this version of Python')
else:
self._custom_context = False
- ssl_version = ssl_opts.pop('ssl_version', TSSLBase.SSL_VERSION)
+ ssl_version = ssl_opts.pop('ssl_version', self._minimum_tls_version)
self._init_context(ssl_version)
self.cert_reqs = ssl_opts.pop('cert_reqs', ssl.CERT_REQUIRED)
self.ca_certs = ssl_opts.pop('ca_certs', None)
@@ -176,35 +150,28 @@ def certfile(self, certfile):
self._certfile = certfile
def _wrap_socket(self, sock):
- if self._has_ssl_context:
- if not self._custom_context:
- self.ssl_context.verify_mode = self.cert_reqs
- if self.certfile:
- self.ssl_context.load_cert_chain(self.certfile,
- self.keyfile)
- if self.ciphers:
- self.ssl_context.set_ciphers(self.ciphers)
- if self.ca_certs:
- self.ssl_context.load_verify_locations(self.ca_certs)
- return self.ssl_context.wrap_socket(
- sock, server_side=self._server_side,
- server_hostname=self._server_hostname)
- else:
- ssl_opts = {
- 'ssl_version': self._ssl_version,
- 'server_side': self._server_side,
- 'ca_certs': self.ca_certs,
- 'keyfile': self.keyfile,
- 'certfile': self.certfile,
- 'cert_reqs': self.cert_reqs,
- }
+ if not self._custom_context:
+ if self._server_side:
+ # Server contexts never perform hostname checks.
+ self.ssl_context.check_hostname = False
+ else:
+ # For client sockets, use OpenSSL hostname checking when we
+ # require a verified server certificate. OpenSSL handles
+ # hostname validation during the TLS handshake.
+ self.ssl_context.check_hostname = (
+ self.cert_reqs in (ssl.CERT_REQUIRED, ssl.CERT_OPTIONAL) and
+ bool(self._server_hostname)
+ )
+ self.ssl_context.verify_mode = self.cert_reqs
+ if self.certfile:
+ self.ssl_context.load_cert_chain(self.certfile, self.keyfile)
if self.ciphers:
- if self._has_ciphers:
- ssl_opts['ciphers'] = self.ciphers
- else:
- logger.warning(
- 'ciphers is specified but ignored due to old Python version')
- return ssl.wrap_socket(sock, **ssl_opts)
+ self.ssl_context.set_ciphers(self.ciphers)
+ if self.ca_certs:
+ self.ssl_context.load_verify_locations(self.ca_certs)
+ return self.ssl_context.wrap_socket(
+ sock, server_side=self._server_side,
+ server_hostname=self._server_hostname)
class TSSLSocket(TSocket.TSocket, TSSLBase):
@@ -226,22 +193,20 @@ def __init__(self, host='localhost', port=9090, *args, **kwargs):
"""Positional arguments: ``host``, ``port``, ``unix_socket``
Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``,
- ``ssl_version``, ``ca_certs``,
- ``ciphers`` (Python 2.7.0 or later),
- ``server_hostname`` (Python 2.7.9 or later)
+ ``ssl_version`` (minimum TLS version, defaults to 1.2),
+ ``ca_certs``, ``ciphers``, ``server_hostname``
Passed to ssl.wrap_socket. See ssl.wrap_socket documentation.
- Alternative keyword arguments: (Python 2.7.9 or later)
+ Alternative keyword arguments:
``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
``server_hostname``: Passed to SSLContext.wrap_socket
Common keyword argument:
- ``validate_callback`` (cert, hostname) -> None:
- Called after SSL handshake. Can raise when hostname does not
- match the cert.
``socket_keepalive`` enable TCP keepalive, default off.
+
+ Note: Hostname verification is handled by OpenSSL during the TLS
+ handshake when cert_reqs=ssl.CERT_REQUIRED and server_hostname is set.
"""
- self.is_valid = False
self.peercert = None
if args:
@@ -268,12 +233,13 @@ def __init__(self, host='localhost', port=9090, *args, **kwargs):
unix_socket = kwargs.pop('unix_socket', None)
socket_keepalive = kwargs.pop('socket_keepalive', False)
- self._validate_callback = kwargs.pop('validate_callback', _match_hostname)
TSSLBase.__init__(self, False, host, kwargs)
TSocket.TSocket.__init__(self, host, port, unix_socket,
socket_keepalive=socket_keepalive)
def close(self):
+ if not self.handle:
+ return
try:
self.handle.settimeout(0.001)
self.handle = self.handle.unwrap()
@@ -306,15 +272,10 @@ def _do_open(self, family, socktype):
def open(self):
super(TSSLSocket, self).open()
+ # Hostname verification is handled by OpenSSL during the TLS handshake
+ # when check_hostname=True is set on the SSLContext.
if self._should_verify:
self.peercert = self.handle.getpeercert()
- try:
- self._validate_callback(self.peercert, self._server_hostname)
- self.is_valid = True
- except TTransportException:
- raise
- except Exception as ex:
- raise TTransportException(message=str(ex), inner=ex)
class TSSLServerSocket(TSocket.TServerSocket, TSSLBase):
@@ -331,18 +292,17 @@ class TSSLServerSocket(TSocket.TServerSocket, TSSLBase):
def __init__(self, host=None, port=9090, *args, **kwargs):
"""Positional arguments: ``host``, ``port``, ``unix_socket``
- Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``,
- ``ca_certs``, ``ciphers`` (Python 2.7.0 or later)
+ Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``,
+ ``ssl_version`` (minimum TLS version, defaults to 1.2),
+ ``ca_certs``, ``ciphers``
See ssl.wrap_socket documentation.
- Alternative keyword arguments: (Python 2.7.9 or later)
+ Alternative keyword arguments:
``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
- ``server_hostname``: Passed to SSLContext.wrap_socket
- Common keyword argument:
- ``validate_callback`` (cert, hostname) -> None:
- Called after SSL handshake. Can raise when hostname does not
- match the cert.
+ For mTLS (mutual TLS), set cert_reqs=ssl.CERT_REQUIRED and provide
+ ca_certs to verify client certificates. Client certificate validation
+ checks that the certificate is signed by a trusted CA.
"""
if args:
if len(args) > 3:
@@ -356,17 +316,12 @@ def __init__(self, host=None, port=9090, *args, **kwargs):
# Preserve existing behaviors for default values
if 'cert_reqs' not in kwargs:
kwargs['cert_reqs'] = ssl.CERT_NONE
- if'certfile' not in kwargs:
+ if 'certfile' not in kwargs:
kwargs['certfile'] = 'cert.pem'
unix_socket = kwargs.pop('unix_socket', None)
- self._validate_callback = \
- kwargs.pop('validate_callback', _match_hostname)
TSSLBase.__init__(self, True, None, kwargs)
TSocket.TServerSocket.__init__(self, host, port, unix_socket)
- if self._should_verify and not _match_has_ipaddress:
- raise ValueError('Need ipaddress and backports.ssl_match_hostname '
- 'module to verify client certificate')
def setCertfile(self, certfile):
"""Set or change the server certificate file used to wrap new
@@ -398,18 +353,8 @@ def accept(self):
# other exception handling. (but TSimpleServer dies anyway)
return None
- if self._should_verify:
- client.peercert = client.getpeercert()
- try:
- self._validate_callback(client.peercert, addr[0])
- client.is_valid = True
- except Exception:
- logger.warning('Failed to validate client certificate address: %s',
- addr[0], exc_info=True)
- client.close()
- plain_client.close()
- return None
-
+ # For mTLS, OpenSSL validates that the client certificate is signed
+ # by a trusted CA during the handshake (when cert_reqs=CERT_REQUIRED).
result = TSocket.TSocket()
result.handle = client
return result
diff --git a/lib/py/src/transport/TSocket.py b/lib/py/src/transport/TSocket.py
index 195bfcb57a9..93a5469d3a2 100644
--- a/lib/py/src/transport/TSocket.py
+++ b/lib/py/src/transport/TSocket.py
@@ -22,7 +22,6 @@
import os
import socket
import sys
-import platform
from .TTransport import TTransportBase, TTransportException, TServerTransportBase
@@ -159,8 +158,7 @@ def open(self):
def read(self, sz):
try:
buff = self.handle.recv(sz)
- # TODO: remove socket.timeout when 3.10 becomes the earliest version of python supported.
- except (socket.timeout, TimeoutError) as e:
+ except TimeoutError as e:
raise TTransportException(type=TTransportException.TIMED_OUT, message="read timeout", inner=e)
except socket.error as e:
if (e.args[0] == errno.ECONNRESET and
@@ -171,7 +169,7 @@ def read(self, sz):
# in lib/cpp/src/transport/TSocket.cpp.
self.close()
# Trigger the check to raise the END_OF_FILE exception below.
- buff = ''
+ buff = b''
else:
raise TTransportException(message="unexpected exception", inner=e)
if len(buff) == 0:
@@ -236,16 +234,14 @@ def listen(self):
eno, message = err.args
if eno == errno.ECONNREFUSED:
os.unlink(res[4])
+ finally:
+ tmp.close()
self.handle = s = socket.socket(res[0], res[1])
if s.family is socket.AF_INET6:
- if platform.system() == 'Windows' and sys.version_info < (3, 8):
- logger.warning('Windows socket defaulting to IPv4 for Python < 3.8: See https://github.com/python/cpython/issues/73701')
- else:
- s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
+ s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- if hasattr(s, 'settimeout'):
- s.settimeout(None)
+ s.settimeout(None)
s.bind(res[4])
s.listen(self._backlog)
diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py
index 4f6b67fe123..799e8627177 100644
--- a/lib/py/src/transport/TTransport.py
+++ b/lib/py/src/transport/TTransport.py
@@ -41,7 +41,7 @@ def __init__(self, type=UNKNOWN, message=None, inner=None):
self.inner = inner
-class TTransportBase(object):
+class TTransportBase:
"""Base class for Thrift transport layer."""
def isOpen(self):
@@ -78,7 +78,7 @@ def flush(self):
# This class should be thought of as an interface.
-class CReadableTransport(object):
+class CReadableTransport:
"""base class for transports that are readable from C"""
# TODO(dreiss): Think about changing this interface to allow us to use
@@ -106,7 +106,7 @@ def cstringio_refill(self, partialread, reqlen):
pass
-class TServerTransportBase(object):
+class TServerTransportBase:
"""Base class for Thrift server transports."""
def listen(self):
@@ -119,14 +119,14 @@ def close(self):
pass
-class TTransportFactoryBase(object):
+class TTransportFactoryBase:
"""Base class for a Transport Factory"""
def getTransport(self, trans):
return trans
-class TBufferedTransportFactory(object):
+class TBufferedTransportFactory:
"""Factory transport that builds buffered transports"""
def getTransport(self, trans):
@@ -251,7 +251,7 @@ def cstringio_refill(self, partialread, reqlen):
raise EOFError()
-class TFramedTransportFactory(object):
+class TFramedTransportFactory:
"""Factory transport that builds framed transports"""
def getTransport(self, trans):
diff --git a/lib/py/src/transport/sslcompat.py b/lib/py/src/transport/sslcompat.py
index 54235ec6d1d..9407f1b3699 100644
--- a/lib/py/src/transport/sslcompat.py
+++ b/lib/py/src/transport/sslcompat.py
@@ -1,107 +1,69 @@
#
-# licensed to the apache software foundation (asf) under one
-# or more contributor license agreements. see the notice file
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
-# regarding copyright ownership. the asf licenses this file
-# to you under the apache license, version 2.0 (the
-# "license"); you may not use this file except in compliance
-# with the license. you may obtain a copy of the license at
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
#
-# http://www.apache.org/licenses/license-2.0
+# http://www.apache.org/licenses/LICENSE-2.0
#
-# unless required by applicable law or agreed to in writing,
-# software distributed under the license is distributed on an
-# "as is" basis, without warranties or conditions of any
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
-import logging
-import sys
+"""SSL compatibility utilities for Thrift.
-from thrift.transport.TTransport import TTransportException
+For Python 3.10+, hostname verification is handled by OpenSSL during the
+TLS handshake when SSLContext.check_hostname is True. This module provides
+TLS version enforcement utilities.
+"""
-logger = logging.getLogger(__name__)
+import ssl
+# Minimum TLS version for all Thrift SSL connections
+MINIMUM_TLS_VERSION = ssl.TLSVersion.TLSv1_2
-def legacy_validate_callback(cert, hostname):
- """legacy method to validate the peer's SSL certificate, and to check
- the commonName of the certificate to ensure it matches the hostname we
- used to make this connection. Does not support subjectAltName records
- in certificates.
- raises TTransportException if the certificate fails validation.
+def enforce_minimum_tls(context):
+ """Enforce TLS 1.2 or higher on an SSLContext.
+
+ This function modifies the context in-place to ensure that TLS 1.2 or higher
+ is used. It raises ValueError if the context's maximum_version is set to a
+ version lower than TLS 1.2.
+
+ Args:
+ context: An ssl.SSLContext to enforce minimum TLS version on
"""
- if 'subject' not in cert:
- raise TTransportException(
- TTransportException.NOT_OPEN,
- 'No SSL certificate found from %s' % hostname)
- fields = cert['subject']
- for field in fields:
- # ensure structure we get back is what we expect
- if not isinstance(field, tuple):
- continue
- cert_pair = field[0]
- if len(cert_pair) < 2:
- continue
- cert_key, cert_value = cert_pair[0:2]
- if cert_key != 'commonName':
- continue
- certhost = cert_value
- # this check should be performed by some sort of Access Manager
- if certhost == hostname:
- # success, cert commonName matches desired hostname
- return
- else:
- raise TTransportException(
- TTransportException.UNKNOWN,
- 'Hostname we connected to "%s" doesn\'t match certificate '
- 'provided commonName "%s"' % (hostname, certhost))
- raise TTransportException(
- TTransportException.UNKNOWN,
- 'Could not validate SSL certificate from host "%s". Cert=%s'
- % (hostname, cert))
+ if context.minimum_version < MINIMUM_TLS_VERSION:
+ context.minimum_version = MINIMUM_TLS_VERSION
+ if (context.maximum_version != ssl.TLSVersion.MAXIMUM_SUPPORTED and
+ context.maximum_version < MINIMUM_TLS_VERSION):
+ raise ValueError('TLS maximum_version must be TLS 1.2 or higher.')
-def _optional_dependencies():
- try:
- import ipaddress # noqa
- logger.debug('ipaddress module is available')
- ipaddr = True
- except ImportError:
- logger.warning('ipaddress module is unavailable')
- ipaddr = False
+def validate_minimum_tls(context):
+ """Validate that an SSLContext uses TLS 1.2 or higher.
- if sys.hexversion < 0x030500F0:
- try:
- from backports.ssl_match_hostname import match_hostname, __version__ as ver
- ver = list(map(int, ver.split('.')))
- logger.debug('backports.ssl_match_hostname module is available')
- match = match_hostname
- if ver[0] * 10 + ver[1] >= 35:
- return ipaddr, match
- else:
- logger.warning('backports.ssl_match_hostname module is too old')
- ipaddr = False
- except ImportError:
- logger.warning('backports.ssl_match_hostname is unavailable')
- ipaddr = False
- try:
- from ssl import match_hostname
- logger.debug('ssl.match_hostname is available')
- match = match_hostname
- except ImportError:
- # https://docs.python.org/3/whatsnew/3.12.html:
- # "Remove the ssl.match_hostname() function. It was deprecated in Python
- # 3.7. OpenSSL performs hostname matching since Python 3.7, Python no
- # longer uses the ssl.match_hostname() function.""
- if sys.version_info[0] > 3 or (sys.version_info[0] == 3 and sys.version_info[1] >= 12):
- match = lambda cert, hostname: True
- else:
- logger.warning('using legacy validation callback')
- match = legacy_validate_callback
- return ipaddr, match
+ Unlike enforce_minimum_tls, this function does not modify the context.
+ It raises ValueError if the context is configured to use TLS versions
+ lower than 1.2.
+ Args:
+ context: An ssl.SSLContext to validate
-_match_has_ipaddress, _match_hostname = _optional_dependencies()
+ Raises:
+ ValueError: If the context allows TLS versions below 1.2
+ """
+ if context.minimum_version < MINIMUM_TLS_VERSION:
+ raise ValueError(
+ 'ssl_context.minimum_version must be TLS 1.2 or higher.')
+ if (context.maximum_version != ssl.TLSVersion.MAXIMUM_SUPPORTED and
+ context.maximum_version < MINIMUM_TLS_VERSION):
+ raise ValueError(
+ 'ssl_context.maximum_version must be TLS 1.2 or higher.')
diff --git a/lib/py/test/.gitignore b/lib/py/test/.gitignore
new file mode 100644
index 00000000000..24060ee0537
--- /dev/null
+++ b/lib/py/test/.gitignore
@@ -0,0 +1,2 @@
+# Generated code from type check tests
+gen-py-typecheck/
diff --git a/lib/py/test/test_socket.py b/lib/py/test/test_socket.py
index 5e25f1a90bf..4cc8cb65396 100644
--- a/lib/py/test/test_socket.py
+++ b/lib/py/test/test_socket.py
@@ -44,6 +44,7 @@ def test_failed_connection_raises_exception(self):
def test_socket_readtimeout_exception(self):
acc = ServerAcceptor(TServerSocket(port=0))
acc.start()
+ acc.await_listening()
sock = TSocket(host="localhost", port=acc.port)
sock.open()
@@ -65,12 +66,13 @@ def test_isOpen_checks_for_readability(self):
timeouts = [
None, # blocking mode
0, # non-blocking mode
- 1.0, # timeout mode
+ 500, # timeout mode (ms)
]
for timeout in timeouts:
acc = ServerAcceptor(TServerSocket(port=0))
acc.start()
+ acc.await_listening()
sock = TSocket(host="localhost", port=acc.port)
self.assertFalse(sock.isOpen())
diff --git a/lib/py/test/test_sslcontext_hostname.py b/lib/py/test/test_sslcontext_hostname.py
new file mode 100644
index 00000000000..79e729e6d66
--- /dev/null
+++ b/lib/py/test/test_sslcontext_hostname.py
@@ -0,0 +1,137 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+"""Tests for SSL hostname verification via OpenSSL.
+
+For Python 3.10+, hostname verification is handled by OpenSSL during the
+TLS handshake when SSLContext.check_hostname is True.
+"""
+
+import os
+import socket
+import ssl
+import unittest
+import warnings
+
+import _import_local_thrift # noqa
+
+from thrift.transport.TSSLSocket import TSSLSocket
+from thrift.transport import sslcompat
+
+SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
+ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR)))
+CA_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'CA.pem')
+
+
+class TSSLSocketHostnameVerificationTest(unittest.TestCase):
+ """Tests that OpenSSL hostname verification is properly configured."""
+
+ def _wrap_client(self, **kwargs):
+ client = TSSLSocket('localhost', 0, **kwargs)
+ sock = socket.socket()
+ ssl_sock = None
+ try:
+ ssl_sock = client._wrap_socket(sock)
+ finally:
+ if ssl_sock is not None:
+ ssl_sock.close()
+ else:
+ sock.close()
+ return client
+
+ def test_check_hostname_enabled_with_verification(self):
+ """check_hostname should be True when CERT_REQUIRED and server_hostname set."""
+ client = self._wrap_client(
+ cert_reqs=ssl.CERT_REQUIRED,
+ ca_certs=CA_CERT,
+ server_hostname='localhost',
+ )
+ self.assertTrue(client.ssl_context.check_hostname)
+
+ def test_check_hostname_disabled_without_server_hostname(self):
+ """check_hostname should be False when no server_hostname."""
+ client = self._wrap_client(
+ cert_reqs=ssl.CERT_REQUIRED,
+ ca_certs=CA_CERT,
+ server_hostname=None,
+ )
+ self.assertFalse(client.ssl_context.check_hostname)
+
+ def test_check_hostname_disabled_with_cert_none(self):
+ """check_hostname should be False when CERT_NONE."""
+ client = self._wrap_client(
+ cert_reqs=ssl.CERT_NONE,
+ server_hostname='localhost',
+ )
+ self.assertFalse(client.ssl_context.check_hostname)
+
+
+class TLSVersionEnforcementTest(unittest.TestCase):
+ """Tests for TLS version enforcement utilities."""
+
+ def test_enforce_minimum_tls_upgrades_version(self):
+ """enforce_minimum_tls should set minimum_version to TLS 1.2."""
+ context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ with warnings.catch_warnings():
+ warnings.filterwarnings('ignore', category=DeprecationWarning)
+ context.minimum_version = ssl.TLSVersion.TLSv1
+ sslcompat.enforce_minimum_tls(context)
+ self.assertEqual(context.minimum_version, ssl.TLSVersion.TLSv1_2)
+
+ def test_enforce_minimum_tls_rejects_low_maximum(self):
+ """enforce_minimum_tls should reject maximum_version below TLS 1.2."""
+ context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ with warnings.catch_warnings():
+ warnings.filterwarnings('ignore', category=DeprecationWarning)
+ context.maximum_version = ssl.TLSVersion.TLSv1_1
+ with self.assertRaises(ValueError):
+ sslcompat.enforce_minimum_tls(context)
+
+ def test_validate_minimum_tls_rejects_low_minimum(self):
+ """validate_minimum_tls should reject minimum_version below TLS 1.2."""
+ context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ context.check_hostname = False
+ context.verify_mode = ssl.CERT_NONE
+ with warnings.catch_warnings():
+ warnings.filterwarnings('ignore', category=DeprecationWarning)
+ context.minimum_version = ssl.TLSVersion.TLSv1
+ with self.assertRaises(ValueError):
+ sslcompat.validate_minimum_tls(context)
+
+ def test_validate_minimum_tls_accepts_tls12(self):
+ """validate_minimum_tls should accept TLS 1.2."""
+ context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ context.check_hostname = False
+ context.verify_mode = ssl.CERT_NONE
+ context.minimum_version = ssl.TLSVersion.TLSv1_2
+ # Should not raise
+ sslcompat.validate_minimum_tls(context)
+
+ def test_validate_minimum_tls_accepts_tls13(self):
+ """validate_minimum_tls should accept TLS 1.3."""
+ context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ context.check_hostname = False
+ context.verify_mode = ssl.CERT_NONE
+ context.minimum_version = ssl.TLSVersion.TLSv1_3
+ # Should not raise
+ sslcompat.validate_minimum_tls(context)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/lib/py/test/test_sslsocket.py b/lib/py/test/test_sslsocket.py
index 2cbf5f8dde2..71532d382f2 100644
--- a/lib/py/test/test_sslsocket.py
+++ b/lib/py/test/test_sslsocket.py
@@ -23,15 +23,18 @@
import os
import platform
import ssl
-import sys
import tempfile
import threading
import unittest
import warnings
+import gc
from contextlib import contextmanager
import _import_local_thrift # noqa
+from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket
+from thrift.transport.TTransport import TTransportException
+
SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR)))
SERVER_PEM = os.path.join(ROOT_DIR, 'test', 'keys', 'server.pem')
@@ -42,6 +45,8 @@
CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.crt')
CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.key')
CLIENT_CA = os.path.join(ROOT_DIR, 'test', 'keys', 'CA.pem')
+EXPIRED_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'expired.crt')
+EXPIRED_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'expired.key')
TEST_CIPHERS = 'DES-CBC3-SHA:ECDHE-RSA-AES128-GCM-SHA256'
@@ -107,21 +112,6 @@ def close(self):
self._server.close()
-# Python 2.6 compat
-class AssertRaises(object):
- def __init__(self, expected):
- self._expected = expected
-
- def __enter__(self):
- pass
-
- def __exit__(self, exc_type, exc_value, traceback):
- if not exc_type or not issubclass(exc_type, self._expected):
- raise Exception('fail')
- return True
-
-
-@unittest.skip("failing SSL test to be fixed in subsequent pull request")
class TSSLSocketTest(unittest.TestCase):
def _server_socket(self, **kwargs):
return TSSLServerSocket(port=0, **kwargs)
@@ -151,25 +141,29 @@ def _assert_connection_failure(self, server, path=None, **client_args):
client.write(b"hello")
client.read(5) # b"there"
finally:
+ try:
+ client.close()
+ except Exception:
+ pass
logging.disable(logging.NOTSET)
def _assert_raises(self, exc):
- if sys.hexversion >= 0x020700F0:
- return self.assertRaises(exc)
- else:
- return AssertRaises(exc)
+ return self.assertRaises(exc)
def _assert_connection_success(self, server, path=None, **client_args):
with self._connectable_client(server, path=path, **client_args) as (acc, client):
+ opened = False
try:
self.assertFalse(client.isOpen())
client.open()
+ opened = True
self.assertTrue(client.isOpen())
client.write(b"hello")
self.assertEqual(client.read(5), b"there")
self.assertTrue(acc.client is not None)
finally:
- client.close()
+ if opened:
+ client.close()
# deprecated feature
def test_deprecation(self):
@@ -232,6 +226,24 @@ def test_unix_domain_socket(self):
finally:
os.unlink(path)
+ def test_unix_socket_listen_closes_probe_socket(self):
+ if platform.system() == 'Windows':
+ print('skipping test_unix_socket_listen_closes_probe_socket')
+ return
+ fd, path = tempfile.mkstemp()
+ os.close(fd)
+ os.unlink(path)
+ server = self._server_socket(unix_socket=path, keyfile=SERVER_KEY, certfile=SERVER_CERT)
+ try:
+ with warnings.catch_warnings():
+ warnings.filterwarnings('error', category=ResourceWarning)
+ server.listen()
+ gc.collect()
+ finally:
+ server.close()
+ if os.path.exists(path):
+ os.unlink(path)
+
def test_server_cert(self):
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
@@ -243,6 +255,15 @@ def test_server_cert(self):
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE)
+ def test_server_hostname_mismatch(self):
+ server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
+ self._assert_connection_failure(
+ server,
+ cert_reqs=ssl.CERT_REQUIRED,
+ ca_certs=SERVER_CERT,
+ server_hostname='notlocalhost',
+ )
+
def test_set_server_cert(self):
server = self._server_socket(keyfile=SERVER_KEY, certfile=CLIENT_CERT)
with self._assert_raises(Exception):
@@ -253,114 +274,200 @@ def test_set_server_cert(self):
self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
def test_client_cert(self):
- if not _match_has_ipaddress:
- print('skipping test_client_cert')
- return
+ # Client presents wrong cert (not trusted by server's CA)
server = self._server_socket(
cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
certfile=SERVER_CERT, ca_certs=CLIENT_CERT)
- self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=SERVER_CERT, keyfile=SERVER_KEY)
+ self._assert_connection_failure(
+ server, cert_reqs=ssl.CERT_NONE, certfile=SERVER_CERT, keyfile=SERVER_KEY)
+ # Client presents valid cert signed by trusted CA
+ # Note: We no longer validate client cert SAN/CN against client IP address.
+ # mTLS just verifies the cert is signed by a trusted CA.
server = self._server_socket(
cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
- certfile=SERVER_CERT, ca_certs=CLIENT_CA)
- self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT_NO_IP, keyfile=CLIENT_KEY_NO_IP)
+ certfile=SERVER_CERT, ca_certs=CLIENT_CERT_NO_IP)
+ self._assert_connection_success(
+ server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT_NO_IP, keyfile=CLIENT_KEY_NO_IP)
server = self._server_socket(
cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
- certfile=SERVER_CERT, ca_certs=CLIENT_CA)
- self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
+ certfile=SERVER_CERT, ca_certs=CLIENT_CERT)
+ self._assert_connection_success(
+ server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
server = self._server_socket(
cert_reqs=ssl.CERT_OPTIONAL, keyfile=SERVER_KEY,
- certfile=SERVER_CERT, ca_certs=CLIENT_CA)
- self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
+ certfile=SERVER_CERT, ca_certs=CLIENT_CERT)
+ self._assert_connection_success(
+ server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
def test_ciphers(self):
- server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
- self._assert_connection_success(server, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS)
+ tls12 = ssl.TLSVersion.TLSv1_2
+ server = self._server_socket(
+ keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS, ssl_version=tls12)
+ self._assert_connection_success(
+ server, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS, ssl_version=tls12)
+
+ # NULL cipher tests don't work reliably on Windows where the SSL
+ # library may ignore invalid cipher specifications rather than failing.
+ # On other platforms, we must force TLS 1.2 only (not just minimum) to
+ # prevent TLS 1.3 from negotiating with its own cipher suites that
+ # aren't affected by set_ciphers('NULL').
+ if platform.system() != 'Windows':
+ # Create server with TLS 1.2 only (no TLS 1.3 fallback)
+ server_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ server_ctx.minimum_version = tls12
+ server_ctx.maximum_version = tls12
+ server_ctx.load_cert_chain(SERVER_CERT, SERVER_KEY)
+ server = self._server_socket(ssl_context=server_ctx)
+
+ # Create client with NULL ciphers - should fail to connect
+ client_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ client_ctx.minimum_version = tls12
+ client_ctx.maximum_version = tls12
+ client_ctx.check_hostname = False
+ client_ctx.verify_mode = ssl.CERT_REQUIRED
+ client_ctx.load_verify_locations(SERVER_CERT)
+ client_ctx.set_ciphers('NULL')
+ self._assert_connection_failure(server, ssl_context=client_ctx)
+
+ # Same test but server also specifies ciphers
+ server_ctx2 = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ server_ctx2.minimum_version = tls12
+ server_ctx2.maximum_version = tls12
+ server_ctx2.load_cert_chain(SERVER_CERT, SERVER_KEY)
+ server_ctx2.set_ciphers(TEST_CIPHERS)
+ server = self._server_socket(ssl_context=server_ctx2)
+ self._assert_connection_failure(server, ssl_context=client_ctx)
+
+ def test_reject_deprecated_protocol_constants(self):
+ """Verify that deprecated PROTOCOL_* constants are rejected."""
+ # Our implementation requires ssl.TLSVersion enum values, not the
+ # deprecated PROTOCOL_* constants. This test verifies the error message.
+ with self._assert_raises(ValueError):
+ self._server_socket(
+ keyfile=SERVER_KEY,
+ certfile=SERVER_CERT,
+ ssl_version=ssl.PROTOCOL_TLS,
+ )
+ with self._assert_raises(ValueError):
+ TSSLSocket(
+ 'localhost',
+ 0,
+ cert_reqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLS_CLIENT,
+ )
+
+ def test_reject_legacy_tls_versions(self):
+ """Verify that TLS 1.0 and 1.1 are rejected."""
+ with warnings.catch_warnings():
+ warnings.filterwarnings('ignore', category=DeprecationWarning)
+ legacy_versions = (ssl.TLSVersion.TLSv1, ssl.TLSVersion.TLSv1_1)
+ for version in legacy_versions:
+ with self._assert_raises(ValueError):
+ self._server_socket(
+ keyfile=SERVER_KEY,
+ certfile=SERVER_CERT,
+ ssl_version=version,
+ )
+ with self._assert_raises(ValueError):
+ TSSLSocket(
+ 'localhost',
+ 0,
+ cert_reqs=ssl.CERT_NONE,
+ ssl_version=version,
+ )
+
+ def test_default_context_minimum_tls(self):
+ client = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_NONE)
+ try:
+ self.assertGreaterEqual(
+ client.ssl_context.minimum_version,
+ ssl.TLSVersion.TLSv1_2,
+ )
+ if client.ssl_context.maximum_version != ssl.TLSVersion.MAXIMUM_SUPPORTED:
+ self.assertGreaterEqual(
+ client.ssl_context.maximum_version,
+ ssl.TLSVersion.TLSv1_2,
+ )
+ finally:
+ client.close()
- if not TSSLSocket._has_ciphers:
- # unittest.skip is not available for Python 2.6
- print('skipping test_ciphers')
- return
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
- self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL')
-
- server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
- self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL')
-
- def test_ssl2_and_ssl3_disabled(self):
- if not hasattr(ssl, 'PROTOCOL_SSLv3'):
- print('PROTOCOL_SSLv3 is not available')
- else:
- server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
- self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
-
- server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
- self._assert_connection_failure(server, ca_certs=SERVER_CERT)
-
- if not hasattr(ssl, 'PROTOCOL_SSLv2'):
- print('PROTOCOL_SSLv2 is not available')
- else:
- server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
- self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
-
- server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
- self._assert_connection_failure(server, ca_certs=SERVER_CERT)
-
- def test_newer_tls(self):
- if not TSSLSocket._has_ssl_context:
- # unittest.skip is not available for Python 2.6
- print('skipping test_newer_tls')
- return
- if not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
- print('PROTOCOL_TLSv1_2 is not available')
- else:
- server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
- self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
-
- if not hasattr(ssl, 'PROTOCOL_TLSv1_1'):
- print('PROTOCOL_TLSv1_1 is not available')
- else:
- server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
- self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
-
- if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
- print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available')
- else:
- server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
- self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
+ try:
+ self.assertGreaterEqual(
+ server.ssl_context.minimum_version,
+ ssl.TLSVersion.TLSv1_2,
+ )
+ if server.ssl_context.maximum_version != ssl.TLSVersion.MAXIMUM_SUPPORTED:
+ self.assertGreaterEqual(
+ server.ssl_context.maximum_version,
+ ssl.TLSVersion.TLSv1_2,
+ )
+ finally:
+ server.close()
+
+ def test_tls12_supported(self):
+ server = self._server_socket(
+ keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.TLSVersion.TLSv1_2)
+ self._assert_connection_success(
+ server, ca_certs=SERVER_CERT, ssl_version=ssl.TLSVersion.TLSv1_2)
+
+ def test_tls12_context_no_deprecation_warning(self):
+ with warnings.catch_warnings():
+ warnings.filterwarnings(
+ 'error',
+ category=DeprecationWarning,
+ module=r'thrift\.transport\.TSSLSocket',
+ )
+ server = self._server_socket(
+ keyfile=SERVER_KEY,
+ certfile=SERVER_CERT,
+ ssl_version=ssl.TLSVersion.TLSv1_2,
+ )
+ self._assert_connection_success(
+ server,
+ ca_certs=SERVER_CERT,
+ ssl_version=ssl.TLSVersion.TLSv1_2,
+ )
def test_ssl_context(self):
- if not TSSLSocket._has_ssl_context:
- # unittest.skip is not available for Python 2.6
- print('skipping test_ssl_context')
- return
server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
+ server_context.minimum_version = ssl.TLSVersion.TLSv1_2
server_context.load_cert_chain(SERVER_CERT, SERVER_KEY)
- server_context.load_verify_locations(CLIENT_CA)
+ server_context.load_verify_locations(CLIENT_CERT)
server_context.verify_mode = ssl.CERT_REQUIRED
server = self._server_socket(ssl_context=server_context)
client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
+ client_context.minimum_version = ssl.TLSVersion.TLSv1_2
client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY)
client_context.load_verify_locations(SERVER_CERT)
client_context.verify_mode = ssl.CERT_REQUIRED
self._assert_connection_success(server, ssl_context=client_context)
+ def test_ssl_context_requires_tls12(self):
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ with warnings.catch_warnings():
+ warnings.filterwarnings('ignore', category=DeprecationWarning)
+ client_context.minimum_version = ssl.TLSVersion.TLSv1_1
+ with self._assert_raises(ValueError):
+ TSSLSocket('localhost', 0, ssl_context=client_context)
-# Add a dummy test because starting from python 3.12, if all tests in a test
-# file are skipped that's considered an error.
-class DummyTest(unittest.TestCase):
- def test_dummy(self):
- self.assertEqual(0, 0)
+ def test_expired_certificate_rejected(self):
+ """Verify that expired server certificates are rejected."""
+ if not os.path.exists(EXPIRED_CERT):
+ self.skipTest('expired.crt not found in test/keys/')
+ server = self._server_socket(keyfile=EXPIRED_KEY, certfile=EXPIRED_CERT)
+ self._assert_connection_failure(
+ server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=EXPIRED_CERT)
if __name__ == '__main__':
logging.basicConfig(level=logging.WARN)
- from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket, _match_has_ipaddress
+ from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket
from thrift.transport.TTransport import TTransportException
unittest.main()
diff --git a/lib/py/test/test_type_check.py b/lib/py/test/test_type_check.py
new file mode 100644
index 00000000000..4e7a89e4ca8
--- /dev/null
+++ b/lib/py/test/test_type_check.py
@@ -0,0 +1,328 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+"""
+Comprehensive type checking tests for thrift-generated Python code.
+
+Uses Astral's ty type checker to validate that generated code has correct
+and complete Python 3.10+ type hints.
+"""
+
+import glob
+import os
+import shutil
+import subprocess
+import sys
+import unittest
+
+# Add thrift library from build directory to path before any imports
+# This mirrors the pattern used by other tests in this directory
+_TEST_DIR = os.path.dirname(os.path.abspath(__file__))
+_ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(_TEST_DIR)))
+for _libpath in glob.glob(os.path.join(_ROOT_DIR, "lib", "py", "build", "lib.*")):
+ for _pattern in ("-%d.%d", "-%d%d"):
+ _postfix = _pattern % (sys.version_info[0], sys.version_info[1])
+ if _libpath.endswith(_postfix):
+ sys.path.insert(0, _libpath)
+ break
+
+
+def ensure_ty_installed() -> None:
+ """Install ty if not available, using uv."""
+ if shutil.which("ty") is not None:
+ return
+
+ # Try uv first (preferred)
+ if shutil.which("uv") is not None:
+ subprocess.run(
+ ["uv", "tool", "install", "ty"],
+ check=True,
+ capture_output=True,
+ )
+ else:
+ # Fall back to installing uv first, then ty
+ subprocess.run(
+ [sys.executable, "-m", "pip", "install", "uv"],
+ check=True,
+ capture_output=True,
+ )
+ subprocess.run(
+ ["uv", "tool", "install", "ty"],
+ check=True,
+ capture_output=True,
+ )
+
+
+def find_thrift_compiler() -> str:
+ """Find the thrift compiler binary."""
+ # Check PATH first
+ thrift_bin = shutil.which("thrift")
+ if thrift_bin is not None:
+ return thrift_bin
+
+ # Try common build directories
+ test_dir = os.path.dirname(__file__)
+ candidates = [
+ os.path.join(test_dir, "..", "..", "..", "build-compiler", "compiler", "cpp", "bin", "thrift"),
+ os.path.join(test_dir, "..", "..", "..", "compiler", "cpp", "thrift"),
+ os.path.join(test_dir, "..", "..", "..", "build", "compiler", "cpp", "bin", "thrift"),
+ ]
+
+ for candidate in candidates:
+ abs_path = os.path.abspath(candidate)
+ if os.path.exists(abs_path) and os.access(abs_path, os.X_OK):
+ return abs_path
+
+ raise RuntimeError(
+ "thrift compiler not found. Ensure it is in PATH or built in build-compiler/"
+ )
+
+
+def find_thrift_lib_paths() -> list[str]:
+ """Find paths where the thrift library might be located.
+
+ Returns a list of paths to add to ty's extra-search-path.
+ Checks both build directories (for local development) and
+ installed package locations (for CI environments).
+ """
+ paths: list[str] = []
+
+ # Check build directory (local development)
+ test_dir = os.path.dirname(os.path.abspath(__file__))
+ root_dir = os.path.dirname(os.path.dirname(os.path.dirname(test_dir)))
+ for libpath in glob.glob(os.path.join(root_dir, "lib", "py", "build", "lib.*")):
+ for pattern in ("-%d.%d", "-%d%d"):
+ postfix = pattern % (sys.version_info[0], sys.version_info[1])
+ if libpath.endswith(postfix):
+ paths.append(libpath)
+
+ # Check if thrift is importable (installed in site-packages or virtualenv)
+ try:
+ import thrift
+
+ thrift_path = os.path.dirname(os.path.dirname(thrift.__file__))
+ if thrift_path not in paths:
+ paths.append(thrift_path)
+ except ImportError:
+ pass
+
+ # Also check common install locations
+ lib_py_dir = os.path.join(root_dir, "lib", "py")
+ if os.path.isdir(lib_py_dir) and lib_py_dir not in paths:
+ paths.append(lib_py_dir)
+
+ return paths
+
+
+class TypeCheckTest(unittest.TestCase):
+ """Tests that validate type hints in generated Python code."""
+
+ gen_dir: str
+ thrift_lib_paths: list[str]
+
+ @classmethod
+ def setUpClass(cls) -> None:
+ ensure_ty_installed()
+
+ # Paths
+ test_dir = os.path.dirname(__file__)
+ thrift_file = os.path.join(test_dir, "type_check_test.thrift")
+ cls.gen_dir = os.path.join(test_dir, "gen-py-typecheck")
+ cls.thrift_lib_paths = find_thrift_lib_paths()
+
+ # Find thrift compiler
+ thrift_bin = find_thrift_compiler()
+
+ # Clean and regenerate
+ if os.path.exists(cls.gen_dir):
+ shutil.rmtree(cls.gen_dir)
+ os.makedirs(cls.gen_dir, exist_ok=True)
+
+ # Run thrift compiler
+ result = subprocess.run(
+ [thrift_bin, "--gen", "py", "-out", cls.gen_dir, thrift_file],
+ capture_output=True,
+ text=True,
+ )
+ if result.returncode != 0:
+ raise RuntimeError(
+ f"thrift compiler failed:\nstdout: {result.stdout}\nstderr: {result.stderr}"
+ )
+
+ # Add generated code to path for import tests
+ sys.path.insert(0, cls.gen_dir)
+
+ @classmethod
+ def tearDownClass(cls) -> None:
+ # Remove generated code from path
+ if cls.gen_dir in sys.path:
+ sys.path.remove(cls.gen_dir)
+ # Clean up generated files
+ if os.path.exists(cls.gen_dir):
+ shutil.rmtree(cls.gen_dir)
+
+ def test_ty_type_check_passes(self) -> None:
+ """Verify generated code passes ty without errors."""
+ # Build ty command with extra search paths for thrift library
+ # and the generated code directory (for relative imports)
+ cmd = ["ty", "check"]
+ for path in self.thrift_lib_paths:
+ cmd.extend(["--extra-search-path", path])
+ # Add the generated code directory to resolve relative imports
+ cmd.extend(["--extra-search-path", self.gen_dir])
+ cmd.append(self.gen_dir)
+
+ result = subprocess.run(
+ cmd,
+ capture_output=True,
+ text=True,
+ )
+
+ self.assertEqual(
+ result.returncode,
+ 0,
+ f"ty check failed:\nstdout: {result.stdout}\nstderr: {result.stderr}",
+ )
+
+ def test_py_typed_marker_exists(self) -> None:
+ """Verify py.typed marker is generated for PEP 561."""
+ py_typed = os.path.join(self.gen_dir, "type_check_test", "py.typed")
+ self.assertTrue(
+ os.path.exists(py_typed),
+ f"py.typed marker missing at {py_typed}",
+ )
+
+ def test_generated_code_is_importable(self) -> None:
+ """Verify generated code can be imported without errors."""
+ from type_check_test import TypeCheckService, constants, ttypes
+
+ # Verify key types exist
+ self.assertTrue(hasattr(ttypes, "Status"))
+ self.assertTrue(hasattr(ttypes, "Priority"))
+ self.assertTrue(hasattr(ttypes, "Primitives"))
+ self.assertTrue(hasattr(ttypes, "RequiredFields"))
+ self.assertTrue(hasattr(ttypes, "OptionalFields"))
+ self.assertTrue(hasattr(ttypes, "DefaultValues"))
+ self.assertTrue(hasattr(ttypes, "Containers"))
+ self.assertTrue(hasattr(ttypes, "NestedContainers"))
+ self.assertTrue(hasattr(ttypes, "NestedStructs"))
+ self.assertTrue(hasattr(ttypes, "WithEnum"))
+ self.assertTrue(hasattr(ttypes, "WithTypedef"))
+ self.assertTrue(hasattr(ttypes, "TestUnion"))
+ self.assertTrue(hasattr(ttypes, "ValidationError"))
+ self.assertTrue(hasattr(ttypes, "NotFoundError"))
+ self.assertTrue(hasattr(ttypes, "Empty"))
+
+ # Verify constants exist
+ self.assertTrue(hasattr(constants, "MAX_ITEMS"))
+ self.assertTrue(hasattr(constants, "DEFAULT_NAME"))
+ self.assertTrue(hasattr(constants, "VALID_STATUSES"))
+ self.assertTrue(hasattr(constants, "STATUS_CODES"))
+ self.assertTrue(hasattr(constants, "DEFAULT_STATUS"))
+
+ # Verify service exists
+ self.assertTrue(hasattr(TypeCheckService, "Client"))
+ self.assertTrue(hasattr(TypeCheckService, "Processor"))
+
+ def test_enum_is_intenum(self) -> None:
+ """Verify enums are generated as IntEnum."""
+ from enum import IntEnum
+
+ from type_check_test import ttypes
+
+ self.assertTrue(issubclass(ttypes.Status, IntEnum))
+ self.assertTrue(issubclass(ttypes.Priority, IntEnum))
+
+ # Verify enum values
+ self.assertEqual(ttypes.Status.PENDING, 0)
+ self.assertEqual(ttypes.Status.ACTIVE, 1)
+ self.assertEqual(ttypes.Status.DONE, 2)
+ self.assertEqual(ttypes.Status.CANCELLED, -1)
+
+ self.assertEqual(ttypes.Priority.LOW, 1)
+ self.assertEqual(ttypes.Priority.CRITICAL, 100)
+
+ def test_struct_instantiation(self) -> None:
+ """Verify structs can be instantiated with type-correct arguments."""
+ from type_check_test import ttypes
+
+ # Test primitives struct
+ p = ttypes.Primitives(
+ boolField=True,
+ byteField=127,
+ i16Field=32767,
+ i32Field=2147483647,
+ i64Field=9223372036854775807,
+ doubleField=3.14,
+ stringField="test",
+ binaryField=b"bytes",
+ )
+ self.assertEqual(p.boolField, True)
+ self.assertEqual(p.stringField, "test")
+ self.assertEqual(p.binaryField, b"bytes")
+
+ # Test containers struct
+ c = ttypes.Containers(
+ stringList=["a", "b", "c"],
+ intList=[1, 2, 3],
+ longSet={1, 2, 3},
+ stringSet={"a", "b"},
+ stringIntMap={"key": 42},
+ longStringMap={1: "one"},
+ )
+ self.assertEqual(c.stringList, ["a", "b", "c"])
+ self.assertEqual(c.stringIntMap, {"key": 42})
+
+ # Test required fields struct
+ r = ttypes.RequiredFields(
+ name="test",
+ id=123,
+ status=ttypes.Status.ACTIVE,
+ )
+ self.assertEqual(r.name, "test")
+ self.assertEqual(r.status, ttypes.Status.ACTIVE)
+
+ # Test union
+ u = ttypes.TestUnion(stringValue="test")
+ self.assertEqual(u.stringValue, "test")
+
+ def test_exception_inheritance(self) -> None:
+ """Verify exceptions inherit from TException and can be raised."""
+ from type_check_test import ttypes
+
+ from thrift.Thrift import TException
+
+ self.assertTrue(issubclass(ttypes.ValidationError, TException))
+ self.assertTrue(issubclass(ttypes.NotFoundError, TException))
+
+ # Test raising and catching
+ try:
+ raise ttypes.ValidationError(
+ message="test error",
+ code=400,
+ fields=["field1", "field2"],
+ )
+ except ttypes.ValidationError as e:
+ self.assertEqual(e.message, "test error")
+ self.assertEqual(e.code, 400)
+ self.assertEqual(e.fields, ["field1", "field2"])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/lib/py/test/thrift_TNonblockingServer.py b/lib/py/test/thrift_TNonblockingServer.py
index 7220879ac5d..cf5362b9b48 100644
--- a/lib/py/test/thrift_TNonblockingServer.py
+++ b/lib/py/test/thrift_TNonblockingServer.py
@@ -22,6 +22,7 @@
import threading
import unittest
import time
+import socket
gen_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "gen-py")
sys.path.append(gen_path)
@@ -51,8 +52,12 @@ def start_server(self):
self.server.serve()
print("------stop server -----\n")
- def close_server(self):
+ def stop(self):
+ """Signal the server to stop. Must be called before close()."""
self.server.stop()
+
+ def close(self):
+ """Close server resources. Call only after serve() has returned."""
self.server.close()
@@ -60,11 +65,15 @@ class Client:
def start_client(self):
transport = TSocket.TSocket("127.0.0.1", 30030)
+ transport.setTimeout(2000)
trans = TTransport.TFramedTransport(transport)
protocol = TBinaryProtocol.TBinaryProtocol(trans)
client = TestServer.Client(protocol)
trans.open()
- self.msg = client.add_and_get_msg("hello thrift")
+ try:
+ self.msg = client.add_and_get_msg("hello thrift")
+ finally:
+ trans.close()
def get_message(self):
try:
@@ -76,25 +85,49 @@ def get_message(self):
class TestNonblockingServer(unittest.TestCase):
+ def _wait_for_server(self, timeout=2.0):
+ deadline = time.monotonic() + timeout
+ while time.monotonic() < deadline:
+ sock = socket.socket()
+ try:
+ if sock.connect_ex(("127.0.0.1", 30030)) == 0:
+ return True
+ finally:
+ sock.close()
+ time.sleep(0.05)
+ return False
+
+ def test_close_closes_socketpair(self):
+ serve = Server()
+ serve.close()
+ self.assertIsNone(serve.server._read)
+ self.assertIsNone(serve.server._write)
def test_normalconnection(self):
serve = Server()
client = Client()
- serve_thread = threading.Thread(target=serve.start_server)
- client_thread = threading.Thread(target=client.start_client)
+ serve_thread = threading.Thread(target=serve.start_server, daemon=True)
+ client_thread = threading.Thread(target=client.start_client, daemon=True)
serve_thread.start()
- time.sleep(10)
+ self.assertTrue(self._wait_for_server(), "server did not start in time")
client_thread.start()
- client_thread.join(0.5)
+ client_thread.join(2.0)
try:
msg = client.get_message()
self.assertEqual("hello thrift", msg)
+ self.assertFalse(client_thread.is_alive(), "client thread did not exit")
except AssertionError as e:
raise e
print("assert failure")
finally:
- serve.close_server()
+ # Stop must be called before waiting for the thread to exit
+ # close() should only be called after serve() has returned,
+ # otherwise it destroys the socket pair used to wake up select()
+ serve.stop()
+ serve_thread.join(10.0)
+ self.assertFalse(serve_thread.is_alive(), "server thread did not exit")
+ serve.close()
if __name__ == '__main__':
diff --git a/lib/py/test/thrift_transport.py b/lib/py/test/thrift_transport.py
index cb1bb0ce71a..662e0825c94 100644
--- a/lib/py/test/thrift_transport.py
+++ b/lib/py/test/thrift_transport.py
@@ -17,11 +17,20 @@
# under the License.
#
-import unittest
import os
+import ssl
+import unittest
+import warnings
import _import_local_thrift # noqa
-from thrift.transport import TTransport
+from thrift.protocol import TBinaryProtocol
+from thrift.server import THttpServer as THttpServerModule
+from thrift.transport import THttpClient, TTransport
+
+SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
+ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR)))
+SERVER_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'server.crt')
+SERVER_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'server.key')
class TestTFileObjectTransport(unittest.TestCase):
@@ -66,5 +75,45 @@ def test_memorybuffer_read(self):
buffer_r.close()
+class TestHttpTls(unittest.TestCase):
+ def test_http_client_minimum_tls(self):
+ client = THttpClient.THttpClient('https://localhost:8443/')
+ self.assertGreaterEqual(client.context.minimum_version, ssl.TLSVersion.TLSv1_2)
+ if client.context.maximum_version != ssl.TLSVersion.MAXIMUM_SUPPORTED:
+ self.assertGreaterEqual(client.context.maximum_version, ssl.TLSVersion.TLSv1_2)
+
+ def test_http_client_rejects_legacy_context(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ with warnings.catch_warnings():
+ warnings.filterwarnings('ignore', category=DeprecationWarning)
+ context.minimum_version = ssl.TLSVersion.TLSv1_1
+ with self.assertRaises(ValueError):
+ THttpClient.THttpClient('https://localhost:8443/', ssl_context=context)
+
+ def test_http_server_minimum_tls(self):
+
+ class DummyProcessor(object):
+ def on_message_begin(self, _on_begin):
+ return None
+
+ def process(self, _iprot, _oprot):
+ return None
+
+ server = THttpServerModule.THttpServer(
+ DummyProcessor(),
+ ('localhost', 0),
+ TBinaryProtocol.TBinaryProtocolFactory(),
+ cert_file=SERVER_CERT,
+ key_file=SERVER_KEY,
+ )
+ try:
+ context = server.httpd.socket.context
+ self.assertGreaterEqual(context.minimum_version, ssl.TLSVersion.TLSv1_2)
+ if context.maximum_version != ssl.TLSVersion.MAXIMUM_SUPPORTED:
+ self.assertGreaterEqual(context.maximum_version, ssl.TLSVersion.TLSv1_2)
+ finally:
+ server.shutdown()
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/lib/py/test/type_check_test.thrift b/lib/py/test/type_check_test.thrift
new file mode 100644
index 00000000000..bd69ce9d598
--- /dev/null
+++ b/lib/py/test/type_check_test.thrift
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/**
+ * Comprehensive test thrift file for validating Python type hints.
+ * Covers all thrift features to ensure generated code passes ty type checking.
+ */
+
+namespace py type_check_test
+
+// ============ ENUMS ============
+enum Status {
+ PENDING = 0,
+ ACTIVE = 1,
+ DONE = 2,
+ CANCELLED = -1, // Negative value
+}
+
+enum Priority {
+ LOW = 1,
+ MEDIUM = 5,
+ HIGH = 10,
+ CRITICAL = 100,
+}
+
+// ============ TYPEDEFS ============
+typedef i64 UserId
+typedef string Email
+typedef list StringList
+typedef map ScoreMap
+
+// ============ STRUCTS ============
+struct Empty {}
+
+struct Primitives {
+ 1: bool boolField,
+ 2: byte byteField,
+ 3: i16 i16Field,
+ 4: i32 i32Field,
+ 5: i64 i64Field,
+ 6: double doubleField,
+ 7: string stringField,
+ 8: binary binaryField,
+}
+
+struct RequiredFields {
+ 1: required string name,
+ 2: required i32 id,
+ 3: required Status status,
+}
+
+struct OptionalFields {
+ 1: optional string name,
+ 2: optional i32 count,
+ 3: optional Status status,
+}
+
+struct DefaultValues {
+ 1: string name = "default",
+ 2: i32 count = 42,
+ 3: Status status = Status.PENDING,
+ 4: list tags = ["a", "b"],
+}
+
+struct Containers {
+ 1: list stringList,
+ 2: list intList,
+ 3: set longSet,
+ 4: set stringSet,
+ 5: map stringIntMap,
+ 6: map longStringMap,
+}
+
+struct NestedContainers {
+ 1: list> matrix,
+ 2: map> mapOfLists,
+ 3: list