1use crate::model::{
2 Argument, Arguments, Attribute, Class, Function, Module, PythonIdentifier, TypeHint,
3 TypeHintExpr, VariableLengthArgument,
4};
5use std::collections::{BTreeMap, BTreeSet, HashMap};
6use std::path::PathBuf;
7use std::str::FromStr;
8
9pub fn module_stub_files(module: &Module) -> HashMap<PathBuf, String> {
14 let mut output_files = HashMap::new();
15 add_module_stub_files(module, &[], &mut output_files);
16 output_files
17}
18
19fn add_module_stub_files(
20 module: &Module,
21 module_path: &[&str],
22 output_files: &mut HashMap<PathBuf, String>,
23) {
24 let mut file_path = PathBuf::new();
25 for e in module_path {
26 file_path = file_path.join(e);
27 }
28 output_files.insert(
29 file_path.join("__init__.pyi"),
30 module_stubs(module, module_path),
31 );
32 let mut module_path = module_path.to_vec();
33 module_path.push(&module.name);
34 for submodule in &module.modules {
35 if submodule.modules.is_empty() {
36 output_files.insert(
37 file_path.join(format!("{}.pyi", submodule.name)),
38 module_stubs(submodule, &module_path),
39 );
40 } else {
41 add_module_stub_files(submodule, &module_path, output_files);
42 }
43 }
44}
45
46fn module_stubs(module: &Module, parents: &[&str]) -> String {
48 let imports = Imports::create(module, parents);
49 let mut elements = Vec::new();
50 for attribute in &module.attributes {
51 elements.push(attribute_stubs(attribute, &imports));
52 }
53 for class in &module.classes {
54 elements.push(class_stubs(class, &imports));
55 }
56 for function in &module.functions {
57 elements.push(function_stubs(function, &imports));
58 }
59
60 if module.incomplete && !module.functions.iter().any(|f| f.name == "__getattr__") {
63 elements.push(function_stubs(
64 &Function {
65 name: "__getattr__".into(),
66 decorators: Vec::new(),
67 arguments: Arguments {
68 positional_only_arguments: Vec::new(),
69 arguments: vec![Argument {
70 name: "name".to_string(),
71 default_value: None,
72 annotation: Some(TypeHint::Ast(
73 PythonIdentifier {
74 module: Some("builtins".into()),
75 name: "str".into(),
76 }
77 .into(),
78 )),
79 }],
80 vararg: None,
81 keyword_only_arguments: Vec::new(),
82 kwarg: None,
83 },
84 returns: Some(TypeHint::Ast(
85 PythonIdentifier {
86 module: Some("_typeshed".into()),
87 name: "Incomplete".into(),
88 }
89 .into(),
90 )),
91 },
92 &imports,
93 ));
94 }
95
96 let mut final_elements = imports.imports;
97 final_elements.extend(elements);
98
99 let mut output = String::new();
100
101 for element in final_elements {
103 let is_multiline = element.contains('\n');
104 if is_multiline && !output.is_empty() && !output.ends_with("\n\n") {
105 output.push('\n');
106 }
107 output.push_str(&element);
108 output.push('\n');
109 if is_multiline {
110 output.push('\n');
111 }
112 }
113
114 if output.ends_with("\n\n") {
116 output.pop();
117 }
118 output
119}
120
121fn class_stubs(class: &Class, imports: &Imports) -> String {
122 let mut buffer = String::new();
123 for decorator in &class.decorators {
124 buffer.push('@');
125 imports.serialize_identifier(decorator, &mut buffer);
126 buffer.push('\n');
127 }
128 buffer.push_str("class ");
129 buffer.push_str(&class.name);
130 if !class.bases.is_empty() {
131 buffer.push('(');
132 for (i, base) in class.bases.iter().enumerate() {
133 if i > 0 {
134 buffer.push_str(", ");
135 }
136 imports.serialize_identifier(base, &mut buffer);
137 }
138 buffer.push(')');
139 }
140 buffer.push(':');
141 if class.methods.is_empty() && class.attributes.is_empty() {
142 buffer.push_str(" ...");
143 return buffer;
144 }
145 for attribute in &class.attributes {
146 buffer.push_str("\n ");
148 buffer.push_str(&attribute_stubs(attribute, imports).replace('\n', "\n "));
149 }
150 for method in &class.methods {
151 buffer.push_str("\n ");
153 buffer.push_str(&function_stubs(method, imports).replace('\n', "\n "));
154 }
155 buffer
156}
157
158fn function_stubs(function: &Function, imports: &Imports) -> String {
159 let mut parameters = Vec::new();
161 for argument in &function.arguments.positional_only_arguments {
162 parameters.push(argument_stub(argument, imports));
163 }
164 if !function.arguments.positional_only_arguments.is_empty() {
165 parameters.push("/".into());
166 }
167 for argument in &function.arguments.arguments {
168 parameters.push(argument_stub(argument, imports));
169 }
170 if let Some(argument) = &function.arguments.vararg {
171 parameters.push(format!(
172 "*{}",
173 variable_length_argument_stub(argument, imports)
174 ));
175 } else if !function.arguments.keyword_only_arguments.is_empty() {
176 parameters.push("*".into());
177 }
178 for argument in &function.arguments.keyword_only_arguments {
179 parameters.push(argument_stub(argument, imports));
180 }
181 if let Some(argument) = &function.arguments.kwarg {
182 parameters.push(format!(
183 "**{}",
184 variable_length_argument_stub(argument, imports)
185 ));
186 }
187 let mut buffer = String::new();
188 for decorator in &function.decorators {
189 buffer.push('@');
190 imports.serialize_identifier(decorator, &mut buffer);
191 buffer.push('\n');
192 }
193 buffer.push_str("def ");
194 buffer.push_str(&function.name);
195 buffer.push('(');
196 buffer.push_str(¶meters.join(", "));
197 buffer.push(')');
198 if let Some(returns) = &function.returns {
199 buffer.push_str(" -> ");
200 type_hint_stub(returns, imports, &mut buffer);
201 }
202 buffer.push_str(": ...");
203 buffer
204}
205
206fn attribute_stubs(attribute: &Attribute, imports: &Imports) -> String {
207 let mut buffer = attribute.name.clone();
208 if let Some(annotation) = &attribute.annotation {
209 buffer.push_str(": ");
210 type_hint_stub(annotation, imports, &mut buffer);
211 }
212 if let Some(value) = &attribute.value {
213 buffer.push_str(" = ");
214 buffer.push_str(value);
215 }
216 buffer
217}
218
219fn argument_stub(argument: &Argument, imports: &Imports) -> String {
220 let mut buffer = argument.name.clone();
221 if let Some(annotation) = &argument.annotation {
222 buffer.push_str(": ");
223 type_hint_stub(annotation, imports, &mut buffer);
224 }
225 if let Some(default_value) = &argument.default_value {
226 buffer.push_str(if argument.annotation.is_some() {
227 " = "
228 } else {
229 "="
230 });
231 buffer.push_str(default_value);
232 }
233 buffer
234}
235
236fn variable_length_argument_stub(argument: &VariableLengthArgument, imports: &Imports) -> String {
237 let mut buffer = argument.name.clone();
238 if let Some(annotation) = &argument.annotation {
239 buffer.push_str(": ");
240 type_hint_stub(annotation, imports, &mut buffer);
241 }
242 buffer
243}
244
245fn type_hint_stub(type_hint: &TypeHint, imports: &Imports, buffer: &mut String) {
246 match type_hint {
247 TypeHint::Ast(t) => imports.serialize_type_hint(t, buffer),
248 TypeHint::Plain(t) => buffer.push_str(t),
249 }
250}
251
252#[derive(Default)]
254struct Imports {
255 imports: Vec<String>,
257 renaming: BTreeMap<(String, String), String>,
259}
260
261impl Imports {
262 fn create(module: &Module, module_parents: &[&str]) -> Self {
270 let mut elements_used_in_annotations = ElementsUsedInAnnotations::new();
271 elements_used_in_annotations.walk_module(module);
272
273 let mut imports = Vec::new();
274 let mut renaming = BTreeMap::new();
275 let mut local_name_to_module_and_attribute = BTreeMap::new();
276
277 for name in module
279 .classes
280 .iter()
281 .map(|c| c.name.clone())
282 .chain(module.functions.iter().map(|f| f.name.clone()))
283 .chain(module.attributes.iter().map(|a| a.name.clone()))
284 .chain(elements_used_in_annotations.locals)
285 {
286 local_name_to_module_and_attribute.insert(name.clone(), (None, name.clone()));
287 }
288
289 let mut possible_current_module_names = vec![module.name.clone()];
291 let mut current_module_name = Some(module.name.clone());
292 for parent in module_parents.iter().rev() {
293 let path = if let Some(current) = current_module_name {
294 format!("{parent}.{current}")
295 } else {
296 parent.to_string()
297 };
298 possible_current_module_names.push(path.clone());
299 current_module_name = Some(path);
300 }
301
302 for (module, attrs) in elements_used_in_annotations.module_members {
304 let normalized_module = if possible_current_module_names.contains(&module) {
305 None
306 } else {
307 Some(module.clone())
308 };
309 let mut import_for_module = Vec::new();
310 for attr in attrs {
311 let (root_attr, attr_path) = attr
313 .split_once('.')
314 .map_or((attr.as_str(), None), |(root, path)| (root, Some(path)));
315 let mut local_name = root_attr.to_owned();
316 let mut already_imported = false;
317 while let Some((possible_conflict_module, possible_conflict_attr)) =
318 local_name_to_module_and_attribute.get(&local_name)
319 {
320 if *possible_conflict_module == normalized_module
321 && *possible_conflict_attr == root_attr
322 {
323 already_imported = true;
325 break;
326 }
327 let number_of_digits_at_the_end = local_name
330 .bytes()
331 .rev()
332 .take_while(|b| b.is_ascii_digit())
333 .count();
334 let (local_name_prefix, local_name_number) =
335 local_name.split_at(local_name.len() - number_of_digits_at_the_end);
336 local_name = format!(
337 "{local_name_prefix}{}",
338 u64::from_str(local_name_number).unwrap_or(1) + 1
339 );
340 }
341 renaming.insert(
342 (module.clone(), attr.clone()),
343 if let Some(attr_path) = attr_path {
344 format!("{local_name}.{attr_path}")
345 } else {
346 local_name.clone()
347 },
348 );
349 if !already_imported {
350 local_name_to_module_and_attribute.insert(
351 local_name.clone(),
352 (normalized_module.clone(), root_attr.to_owned()),
353 );
354 let is_not_aliased_builtin =
355 normalized_module.as_deref() == Some("builtins") && local_name == root_attr;
356 if !is_not_aliased_builtin {
357 import_for_module.push(if local_name == root_attr {
358 local_name
359 } else {
360 format!("{root_attr} as {local_name}")
361 });
362 }
363 }
364 }
365 if let Some(module) = normalized_module {
366 if !import_for_module.is_empty() {
367 imports.push(format!(
368 "from {module} import {}",
369 import_for_module.join(", ")
370 ));
371 }
372 }
373 }
374
375 Self { imports, renaming }
376 }
377
378 fn serialize_type_hint(&self, expr: &TypeHintExpr, buffer: &mut String) {
379 match expr {
380 TypeHintExpr::Identifier(id) => {
381 self.serialize_identifier(id, buffer);
382 }
383 TypeHintExpr::Union(elts) => {
384 for (i, elt) in elts.iter().enumerate() {
385 if i > 0 {
386 buffer.push_str(" | ");
387 }
388 self.serialize_type_hint(elt, buffer);
389 }
390 }
391 TypeHintExpr::Subscript { value, slice } => {
392 self.serialize_type_hint(value, buffer);
393 buffer.push('[');
394 for (i, elt) in slice.iter().enumerate() {
395 if i > 0 {
396 buffer.push_str(", ");
397 }
398 self.serialize_type_hint(elt, buffer);
399 }
400 buffer.push(']');
401 }
402 }
403 }
404
405 fn serialize_identifier(&self, id: &PythonIdentifier, buffer: &mut String) {
406 buffer.push_str(if let Some(module) = &id.module {
407 self.renaming
408 .get(&(module.clone(), id.name.clone()))
409 .expect("All type hint attributes should have been visited")
410 } else {
411 &id.name
412 });
413 }
414}
415
416struct ElementsUsedInAnnotations {
418 module_members: BTreeMap<String, BTreeSet<String>>,
420 locals: BTreeSet<String>,
421}
422
423impl ElementsUsedInAnnotations {
424 fn new() -> Self {
425 Self {
426 module_members: BTreeMap::new(),
427 locals: BTreeSet::new(),
428 }
429 }
430
431 fn walk_module(&mut self, module: &Module) {
432 for attr in &module.attributes {
433 self.walk_attribute(attr);
434 }
435 for class in &module.classes {
436 self.walk_class(class);
437 }
438 for function in &module.functions {
439 self.walk_function(function);
440 }
441 if module.incomplete {
442 self.module_members
443 .entry("builtins".into())
444 .or_default()
445 .insert("str".into());
446 self.module_members
447 .entry("_typeshed".into())
448 .or_default()
449 .insert("Incomplete".into());
450 }
451 }
452
453 fn walk_class(&mut self, class: &Class) {
454 for base in &class.bases {
455 self.walk_identifier(base);
456 }
457 for decorator in &class.decorators {
458 self.walk_identifier(decorator);
459 }
460 for method in &class.methods {
461 self.walk_function(method);
462 }
463 for attr in &class.attributes {
464 self.walk_attribute(attr);
465 }
466 }
467
468 fn walk_attribute(&mut self, attribute: &Attribute) {
469 if let Some(type_hint) = &attribute.annotation {
470 self.walk_type_hint(type_hint);
471 }
472 }
473
474 fn walk_function(&mut self, function: &Function) {
475 for decorator in &function.decorators {
476 self.walk_identifier(decorator);
477 }
478 for arg in function
479 .arguments
480 .positional_only_arguments
481 .iter()
482 .chain(&function.arguments.arguments)
483 .chain(&function.arguments.keyword_only_arguments)
484 {
485 if let Some(type_hint) = &arg.annotation {
486 self.walk_type_hint(type_hint);
487 }
488 }
489 for arg in function
490 .arguments
491 .vararg
492 .as_ref()
493 .iter()
494 .chain(&function.arguments.kwarg.as_ref())
495 {
496 if let Some(type_hint) = &arg.annotation {
497 self.walk_type_hint(type_hint);
498 }
499 }
500 if let Some(type_hint) = &function.returns {
501 self.walk_type_hint(type_hint);
502 }
503 }
504
505 fn walk_type_hint(&mut self, type_hint: &TypeHint) {
506 if let TypeHint::Ast(type_hint) = type_hint {
507 self.walk_type_hint_expr(type_hint);
508 }
509 }
510
511 fn walk_type_hint_expr(&mut self, expr: &TypeHintExpr) {
512 match expr {
513 TypeHintExpr::Identifier(id) => {
514 self.walk_identifier(id);
515 }
516 TypeHintExpr::Union(elts) => {
517 for elt in elts {
518 self.walk_type_hint_expr(elt)
519 }
520 }
521 TypeHintExpr::Subscript { value, slice } => {
522 self.walk_type_hint_expr(value);
523 for elt in slice {
524 self.walk_type_hint_expr(elt);
525 }
526 }
527 }
528 }
529
530 fn walk_identifier(&mut self, id: &PythonIdentifier) {
531 if let Some(module) = &id.module {
532 self.module_members
533 .entry(module.clone())
534 .or_default()
535 .insert(id.name.clone());
536 } else {
537 self.locals.insert(id.name.clone());
538 }
539 }
540}
541
542#[cfg(test)]
543mod tests {
544 use super::*;
545 use crate::model::Arguments;
546
547 #[test]
548 fn function_stubs_with_variable_length() {
549 let function = Function {
550 name: "func".into(),
551 decorators: Vec::new(),
552 arguments: Arguments {
553 positional_only_arguments: vec![Argument {
554 name: "posonly".into(),
555 default_value: None,
556 annotation: None,
557 }],
558 arguments: vec![Argument {
559 name: "arg".into(),
560 default_value: None,
561 annotation: None,
562 }],
563 vararg: Some(VariableLengthArgument {
564 name: "varargs".into(),
565 annotation: None,
566 }),
567 keyword_only_arguments: vec![Argument {
568 name: "karg".into(),
569 default_value: None,
570 annotation: Some(TypeHint::Plain("str".into())),
571 }],
572 kwarg: Some(VariableLengthArgument {
573 name: "kwarg".into(),
574 annotation: Some(TypeHint::Plain("str".into())),
575 }),
576 },
577 returns: Some(TypeHint::Plain("list[str]".into())),
578 };
579 assert_eq!(
580 "def func(posonly, /, arg, *varargs, karg: str, **kwarg: str) -> list[str]: ...",
581 function_stubs(&function, &Imports::default())
582 )
583 }
584
585 #[test]
586 fn function_stubs_without_variable_length() {
587 let function = Function {
588 name: "afunc".into(),
589 decorators: Vec::new(),
590 arguments: Arguments {
591 positional_only_arguments: vec![Argument {
592 name: "posonly".into(),
593 default_value: Some("1".into()),
594 annotation: None,
595 }],
596 arguments: vec![Argument {
597 name: "arg".into(),
598 default_value: Some("True".into()),
599 annotation: None,
600 }],
601 vararg: None,
602 keyword_only_arguments: vec![Argument {
603 name: "karg".into(),
604 default_value: Some("\"foo\"".into()),
605 annotation: Some(TypeHint::Plain("str".into())),
606 }],
607 kwarg: None,
608 },
609 returns: None,
610 };
611 assert_eq!(
612 "def afunc(posonly=1, /, arg=True, *, karg: str = \"foo\"): ...",
613 function_stubs(&function, &Imports::default())
614 )
615 }
616
617 #[test]
618 fn test_import() {
619 let big_type = TypeHintExpr::Subscript {
620 value: Box::new(
621 PythonIdentifier {
622 module: Some("builtins".into()),
623 name: "dict".into(),
624 }
625 .into(),
626 ),
627 slice: vec![
628 PythonIdentifier {
629 module: Some("foo.bar".into()),
630 name: "A".into(),
631 }
632 .into(),
633 TypeHintExpr::Union(vec![
634 PythonIdentifier {
635 module: Some("bar".into()),
636 name: "A".into(),
637 }
638 .into(),
639 PythonIdentifier {
640 module: Some("foo".into()),
641 name: "A.C".into(),
642 }
643 .into(),
644 PythonIdentifier {
645 module: Some("foo".into()),
646 name: "A.D".into(),
647 }
648 .into(),
649 PythonIdentifier {
650 module: Some("foo".into()),
651 name: "B".into(),
652 }
653 .into(),
654 PythonIdentifier {
655 module: Some("bat".into()),
656 name: "A".into(),
657 }
658 .into(),
659 PythonIdentifier {
660 module: None,
661 name: "int".into(),
662 }
663 .into(),
664 PythonIdentifier {
665 module: Some("builtins".into()),
666 name: "int".into(),
667 }
668 .into(),
669 PythonIdentifier {
670 module: Some("builtins".into()),
671 name: "float".into(),
672 }
673 .into(),
674 ]),
675 ],
676 };
677 let imports = Imports::create(
678 &Module {
679 name: "bar".into(),
680 modules: Vec::new(),
681 classes: vec![Class {
682 name: "A".into(),
683 bases: vec![PythonIdentifier {
684 module: Some("builtins".into()),
685 name: "dict".into(),
686 }],
687 methods: Vec::new(),
688 attributes: Vec::new(),
689 decorators: vec![PythonIdentifier {
690 module: Some("typing".into()),
691 name: "final".into(),
692 }],
693 }],
694 functions: vec![Function {
695 name: String::new(),
696 decorators: Vec::new(),
697 arguments: Arguments {
698 positional_only_arguments: Vec::new(),
699 arguments: Vec::new(),
700 vararg: None,
701 keyword_only_arguments: Vec::new(),
702 kwarg: None,
703 },
704 returns: Some(TypeHint::Ast(big_type.clone())),
705 }],
706 attributes: Vec::new(),
707 incomplete: true,
708 },
709 &["foo"],
710 );
711 assert_eq!(
712 &imports.imports,
713 &[
714 "from _typeshed import Incomplete",
715 "from bat import A as A2",
716 "from builtins import int as int2",
717 "from foo import A as A3, B",
718 "from typing import final"
719 ]
720 );
721 let mut output = String::new();
722 imports.serialize_type_hint(&big_type, &mut output);
723 assert_eq!(
724 output,
725 "dict[A, A | A3.C | A3.D | B | A2 | int | int2 | float]"
726 );
727 }
728}