Skip to content

Introspection: Adds basic input type annotations #5089

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,7 @@ impl pyo3::types::DerefToPyAny for MyClass {}
unsafe impl pyo3::type_object::PyTypeInfo for MyClass {
const NAME: &'static str = "MyClass";
const MODULE: ::std::option::Option<&'static str> = ::std::option::Option::None;

#[inline]
fn type_object_raw(py: pyo3::Python<'_>) -> *mut pyo3::ffi::PyTypeObject {
<Self as pyo3::impl_::pyclass::PyClassImpl>::lazy_type_object()
Expand All @@ -1393,6 +1394,8 @@ impl pyo3::PyClass for MyClass {
impl<'a, 'py> pyo3::impl_::extract_argument::PyFunctionArgument<'a, 'py, false> for &'a MyClass
{
type Holder = ::std::option::Option<pyo3::PyRef<'py, MyClass>>;
#[cfg(feature = "experimental-inspect")]
const INPUT_TYPE: &'static str = "MyClass";

#[inline]
fn extract(obj: &'a pyo3::Bound<'py, PyAny>, holder: &'a mut Self::Holder) -> pyo3::PyResult<Self> {
Expand All @@ -1403,6 +1406,8 @@ impl<'a, 'py> pyo3::impl_::extract_argument::PyFunctionArgument<'a, 'py, false>
impl<'a, 'py> pyo3::impl_::extract_argument::PyFunctionArgument<'a, 'py, false> for &'a mut MyClass
{
type Holder = ::std::option::Option<pyo3::PyRefMut<'py, MyClass>>;
#[cfg(feature = "experimental-inspect")]
const INPUT_TYPE: &'static str = "MyClass";

#[inline]
fn extract(obj: &'a pyo3::Bound<'py, PyAny>, holder: &'a mut Self::Holder) -> pyo3::PyResult<Self> {
Expand Down
1 change: 1 addition & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,7 @@ def update_ui_tests(session: nox.Session):
@nox.session(name="test-introspection")
def test_introspection(session: nox.Session):
session.install("maturin")
session.install("ruff")
target = os.environ.get("CARGO_BUILD_TARGET")
for options in ([], ["--release"]):
if target is not None:
Expand Down
3 changes: 3 additions & 0 deletions pyo3-introspection/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,8 @@ goblin = "0.9.0"
serde = { version = "1", features = ["derive"] }
serde_json = "1"

[dev-dependencies]
tempfile = "3.12.0"

[lints]
workspace = true
3 changes: 3 additions & 0 deletions pyo3-introspection/src/introspection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ fn convert_argument(arg: &ChunkArgument) -> Argument {
Argument {
name: arg.name.clone(),
default_value: arg.default.clone(),
annotation: arg.annotation.clone(),
}
}

Expand Down Expand Up @@ -315,4 +316,6 @@ struct ChunkArgument {
name: String,
#[serde(default)]
default: Option<String>,
#[serde(default)]
annotation: Option<String>,
}
2 changes: 2 additions & 0 deletions pyo3-introspection/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ pub struct Argument {
pub name: String,
/// Default value as a Python expression
pub default_value: Option<String>,
/// Type annotation as a Python expression
pub annotation: Option<String>,
}

/// A variable length argument ie. *vararg or **kwarg
Expand Down
51 changes: 38 additions & 13 deletions pyo3-introspection/src/stubs.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::model::{Argument, Class, Function, Module, VariableLengthArgument};
use std::collections::HashMap;
use std::collections::{BTreeSet, HashMap};
use std::path::{Path, PathBuf};

/// Generates the [type stubs](https://typing.readthedocs.io/en/latest/source/stubs.html) of a given module.
Expand Down Expand Up @@ -32,51 +32,70 @@ fn add_module_stub_files(

/// Generates the module stubs to a String, not including submodules
fn module_stubs(module: &Module) -> String {
let mut modules_to_import = BTreeSet::new();
let mut elements = Vec::new();
for class in &module.classes {
elements.push(class_stubs(class));
}
for function in &module.functions {
elements.push(function_stubs(function));
elements.push(function_stubs(function, &mut modules_to_import));
}
elements.push(String::new()); // last line jump
elements.join("\n")

let mut final_elements = Vec::new();
for module_to_import in &modules_to_import {
final_elements.push(format!("import {module_to_import}"));
}
final_elements.extend(elements);
final_elements.join("\n")
}

fn class_stubs(class: &Class) -> String {
format!("class {}: ...", class.name)
}

fn function_stubs(function: &Function) -> String {
fn function_stubs(function: &Function, modules_to_import: &mut BTreeSet<String>) -> String {
// Signature
let mut parameters = Vec::new();
for argument in &function.arguments.positional_only_arguments {
parameters.push(argument_stub(argument));
parameters.push(argument_stub(argument, modules_to_import));
}
if !function.arguments.positional_only_arguments.is_empty() {
parameters.push("/".into());
}
for argument in &function.arguments.arguments {
parameters.push(argument_stub(argument));
parameters.push(argument_stub(argument, modules_to_import));
}
if let Some(argument) = &function.arguments.vararg {
parameters.push(format!("*{}", variable_length_argument_stub(argument)));
} else if !function.arguments.keyword_only_arguments.is_empty() {
parameters.push("*".into());
}
for argument in &function.arguments.keyword_only_arguments {
parameters.push(argument_stub(argument));
parameters.push(argument_stub(argument, modules_to_import));
}
if let Some(argument) = &function.arguments.kwarg {
parameters.push(format!("**{}", variable_length_argument_stub(argument)));
}
format!("def {}({}): ...", function.name, parameters.join(", "))
}

fn argument_stub(argument: &Argument) -> String {
fn argument_stub(argument: &Argument, modules_to_import: &mut BTreeSet<String>) -> String {
let mut output = argument.name.clone();
if let Some(annotation) = &argument.annotation {
output.push_str(": ");
output.push_str(annotation);
if let Some((module, _)) = annotation.rsplit_once('.') {
// TODO: this is very naive
modules_to_import.insert(module.into());
}
}
if let Some(default_value) = &argument.default_value {
output.push('=');
output.push_str(if argument.annotation.is_some() {
" = "
} else {
"="
});
output.push_str(default_value);
}
output
Expand All @@ -99,26 +118,29 @@ mod tests {
positional_only_arguments: vec![Argument {
name: "posonly".into(),
default_value: None,
annotation: None,
}],
arguments: vec![Argument {
name: "arg".into(),
default_value: None,
annotation: None,
}],
vararg: Some(VariableLengthArgument {
name: "varargs".into(),
}),
keyword_only_arguments: vec![Argument {
name: "karg".into(),
default_value: None,
annotation: Some("str".into()),
}],
kwarg: Some(VariableLengthArgument {
name: "kwarg".into(),
}),
},
};
assert_eq!(
"def func(posonly, /, arg, *varargs, karg, **kwarg): ...",
function_stubs(&function)
"def func(posonly, /, arg, *varargs, karg: str, **kwarg): ...",
function_stubs(&function, &mut BTreeSet::new())
)
}

Expand All @@ -130,22 +152,25 @@ mod tests {
positional_only_arguments: vec![Argument {
name: "posonly".into(),
default_value: Some("1".into()),
annotation: None,
}],
arguments: vec![Argument {
name: "arg".into(),
default_value: Some("True".into()),
annotation: None,
}],
vararg: None,
keyword_only_arguments: vec![Argument {
name: "karg".into(),
default_value: Some("\"foo\"".into()),
annotation: Some("str".into()),
}],
kwarg: None,
},
};
assert_eq!(
"def afunc(posonly=1, /, arg=True, *, karg=\"foo\"): ...",
function_stubs(&function)
"def afunc(posonly=1, /, arg=True, *, karg: str = \"foo\"): ...",
function_stubs(&function, &mut BTreeSet::new())
)
}
}
34 changes: 31 additions & 3 deletions pyo3-introspection/tests/test.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use anyhow::Result;
use anyhow::{ensure, Result};
use pyo3_introspection::{introspect_cdylib, module_stub_files};
use std::collections::HashMap;
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::process::Command;
use std::{env, fs};
use tempfile::NamedTempFile;

#[test]
fn pytests_stubs() -> Result<()> {
Expand Down Expand Up @@ -42,9 +45,12 @@ fn pytests_stubs() -> Result<()> {
file_name.display()
)
});

let actual_file_content = format_with_ruff(actual_file_content)?;

assert_eq!(
&expected_file_content.replace('\r', ""), // Windows compatibility
actual_file_content,
expected_file_content.as_str(),
actual_file_content.as_str(),
"The content of file {} is different",
file_name.display()
)
Expand Down Expand Up @@ -75,3 +81,25 @@ fn add_dir_files(
}
Ok(())
}

fn format_with_ruff(code: &str) -> Result<String> {
let temp_file = NamedTempFile::with_suffix(".pyi")?;
// Write to file
{
let mut file = temp_file.as_file();
file.write_all(code.as_bytes())?;
file.flush()?;
file.seek(SeekFrom::Start(0))?;
}
ensure!(
Command::new("ruff")
.arg("format")
.arg(temp_file.path())
.status()?
.success(),
"Failed to run ruff"
);
let mut content = String::new();
temp_file.as_file().read_to_string(&mut content)?;
Ok(content)
}
53 changes: 46 additions & 7 deletions pyo3-macros-backend/src/introspection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

use crate::method::{FnArg, RegularArg};
use crate::pyfunction::FunctionSignature;
use crate::utils::PyO3CratePath;
use crate::utils::{PyO3CratePath, TypeExt};
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, ToTokens};
use std::borrow::Cow;
Expand All @@ -19,7 +19,7 @@ use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::mem::take;
use std::sync::atomic::{AtomicUsize, Ordering};
use syn::{Attribute, Ident};
use syn::{Attribute, Ident, Type};

static GLOBAL_COUNTER_FOR_UNIQUE_NAMES: AtomicUsize = AtomicUsize::new(0);

Expand Down Expand Up @@ -179,20 +179,44 @@ fn argument_introspection_data<'a>(
IntrospectionNode::String(desc.default_value().into()),
);
}
if desc.from_py_with.is_none() {
// If from_py_with is set we don't know anything on the input type
if let Some(ty) = desc.option_wrapped_type {
// Special case to properly generate a `T | None` annotation
let ty = ty.clone().elide_lifetimes();
params.insert(
"annotation",
IntrospectionNode::InputType {
rust_type: ty,
nullable: true,
},
);
} else {
let ty = desc.ty.clone().elide_lifetimes();
params.insert(
"annotation",
IntrospectionNode::InputType {
rust_type: ty,
nullable: false,
},
);
}
}
IntrospectionNode::Map(params)
}

