pyo3_introspection/
stubs.rs

1use crate::model::{
2    Argument, Arguments, Attribute, Class, Function, Module, PythonIdentifier, TypeHint,
3    TypeHintExpr, VariableLengthArgument,
4};
5use std::collections::{BTreeMap, BTreeSet, HashMap};
6use std::path::PathBuf;
7use std::str::FromStr;
8
9/// Generates the [type stubs](https://typing.readthedocs.io/en/latest/source/stubs.html) of a given module.
10/// It returns a map between the file name and the file content.
11/// The root module stubs will be in the `__init__.pyi` file and the submodules directory
12/// in files with a relevant name.
13pub fn module_stub_files(module: &Module) -> HashMap<PathBuf, String> {
14    let mut output_files = HashMap::new();
15    add_module_stub_files(module, &[], &mut output_files);
16    output_files
17}
18
19fn add_module_stub_files(
20    module: &Module,
21    module_path: &[&str],
22    output_files: &mut HashMap<PathBuf, String>,
23) {
24    let mut file_path = PathBuf::new();
25    for e in module_path {
26        file_path = file_path.join(e);
27    }
28    output_files.insert(
29        file_path.join("__init__.pyi"),
30        module_stubs(module, module_path),
31    );
32    let mut module_path = module_path.to_vec();
33    module_path.push(&module.name);
34    for submodule in &module.modules {
35        if submodule.modules.is_empty() {
36            output_files.insert(
37                file_path.join(format!("{}.pyi", submodule.name)),
38                module_stubs(submodule, &module_path),
39            );
40        } else {
41            add_module_stub_files(submodule, &module_path, output_files);
42        }
43    }
44}
45
46/// Generates the module stubs to a String, not including submodules
47fn module_stubs(module: &Module, parents: &[&str]) -> String {
48    let imports = Imports::create(module, parents);
49    let mut elements = Vec::new();
50    for attribute in &module.attributes {
51        elements.push(attribute_stubs(attribute, &imports));
52    }
53    for class in &module.classes {
54        elements.push(class_stubs(class, &imports));
55    }
56    for function in &module.functions {
57        elements.push(function_stubs(function, &imports));
58    }
59
60    // We generate a __getattr__ method to tag incomplete stubs
61    // See https://typing.python.org/en/latest/guides/writing_stubs.html#incomplete-stubs
62    if module.incomplete && !module.functions.iter().any(|f| f.name == "__getattr__") {
63        elements.push(function_stubs(
64            &Function {
65                name: "__getattr__".into(),
66                decorators: Vec::new(),
67                arguments: Arguments {
68                    positional_only_arguments: Vec::new(),
69                    arguments: vec![Argument {
70                        name: "name".to_string(),
71                        default_value: None,
72                        annotation: Some(TypeHint::Ast(
73                            PythonIdentifier {
74                                module: Some("builtins".into()),
75                                name: "str".into(),
76                            }
77                            .into(),
78                        )),
79                    }],
80                    vararg: None,
81                    keyword_only_arguments: Vec::new(),
82                    kwarg: None,
83                },
84                returns: Some(TypeHint::Ast(
85                    PythonIdentifier {
86                        module: Some("_typeshed".into()),
87                        name: "Incomplete".into(),
88                    }
89                    .into(),
90                )),
91            },
92            &imports,
93        ));
94    }
95
96    let mut final_elements = imports.imports;
97    final_elements.extend(elements);
98
99    let mut output = String::new();
100
101    // We insert two line jumps (i.e. empty strings) only above and below multiple line elements (classes with methods, functions with decorators)
102    for element in final_elements {
103        let is_multiline = element.contains('\n');
104        if is_multiline && !output.is_empty() && !output.ends_with("\n\n") {
105            output.push('\n');
106        }
107        output.push_str(&element);
108        output.push('\n');
109        if is_multiline {
110            output.push('\n');
111        }
112    }
113
114    // We remove a line jump at the end if they are two
115    if output.ends_with("\n\n") {
116        output.pop();
117    }
118    output
119}
120
121fn class_stubs(class: &Class, imports: &Imports) -> String {
122    let mut buffer = String::new();
123    for decorator in &class.decorators {
124        buffer.push('@');
125        imports.serialize_identifier(decorator, &mut buffer);
126        buffer.push('\n');
127    }
128    buffer.push_str("class ");
129    buffer.push_str(&class.name);
130    if !class.bases.is_empty() {
131        buffer.push('(');
132        for (i, base) in class.bases.iter().enumerate() {
133            if i > 0 {
134                buffer.push_str(", ");
135            }
136            imports.serialize_identifier(base, &mut buffer);
137        }
138        buffer.push(')');
139    }
140    buffer.push(':');
141    if class.methods.is_empty() && class.attributes.is_empty() {
142        buffer.push_str(" ...");
143        return buffer;
144    }
145    for attribute in &class.attributes {
146        // We do the indentation
147        buffer.push_str("\n    ");
148        buffer.push_str(&attribute_stubs(attribute, imports).replace('\n', "\n    "));
149    }
150    for method in &class.methods {
151        // We do the indentation
152        buffer.push_str("\n    ");
153        buffer.push_str(&function_stubs(method, imports).replace('\n', "\n    "));
154    }
155    buffer
156}
157
158fn function_stubs(function: &Function, imports: &Imports) -> String {
159    // Signature
160    let mut parameters = Vec::new();
161    for argument in &function.arguments.positional_only_arguments {
162        parameters.push(argument_stub(argument, imports));
163    }
164    if !function.arguments.positional_only_arguments.is_empty() {
165        parameters.push("/".into());
166    }
167    for argument in &function.arguments.arguments {
168        parameters.push(argument_stub(argument, imports));
169    }
170    if let Some(argument) = &function.arguments.vararg {
171        parameters.push(format!(
172            "*{}",
173            variable_length_argument_stub(argument, imports)
174        ));
175    } else if !function.arguments.keyword_only_arguments.is_empty() {
176        parameters.push("*".into());
177    }
178    for argument in &function.arguments.keyword_only_arguments {
179        parameters.push(argument_stub(argument, imports));
180    }
181    if let Some(argument) = &function.arguments.kwarg {
182        parameters.push(format!(
183            "**{}",
184            variable_length_argument_stub(argument, imports)
185        ));
186    }
187    let mut buffer = String::new();
188    for decorator in &function.decorators {
189        buffer.push('@');
190        imports.serialize_identifier(decorator, &mut buffer);
191        buffer.push('\n');
192    }
193    buffer.push_str("def ");
194    buffer.push_str(&function.name);
195    buffer.push('(');
196    buffer.push_str(&parameters.join(", "));
197    buffer.push(')');
198    if let Some(returns) = &function.returns {
199        buffer.push_str(" -> ");
200        type_hint_stub(returns, imports, &mut buffer);
201    }
202    buffer.push_str(": ...");
203    buffer
204}
205
206fn attribute_stubs(attribute: &Attribute, imports: &Imports) -> String {
207    let mut buffer = attribute.name.clone();
208    if let Some(annotation) = &attribute.annotation {
209        buffer.push_str(": ");
210        type_hint_stub(annotation, imports, &mut buffer);
211    }
212    if let Some(value) = &attribute.value {
213        buffer.push_str(" = ");
214        buffer.push_str(value);
215    }
216    buffer
217}
218
219fn argument_stub(argument: &Argument, imports: &Imports) -> String {
220    let mut buffer = argument.name.clone();
221    if let Some(annotation) = &argument.annotation {
222        buffer.push_str(": ");
223        type_hint_stub(annotation, imports, &mut buffer);
224    }
225    if let Some(default_value) = &argument.default_value {
226        buffer.push_str(if argument.annotation.is_some() {
227            " = "
228        } else {
229            "="
230        });
231        buffer.push_str(default_value);
232    }
233    buffer
234}
235
236fn variable_length_argument_stub(argument: &VariableLengthArgument, imports: &Imports) -> String {
237    let mut buffer = argument.name.clone();
238    if let Some(annotation) = &argument.annotation {
239        buffer.push_str(": ");
240        type_hint_stub(annotation, imports, &mut buffer);
241    }
242    buffer
243}
244
245fn type_hint_stub(type_hint: &TypeHint, imports: &Imports, buffer: &mut String) {
246    match type_hint {
247        TypeHint::Ast(t) => imports.serialize_type_hint(t, buffer),
248        TypeHint::Plain(t) => buffer.push_str(t),
249    }
250}
251
252/// Datastructure to deduplicate, validate and generate imports
253#[derive(Default)]
254struct Imports {
255    /// Import lines ready to use
256    imports: Vec<String>,
257    /// Renaming map: from module name and member name return the name to use in type hints
258    renaming: BTreeMap<(String, String), String>,
259}
260
261impl Imports {
262    /// This generates a map from the builtin or module name to the actual alias used in the file
263    ///
264    /// For Python builtins and elements declared by the module the alias is always the actual name.
265    ///
266    /// For other elements, we can alias them using the `from X import Y as Z` syntax.
267    /// So, we first list all builtins and local elements, then iterate on imports
268    /// and create the aliases when needed.
269    fn create(module: &Module, module_parents: &[&str]) -> Self {
270        let mut elements_used_in_annotations = ElementsUsedInAnnotations::new();
271        elements_used_in_annotations.walk_module(module);
272
273        let mut imports = Vec::new();
274        let mut renaming = BTreeMap::new();
275        let mut local_name_to_module_and_attribute = BTreeMap::new();
276
277        // We first process local and built-ins elements, they are never aliased or imported
278        for name in module
279            .classes
280            .iter()
281            .map(|c| c.name.clone())
282            .chain(module.functions.iter().map(|f| f.name.clone()))
283            .chain(module.attributes.iter().map(|a| a.name.clone()))
284            .chain(elements_used_in_annotations.locals)
285        {
286            local_name_to_module_and_attribute.insert(name.clone(), (None, name.clone()));
287        }
288
289        // We compute the set of ways the current module can be named
290        let mut possible_current_module_names = vec![module.name.clone()];
291        let mut current_module_name = Some(module.name.clone());
292        for parent in module_parents.iter().rev() {
293            let path = if let Some(current) = current_module_name {
294                format!("{parent}.{current}")
295            } else {
296                parent.to_string()
297            };
298            possible_current_module_names.push(path.clone());
299            current_module_name = Some(path);
300        }
301
302        // We process then imports, normalizing local imports
303        for (module, attrs) in elements_used_in_annotations.module_members {
304            let normalized_module = if possible_current_module_names.contains(&module) {
305                None
306            } else {
307                Some(module.clone())
308            };
309            let mut import_for_module = Vec::new();
310            for attr in attrs {
311                // We split nested classes A.B in "A" (the part that must be imported and can have naming conflicts) and ".B"
312                let (root_attr, attr_path) = attr
313                    .split_once('.')
314                    .map_or((attr.as_str(), None), |(root, path)| (root, Some(path)));
315                let mut local_name = root_attr.to_owned();
316                let mut already_imported = false;
317                while let Some((possible_conflict_module, possible_conflict_attr)) =
318                    local_name_to_module_and_attribute.get(&local_name)
319                {
320                    if *possible_conflict_module == normalized_module
321                        && *possible_conflict_attr == root_attr
322                    {
323                        // It's the same
324                        already_imported = true;
325                        break;
326                    }
327                    // We generate a new local name
328                    // TODO: we use currently a format like Foo2. It might be nicer to use something like ModFoo
329                    let number_of_digits_at_the_end = local_name
330                        .bytes()
331                        .rev()
332                        .take_while(|b| b.is_ascii_digit())
333                        .count();
334                    let (local_name_prefix, local_name_number) =
335                        local_name.split_at(local_name.len() - number_of_digits_at_the_end);
336                    local_name = format!(
337                        "{local_name_prefix}{}",
338                        u64::from_str(local_name_number).unwrap_or(1) + 1
339                    );
340                }
341                renaming.insert(
342                    (module.clone(), attr.clone()),
343                    if let Some(attr_path) = attr_path {
344                        format!("{local_name}.{attr_path}")
345                    } else {
346                        local_name.clone()
347                    },
348                );
349                if !already_imported {
350                    local_name_to_module_and_attribute.insert(
351                        local_name.clone(),
352                        (normalized_module.clone(), root_attr.to_owned()),
353                    );
354                    let is_not_aliased_builtin =
355                        normalized_module.as_deref() == Some("builtins") && local_name == root_attr;
356                    if !is_not_aliased_builtin {
357                        import_for_module.push(if local_name == root_attr {
358                            local_name
359                        } else {
360                            format!("{root_attr} as {local_name}")
361                        });
362                    }
363                }
364            }
365            if let Some(module) = normalized_module {
366                if !import_for_module.is_empty() {
367                    imports.push(format!(
368                        "from {module} import {}",
369                        import_for_module.join(", ")
370                    ));
371                }
372            }
373        }
374
375        Self { imports, renaming }
376    }
377
378    fn serialize_type_hint(&self, expr: &TypeHintExpr, buffer: &mut String) {
379        match expr {
380            TypeHintExpr::Identifier(id) => {
381                self.serialize_identifier(id, buffer);
382            }
383            TypeHintExpr::Union(elts) => {
384                for (i, elt) in elts.iter().enumerate() {
385                    if i > 0 {
386                        buffer.push_str(" | ");
387                    }
388                    self.serialize_type_hint(elt, buffer);
389                }
390            }
391            TypeHintExpr::Subscript { value, slice } => {
392                self.serialize_type_hint(value, buffer);
393                buffer.push('[');
394                for (i, elt) in slice.iter().enumerate() {
395                    if i > 0 {
396                        buffer.push_str(", ");
397                    }
398                    self.serialize_type_hint(elt, buffer);
399                }
400                buffer.push(']');
401            }
402        }
403    }
404
405    fn serialize_identifier(&self, id: &PythonIdentifier, buffer: &mut String) {
406        buffer.push_str(if let Some(module) = &id.module {
407            self.renaming
408                .get(&(module.clone(), id.name.clone()))
409                .expect("All type hint attributes should have been visited")
410        } else {
411            &id.name
412        });
413    }
414}
415
416/// Lists all the elements used in annotations
417struct ElementsUsedInAnnotations {
418    /// module -> name
419    module_members: BTreeMap<String, BTreeSet<String>>,
420    locals: BTreeSet<String>,
421}
422
423impl ElementsUsedInAnnotations {
424    fn new() -> Self {
425        Self {
426            module_members: BTreeMap::new(),
427            locals: BTreeSet::new(),
428        }
429    }
430
431    fn walk_module(&mut self, module: &Module) {
432        for attr in &module.attributes {
433            self.walk_attribute(attr);
434        }
435        for class in &module.classes {
436            self.walk_class(class);
437        }
438        for function in &module.functions {
439            self.walk_function(function);
440        }
441        if module.incomplete {
442            self.module_members
443                .entry("builtins".into())
444                .or_default()
445                .insert("str".into());
446            self.module_members
447                .entry("_typeshed".into())
448                .or_default()
449                .insert("Incomplete".into());
450        }
451    }
452
453    fn walk_class(&mut self, class: &Class) {
454        for base in &class.bases {
455            self.walk_identifier(base);
456        }
457        for decorator in &class.decorators {
458            self.walk_identifier(decorator);
459        }
460        for method in &class.methods {
461            self.walk_function(method);
462        }
463        for attr in &class.attributes {
464            self.walk_attribute(attr);
465        }
466    }
467
468    fn walk_attribute(&mut self, attribute: &Attribute) {
469        if let Some(type_hint) = &attribute.annotation {
470            self.walk_type_hint(type_hint);
471        }
472    }
473
474    fn walk_function(&mut self, function: &Function) {
475        for decorator in &function.decorators {
476            self.walk_identifier(decorator);
477        }
478        for arg in function
479            .arguments
480            .positional_only_arguments
481            .iter()
482            .chain(&function.arguments.arguments)
483            .chain(&function.arguments.keyword_only_arguments)
484        {
485            if let Some(type_hint) = &arg.annotation {
486                self.walk_type_hint(type_hint);
487            }
488        }
489        for arg in function
490            .arguments
491            .vararg
492            .as_ref()
493            .iter()
494            .chain(&function.arguments.kwarg.as_ref())
495        {
496            if let Some(type_hint) = &arg.annotation {
497                self.walk_type_hint(type_hint);
498            }
499        }
500        if let Some(type_hint) = &function.returns {
501            self.walk_type_hint(type_hint);
502        }
503    }
504
505    fn walk_type_hint(&mut self, type_hint: &TypeHint) {
506        if let TypeHint::Ast(type_hint) = type_hint {
507            self.walk_type_hint_expr(type_hint);
508        }
509    }
510
511    fn walk_type_hint_expr(&mut self, expr: &TypeHintExpr) {
512        match expr {
513            TypeHintExpr::Identifier(id) => {
514                self.walk_identifier(id);
515            }
516            TypeHintExpr::Union(elts) => {
517                for elt in elts {
518                    self.walk_type_hint_expr(elt)
519                }
520            }
521            TypeHintExpr::Subscript { value, slice } => {
522                self.walk_type_hint_expr(value);
523                for elt in slice {
524                    self.walk_type_hint_expr(elt);
525                }
526            }
527        }
528    }
529
530    fn walk_identifier(&mut self, id: &PythonIdentifier) {
531        if let Some(module) = &id.module {
532            self.module_members
533                .entry(module.clone())
534                .or_default()
535                .insert(id.name.clone());
536        } else {
537            self.locals.insert(id.name.clone());
538        }
539    }
540}
541
542#[cfg(test)]
543mod tests {
544    use super::*;
545    use crate::model::Arguments;
546
547    #[test]
548    fn function_stubs_with_variable_length() {
549        let function = Function {
550            name: "func".into(),
551            decorators: Vec::new(),
552            arguments: Arguments {
553                positional_only_arguments: vec![Argument {
554                    name: "posonly".into(),
555                    default_value: None,
556                    annotation: None,
557                }],
558                arguments: vec![Argument {
559                    name: "arg".into(),
560                    default_value: None,
561                    annotation: None,
562                }],
563                vararg: Some(VariableLengthArgument {
564                    name: "varargs".into(),
565                    annotation: None,
566                }),
567                keyword_only_arguments: vec![Argument {
568                    name: "karg".into(),
569                    default_value: None,
570                    annotation: Some(TypeHint::Plain("str".into())),
571                }],
572                kwarg: Some(VariableLengthArgument {
573                    name: "kwarg".into(),
574                    annotation: Some(TypeHint::Plain("str".into())),
575                }),
576            },
577            returns: Some(TypeHint::Plain("list[str]".into())),
578        };
579        assert_eq!(
580            "def func(posonly, /, arg, *varargs, karg: str, **kwarg: str) -> list[str]: ...",
581            function_stubs(&function, &Imports::default())
582        )
583    }
584
585    #[test]
586    fn function_stubs_without_variable_length() {
587        let function = Function {
588            name: "afunc".into(),
589            decorators: Vec::new(),
590            arguments: Arguments {
591                positional_only_arguments: vec![Argument {
592                    name: "posonly".into(),
593                    default_value: Some("1".into()),
594                    annotation: None,
595                }],
596                arguments: vec![Argument {
597                    name: "arg".into(),
598                    default_value: Some("True".into()),
599                    annotation: None,
600                }],
601                vararg: None,
602                keyword_only_arguments: vec![Argument {
603                    name: "karg".into(),
604                    default_value: Some("\"foo\"".into()),
605                    annotation: Some(TypeHint::Plain("str".into())),
606                }],
607                kwarg: None,
608            },
609            returns: None,
610        };
611        assert_eq!(
612            "def afunc(posonly=1, /, arg=True, *, karg: str = \"foo\"): ...",
613            function_stubs(&function, &Imports::default())
614        )
615    }
616
617    #[test]
618    fn test_import() {
619        let big_type = TypeHintExpr::Subscript {
620            value: Box::new(
621                PythonIdentifier {
622                    module: Some("builtins".into()),
623                    name: "dict".into(),
624                }
625                .into(),
626            ),
627            slice: vec![
628                PythonIdentifier {
629                    module: Some("foo.bar".into()),
630                    name: "A".into(),
631                }
632                .into(),
633                TypeHintExpr::Union(vec![
634                    PythonIdentifier {
635                        module: Some("bar".into()),
636                        name: "A".into(),
637                    }
638                    .into(),
639                    PythonIdentifier {
640                        module: Some("foo".into()),
641                        name: "A.C".into(),
642                    }
643                    .into(),
644                    PythonIdentifier {
645                        module: Some("foo".into()),
646                        name: "A.D".into(),
647                    }
648                    .into(),
649                    PythonIdentifier {
650                        module: Some("foo".into()),
651                        name: "B".into(),
652                    }
653                    .into(),
654                    PythonIdentifier {
655                        module: Some("bat".into()),
656                        name: "A".into(),
657                    }
658                    .into(),
659                    PythonIdentifier {
660                        module: None,
661                        name: "int".into(),
662                    }
663                    .into(),
664                    PythonIdentifier {
665                        module: Some("builtins".into()),
666                        name: "int".into(),
667                    }
668                    .into(),
669                    PythonIdentifier {
670                        module: Some("builtins".into()),
671                        name: "float".into(),
672                    }
673                    .into(),
674                ]),
675            ],
676        };
677        let imports = Imports::create(
678            &Module {
679                name: "bar".into(),
680                modules: Vec::new(),
681                classes: vec![Class {
682                    name: "A".into(),
683                    bases: vec![PythonIdentifier {
684                        module: Some("builtins".into()),
685                        name: "dict".into(),
686                    }],
687                    methods: Vec::new(),
688                    attributes: Vec::new(),
689                    decorators: vec![PythonIdentifier {
690                        module: Some("typing".into()),
691                        name: "final".into(),
692                    }],
693                }],
694                functions: vec![Function {
695                    name: String::new(),
696                    decorators: Vec::new(),
697                    arguments: Arguments {
698                        positional_only_arguments: Vec::new(),
699                        arguments: Vec::new(),
700                        vararg: None,
701                        keyword_only_arguments: Vec::new(),
702                        kwarg: None,
703                    },
704                    returns: Some(TypeHint::Ast(big_type.clone())),
705                }],
706                attributes: Vec::new(),
707                incomplete: true,
708            },
709            &["foo"],
710        );
711        assert_eq!(
712            &imports.imports,
713            &[
714                "from _typeshed import Incomplete",
715                "from bat import A as A2",
716                "from builtins import int as int2",
717                "from foo import A as A3, B",
718                "from typing import final"
719            ]
720        );
721        let mut output = String::new();
722        imports.serialize_type_hint(&big_type, &mut output);
723        assert_eq!(
724            output,
725            "dict[A, A | A3.C | A3.D | B | A2 | int | int2 | float]"
726        );
727    }
728}