Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ fn parse(
platform: Option<String>,
always_true: Option<Vec<String>>,
always_false: Option<Vec<String>>,
) -> PyResult<(Vec<u8>, Vec<PyObject>, Vec<PyObject>, Vec<u8>)> {
) -> PyResult<(Vec<u8>, Vec<PyObject>, Vec<PyObject>, Vec<u8>, bool)> {
// Get defaults from Python if not provided
let python_version = match python_version {
Some(v) => v,
Expand All @@ -63,7 +63,7 @@ fn parse(
let always_false = always_false.unwrap_or_default();

let path = Path::new(&fnam);
let (ast_bytes, syntax_errors, type_ignore_lines, import_bytes) = py
let (ast_bytes, syntax_errors, type_ignore_lines, import_bytes, is_partial_package) = py
.allow_threads(|| {
serialize_ast::serialize_python_file(
path,
Expand Down Expand Up @@ -102,7 +102,7 @@ fn parse(
.collect();
let py_type_ignores = py_type_ignores?;

Ok((ast_bytes, py_errors, py_type_ignores, import_bytes))
Ok((ast_bytes, py_errors, py_type_ignores, import_bytes, is_partial_package))
}

/// Get the default Python version from sys.version_info
Expand Down
19 changes: 17 additions & 2 deletions src/serialize_ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,14 @@ pub(crate) fn serialize_python_file(
file_path: &Path,
skip_function_bodies: bool,
options: Options,
) -> Result<(Vec<u8>, Vec<SyntaxError>, Vec<(usize, Vec<String>)>, Vec<u8>)> {
) -> Result<(Vec<u8>, Vec<SyntaxError>, Vec<(usize, Vec<String>)>, Vec<u8>, bool)> {
let source_type = PySourceType::from(file_path);
let source_text = std::fs::read_to_string(file_path)?;
let line_index = LineIndex::from_source_text(&source_text);
let is_stub_package = match file_path.file_name() {
Some(file) => file.as_encoded_bytes() == b"__init__.pyi",
_ => false,
};

// Check if file is all ASCII and build per-line non-ASCII flags if needed
let is_all_ascii = source_text.is_ascii();
Expand Down Expand Up @@ -221,6 +225,7 @@ pub(crate) fn serialize_python_file(
options,
current_unreachable: false,
current_mypy_only: false,
top_level_getattr: false,
};
parsed.syntax().serialize(&mut ser);

Expand All @@ -233,7 +238,10 @@ pub(crate) fn serialize_python_file(
Some(ser.lines_with_non_ascii),
);

Ok((ser.bytes, syntax_errors, type_ignore_lines, import_bytes))
// Return this directly to caller, so that it can check this without deserialization
let is_partial_package = is_stub_package && ser.top_level_getattr;

Ok((ser.bytes, syntax_errors, type_ignore_lines, import_bytes, is_partial_package))
}

// Bit flags for import statement metadata
Expand Down Expand Up @@ -279,6 +287,7 @@ struct Serializer<'a> {
options: Options, // Reachability analysis options
current_unreachable: bool, // Whether we're currently in an unreachable block
current_mypy_only: bool, // Whether we're currently in a mypy-only block (e.g., if TYPE_CHECKING)
top_level_getattr: bool, // Does module have top-level __getattr__() function
}

impl<'a> Serializer<'a> {
Expand Down Expand Up @@ -864,6 +873,10 @@ impl Ser for ast::Stmt {
true
};

if !ser.in_class && !ser.in_function && f.name.as_str() == "__getattr__" {
ser.top_level_getattr = true;
};

// Body - mark that we're inside a function
let was_in_function = ser.in_function;
ser.in_function = true;
Expand Down Expand Up @@ -2580,6 +2593,7 @@ pub fn serialize_imports(
options: Options::default(),
current_unreachable: false,
current_mypy_only: false,
top_level_getattr: false,
};

// Write list of imports
Expand Down Expand Up @@ -2708,6 +2722,7 @@ mod tests {
options: Options::default(),
current_unreachable: false,
current_mypy_only: false,
top_level_getattr: false,
}
}

Expand Down