use crate::attributes::{self, get_pyo3_options, CrateAttribute, FromPyWithAttribute};
use crate::utils::Ctx;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{
parenthesized,
parse::{Parse, ParseStream},
parse_quote,
punctuated::Punctuated,
spanned::Spanned,
Attribute, DataEnum, DeriveInput, Fields, Ident, LitStr, Result, Token,
};
struct Enum<'a> {
enum_ident: &'a Ident,
variants: Vec<Container<'a>>,
}
impl<'a> Enum<'a> {
fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result<Self> {
ensure_spanned!(
!data_enum.variants.is_empty(),
ident.span() => "cannot derive FromPyObject for empty enum"
);
let variants = data_enum
.variants
.iter()
.map(|variant| {
let attrs = ContainerOptions::from_attrs(&variant.attrs)?;
let var_ident = &variant.ident;
Container::new(&variant.fields, parse_quote!(#ident::#var_ident), attrs)
})
.collect::<Result<Vec<_>>>()?;
Ok(Enum {
enum_ident: ident,
variants,
})
}
fn build(&self, ctx: &Ctx) -> TokenStream {
let Ctx { pyo3_path, .. } = ctx;
let mut var_extracts = Vec::new();
let mut variant_names = Vec::new();
let mut error_names = Vec::new();
for var in &self.variants {
let struct_derive = var.build(ctx);
let ext = quote!({
let maybe_ret = || -> #pyo3_path::PyResult<Self> {
#struct_derive
}();
match maybe_ret {
ok @ ::std::result::Result::Ok(_) => return ok,
::std::result::Result::Err(err) => err
}
});
var_extracts.push(ext);
variant_names.push(var.path.segments.last().unwrap().ident.to_string());
error_names.push(&var.err_name);
}
let ty_name = self.enum_ident.to_string();
quote!(
let errors = [
#(#var_extracts),*
];
::std::result::Result::Err(
#pyo3_path::impl_::frompyobject::failed_to_extract_enum(
obj.py(),
#ty_name,
&[#(#variant_names),*],
&[#(#error_names),*],
&errors
)
)
)
}
}
struct NamedStructField<'a> {
ident: &'a syn::Ident,
getter: Option<FieldGetter>,
from_py_with: Option<FromPyWithAttribute>,
}
struct TupleStructField {
from_py_with: Option<FromPyWithAttribute>,
}
enum ContainerType<'a> {
Struct(Vec<NamedStructField<'a>>),
StructNewtype(&'a syn::Ident, Option<FromPyWithAttribute>),
Tuple(Vec<TupleStructField>),
TupleNewtype(Option<FromPyWithAttribute>),
}
struct Container<'a> {
path: syn::Path,
ty: ContainerType<'a>,
err_name: String,
}
impl<'a> Container<'a> {
fn new(fields: &'a Fields, path: syn::Path, options: ContainerOptions) -> Result<Self> {
let style = match fields {
Fields::Unnamed(unnamed) if !unnamed.unnamed.is_empty() => {
let mut tuple_fields = unnamed
.unnamed
.iter()
.map(|field| {
let attrs = FieldPyO3Attributes::from_attrs(&field.attrs)?;
ensure_spanned!(
attrs.getter.is_none(),
field.span() => "`getter` is not permitted on tuple struct elements."
);
Ok(TupleStructField {
from_py_with: attrs.from_py_with,
})
})
.collect::<Result<Vec<_>>>()?;
if tuple_fields.len() == 1 {
let field = tuple_fields.pop().unwrap();
ContainerType::TupleNewtype(field.from_py_with)
} else if options.transparent {
bail_spanned!(
fields.span() => "transparent structs and variants can only have 1 field"
);
} else {
ContainerType::Tuple(tuple_fields)
}
}
Fields::Named(named) if !named.named.is_empty() => {
let mut struct_fields = named
.named
.iter()
.map(|field| {
let ident = field
.ident
.as_ref()
.expect("Named fields should have identifiers");
let mut attrs = FieldPyO3Attributes::from_attrs(&field.attrs)?;
if let Some(ref from_item_all) = options.from_item_all {
if let Some(replaced) = attrs.getter.replace(FieldGetter::GetItem(None))
{
match replaced {
FieldGetter::GetItem(Some(item_name)) => {
attrs.getter = Some(FieldGetter::GetItem(Some(item_name)));
}
FieldGetter::GetItem(None) => bail_spanned!(from_item_all.span() => "Useless `item` - the struct is already annotated with `from_item_all`"),
FieldGetter::GetAttr(_) => bail_spanned!(
from_item_all.span() => "The struct is already annotated with `from_item_all`, `attribute` is not allowed"
),
}
}
}
Ok(NamedStructField {
ident,
getter: attrs.getter,
from_py_with: attrs.from_py_with,
})
})
.collect::<Result<Vec<_>>>()?;
if options.transparent {
ensure_spanned!(
struct_fields.len() == 1,
fields.span() => "transparent structs and variants can only have 1 field"
);
let field = struct_fields.pop().unwrap();
ensure_spanned!(
field.getter.is_none(),
field.ident.span() => "`transparent` structs may not have a `getter` for the inner field"
);
ContainerType::StructNewtype(field.ident, field.from_py_with)
} else {
ContainerType::Struct(struct_fields)
}
}
_ => bail_spanned!(
fields.span() => "cannot derive FromPyObject for empty structs and variants"
),
};
let err_name = options.annotation.map_or_else(
|| path.segments.last().unwrap().ident.to_string(),
|lit_str| lit_str.value(),
);
let v = Container {
path,
ty: style,
err_name,
};
Ok(v)
}
fn name(&self) -> String {
let mut value = String::new();
for segment in &self.path.segments {
if !value.is_empty() {
value.push_str("::");
}
value.push_str(&segment.ident.to_string());
}
value
}
fn build(&self, ctx: &Ctx) -> TokenStream {
match &self.ty {
ContainerType::StructNewtype(ident, from_py_with) => {
self.build_newtype_struct(Some(ident), from_py_with, ctx)
}
ContainerType::TupleNewtype(from_py_with) => {
self.build_newtype_struct(None, from_py_with, ctx)
}
ContainerType::Tuple(tups) => self.build_tuple_struct(tups, ctx),
ContainerType::Struct(tups) => self.build_struct(tups, ctx),
}
}
fn build_newtype_struct(
&self,
field_ident: Option<&Ident>,
from_py_with: &Option<FromPyWithAttribute>,
ctx: &Ctx,
) -> TokenStream {
let Ctx { pyo3_path, .. } = ctx;
let self_ty = &self.path;
let struct_name = self.name();
if let Some(ident) = field_ident {
let field_name = ident.to_string();
match from_py_with {
None => quote! {
Ok(#self_ty {
#ident: #pyo3_path::impl_::frompyobject::extract_struct_field(obj, #struct_name, #field_name)?
})
},
Some(FromPyWithAttribute {
value: expr_path, ..
}) => quote! {
Ok(#self_ty {
#ident: #pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, obj, #struct_name, #field_name)?
})
},
}
} else {
match from_py_with {
None => quote! {
#pyo3_path::impl_::frompyobject::extract_tuple_struct_field(obj, #struct_name, 0).map(#self_ty)
},
Some(FromPyWithAttribute {
value: expr_path, ..
}) => quote! {
#pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path as fn(_) -> _, obj, #struct_name, 0).map(#self_ty)
},
}
}
}
fn build_tuple_struct(&self, struct_fields: &[TupleStructField], ctx: &Ctx) -> TokenStream {
let Ctx { pyo3_path, .. } = ctx;
let self_ty = &self.path;
let struct_name = &self.name();
let field_idents: Vec<_> = (0..struct_fields.len())
.map(|i| format_ident!("arg{}", i))
.collect();
let fields = struct_fields.iter().zip(&field_idents).enumerate().map(|(index, (field, ident))| {
match &field.from_py_with {
None => quote!(
#pyo3_path::impl_::frompyobject::extract_tuple_struct_field(&#ident, #struct_name, #index)?
),
Some(FromPyWithAttribute {
value: expr_path, ..
}) => quote! (
#pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path as fn(_) -> _, &#ident, #struct_name, #index)?
),
}
});
quote!(
match #pyo3_path::types::PyAnyMethods::extract(obj) {
::std::result::Result::Ok((#(#field_idents),*)) => ::std::result::Result::Ok(#self_ty(#(#fields),*)),
::std::result::Result::Err(err) => ::std::result::Result::Err(err),
}
)
}
fn build_struct(&self, struct_fields: &[NamedStructField<'_>], ctx: &Ctx) -> TokenStream {
let Ctx { pyo3_path, .. } = ctx;
let self_ty = &self.path;
let struct_name = &self.name();
let mut fields: Punctuated<TokenStream, syn::Token![,]> = Punctuated::new();
for field in struct_fields {
let ident = &field.ident;
let field_name = ident.to_string();
let getter = match field.getter.as_ref().unwrap_or(&FieldGetter::GetAttr(None)) {
FieldGetter::GetAttr(Some(name)) => {
quote!(#pyo3_path::types::PyAnyMethods::getattr(obj, #pyo3_path::intern!(obj.py(), #name)))
}
FieldGetter::GetAttr(None) => {
quote!(#pyo3_path::types::PyAnyMethods::getattr(obj, #pyo3_path::intern!(obj.py(), #field_name)))
}
FieldGetter::GetItem(Some(syn::Lit::Str(key))) => {
quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #key)))
}
FieldGetter::GetItem(Some(key)) => {
quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #key))
}
FieldGetter::GetItem(None) => {
quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #field_name)))
}
};
let extractor = match &field.from_py_with {
None => {
quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&#getter?, #struct_name, #field_name)?)
}
Some(FromPyWithAttribute {
value: expr_path, ..
}) => {
quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &#getter?, #struct_name, #field_name)?)
}
};
fields.push(quote!(#ident: #extractor));
}
quote!(::std::result::Result::Ok(#self_ty{#fields}))
}
}
#[derive(Default)]
struct ContainerOptions {
transparent: bool,
from_item_all: Option<attributes::kw::from_item_all>,
annotation: Option<syn::LitStr>,
krate: Option<CrateAttribute>,
}
enum ContainerPyO3Attribute {
Transparent(attributes::kw::transparent),
ItemAll(attributes::kw::from_item_all),
ErrorAnnotation(LitStr),
Crate(CrateAttribute),
}
impl Parse for ContainerPyO3Attribute {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::transparent) {
let kw: attributes::kw::transparent = input.parse()?;
Ok(ContainerPyO3Attribute::Transparent(kw))
} else if lookahead.peek(attributes::kw::from_item_all) {
let kw: attributes::kw::from_item_all = input.parse()?;
Ok(ContainerPyO3Attribute::ItemAll(kw))
} else if lookahead.peek(attributes::kw::annotation) {
let _: attributes::kw::annotation = input.parse()?;
let _: Token![=] = input.parse()?;
input.parse().map(ContainerPyO3Attribute::ErrorAnnotation)
} else if lookahead.peek(Token![crate]) {
input.parse().map(ContainerPyO3Attribute::Crate)
} else {
Err(lookahead.error())
}
}
}
impl ContainerOptions {
fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
let mut options = ContainerOptions::default();
for attr in attrs {
if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
for pyo3_attr in pyo3_attrs {
match pyo3_attr {
ContainerPyO3Attribute::Transparent(kw) => {
ensure_spanned!(
!options.transparent,
kw.span() => "`transparent` may only be provided once"
);
options.transparent = true;
}
ContainerPyO3Attribute::ItemAll(kw) => {
ensure_spanned!(
options.from_item_all.is_none(),
kw.span() => "`from_item_all` may only be provided once"
);
options.from_item_all = Some(kw);
}
ContainerPyO3Attribute::ErrorAnnotation(lit_str) => {
ensure_spanned!(
options.annotation.is_none(),
lit_str.span() => "`annotation` may only be provided once"
);
options.annotation = Some(lit_str);
}
ContainerPyO3Attribute::Crate(path) => {
ensure_spanned!(
options.krate.is_none(),
path.span() => "`crate` may only be provided once"
);
options.krate = Some(path);
}
}
}
}
}
Ok(options)
}
}
#[derive(Clone, Debug)]
struct FieldPyO3Attributes {
getter: Option<FieldGetter>,
from_py_with: Option<FromPyWithAttribute>,
}
#[derive(Clone, Debug)]
enum FieldGetter {
GetItem(Option<syn::Lit>),
GetAttr(Option<LitStr>),
}
enum FieldPyO3Attribute {
Getter(FieldGetter),
FromPyWith(FromPyWithAttribute),
}
impl Parse for FieldPyO3Attribute {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::attribute) {
let _: attributes::kw::attribute = input.parse()?;
if input.peek(syn::token::Paren) {
let content;
let _ = parenthesized!(content in input);
let attr_name: LitStr = content.parse()?;
if !content.is_empty() {
return Err(content.error(
"expected at most one argument: `attribute` or `attribute(\"name\")`",
));
}
ensure_spanned!(
!attr_name.value().is_empty(),
attr_name.span() => "attribute name cannot be empty"
);
Ok(FieldPyO3Attribute::Getter(FieldGetter::GetAttr(Some(
attr_name,
))))
} else {
Ok(FieldPyO3Attribute::Getter(FieldGetter::GetAttr(None)))
}
} else if lookahead.peek(attributes::kw::item) {
let _: attributes::kw::item = input.parse()?;
if input.peek(syn::token::Paren) {
let content;
let _ = parenthesized!(content in input);
let key = content.parse()?;
if !content.is_empty() {
return Err(
content.error("expected at most one argument: `item` or `item(key)`")
);
}
Ok(FieldPyO3Attribute::Getter(FieldGetter::GetItem(Some(key))))
} else {
Ok(FieldPyO3Attribute::Getter(FieldGetter::GetItem(None)))
}
} else if lookahead.peek(attributes::kw::from_py_with) {
input.parse().map(FieldPyO3Attribute::FromPyWith)
} else {
Err(lookahead.error())
}
}
}
impl FieldPyO3Attributes {
fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
let mut getter = None;
let mut from_py_with = None;
for attr in attrs {
if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
for pyo3_attr in pyo3_attrs {
match pyo3_attr {
FieldPyO3Attribute::Getter(field_getter) => {
ensure_spanned!(
getter.is_none(),
attr.span() => "only one of `attribute` or `item` can be provided"
);
getter = Some(field_getter);
}
FieldPyO3Attribute::FromPyWith(from_py_with_attr) => {
ensure_spanned!(
from_py_with.is_none(),
attr.span() => "`from_py_with` may only be provided once"
);
from_py_with = Some(from_py_with_attr);
}
}
}
}
}
Ok(FieldPyO3Attributes {
getter,
from_py_with,
})
}
}
fn verify_and_get_lifetime(generics: &syn::Generics) -> Result<Option<&syn::LifetimeParam>> {
let mut lifetimes = generics.lifetimes();
let lifetime = lifetimes.next();
ensure_spanned!(
lifetimes.next().is_none(),
generics.span() => "FromPyObject can be derived with at most one lifetime parameter"
);
Ok(lifetime)
}
pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
let options = ContainerOptions::from_attrs(&tokens.attrs)?;
let ctx = &Ctx::new(&options.krate, None);
let Ctx { pyo3_path, .. } = &ctx;
let (_, ty_generics, _) = tokens.generics.split_for_impl();
let mut trait_generics = tokens.generics.clone();
let lt_param = if let Some(lt) = verify_and_get_lifetime(&trait_generics)? {
lt.clone()
} else {
trait_generics.params.push(parse_quote!('py));
parse_quote!('py)
};
let (impl_generics, _, where_clause) = trait_generics.split_for_impl();
let mut where_clause = where_clause.cloned().unwrap_or_else(|| parse_quote!(where));
for param in trait_generics.type_params() {
let gen_ident = ¶m.ident;
where_clause
.predicates
.push(parse_quote!(#gen_ident: #pyo3_path::FromPyObject<'py>))
}
let derives = match &tokens.data {
syn::Data::Enum(en) => {
if options.transparent || options.annotation.is_some() {
bail_spanned!(tokens.span() => "`transparent` or `annotation` is not supported \
at top level for enums");
}
let en = Enum::new(en, &tokens.ident)?;
en.build(ctx)
}
syn::Data::Struct(st) => {
if let Some(lit_str) = &options.annotation {
bail_spanned!(lit_str.span() => "`annotation` is unsupported for structs");
}
let ident = &tokens.ident;
let st = Container::new(&st.fields, parse_quote!(#ident), options)?;
st.build(ctx)
}
syn::Data::Union(_) => bail_spanned!(
tokens.span() => "#[derive(FromPyObject)] is not supported for unions"
),
};
let ident = &tokens.ident;
Ok(quote!(
#[automatically_derived]
impl #impl_generics #pyo3_path::FromPyObject<#lt_param> for #ident #ty_generics #where_clause {
fn extract_bound(obj: &#pyo3_path::Bound<#lt_param, #pyo3_path::PyAny>) -> #pyo3_path::PyResult<Self> {
#derives
}
}
))
}