enum IntrospectionNode<'a> {
String(Cow<'a, str>),
IntrospectionId(Option<&'a Ident>),
InputType { rust_type: Type, nullable: bool },
Map(HashMap<&'static str, IntrospectionNode<'a>>),
List(Vec<IntrospectionNode<'a>>),
}

impl IntrospectionNode<'_> {
fn emit(self, pyo3_crate_path: &PyO3CratePath) -> TokenStream {
let mut content = ConcatenationBuilder::default();
self.add_to_serialization(&mut content);
self.add_to_serialization(&mut content, pyo3_crate_path);
let content = content.into_token_stream(pyo3_crate_path);

let static_name = format_ident!("PYO3_INTROSPECTION_0_{}", unique_element_id());
Expand All @@ -206,7 +230,11 @@ impl IntrospectionNode<'_> {
}
}

fn add_to_serialization(self, content: &mut ConcatenationBuilder) {
fn add_to_serialization(
self,
content: &mut ConcatenationBuilder,
pyo3_crate_path: &PyO3CratePath,
) {
match self {
Self::String(string) => {
content.push_str_to_escape(&string);
Expand All @@ -216,10 +244,21 @@ impl IntrospectionNode<'_> {
content.push_tokens(if let Some(ident) = ident {
quote! { #ident::_PYO3_INTROSPECTION_ID }
} else {
Ident::new("_PYO3_INTROSPECTION_ID", Span::call_site()).into_token_stream()
quote! { _PYO3_INTROSPECTION_ID }
});
content.push_str("\"");
}
Self::InputType {
rust_type,
nullable,
} => {
content.push_str("\"");
content.push_tokens(quote! { <#rust_type as #pyo3_crate_path::impl_::extract_argument::PyFunctionArgument<false>>::INPUT_TYPE });
if nullable {
content.push_str(" | None");
}
content.push_str("\"");
}
Self::Map(map) => {
content.push_str("{");
for (i, (key, value)) in map.into_iter().enumerate() {
Expand All @@ -228,7 +267,7 @@ impl IntrospectionNode<'_> {
}
content.push_str_to_escape(key);
content.push_str(":");
value.add_to_serialization(content);
value.add_to_serialization(content, pyo3_crate_path);
}
content.push_str("}");
}
Expand All @@ -238,7 +277,7 @@ impl IntrospectionNode<'_> {
if i > 0 {
content.push_str(",");
}
value.add_to_serialization(content);
value.add_to_serialization(content, pyo3_crate_path);
}
content.push_str("]");
}
Expand Down
Loading
Loading