blob: 2fe918f4d82aa18e71a4fe78aa2551b08065012c [file] [log] [blame] [edit]
// SPDX-License-Identifier: Apache-2.0 OR MIT
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, quote_spanned};
use syn::{
braced,
parse::{End, Parse},
parse_quote,
punctuated::Punctuated,
spanned::Spanned,
token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
};
use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
pub(crate) struct Initializer {
attrs: Vec<InitializerAttribute>,
this: Option<This>,
path: Path,
brace_token: token::Brace,
fields: Punctuated<InitializerField, Token![,]>,
rest: Option<(Token![..], Expr)>,
error: Option<(Token![?], Type)>,
}
struct This {
_and_token: Token![&],
ident: Ident,
_in_token: Token![in],
}
struct InitializerField {
attrs: Vec<Attribute>,
kind: InitializerKind,
}
enum InitializerKind {
Value {
ident: Ident,
value: Option<(Token![:], Expr)>,
},
Init {
ident: Ident,
_left_arrow_token: Token![<-],
value: Expr,
},
Code {
_underscore_token: Token![_],
_colon_token: Token![:],
block: Block,
},
}
impl InitializerKind {
fn ident(&self) -> Option<&Ident> {
match self {
Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident),
Self::Code { .. } => None,
}
}
}
enum InitializerAttribute {
DefaultError(DefaultErrorAttribute),
}
struct DefaultErrorAttribute {
ty: Box<Type>,
}
pub(crate) fn expand(
Initializer {
attrs,
this,
path,
brace_token,
fields,
rest,
error,
}: Initializer,
default_error: Option<&'static str>,
pinned: bool,
dcx: &mut DiagCtxt,
) -> Result<TokenStream, ErrorGuaranteed> {
let error = error.map_or_else(
|| {
if let Some(default_error) = attrs.iter().fold(None, |acc, attr| {
#[expect(irrefutable_let_patterns)]
if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
Some(ty.clone())
} else {
acc
}
}) {
default_error
} else if let Some(default_error) = default_error {
syn::parse_str(default_error).unwrap()
} else {
dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
parse_quote!(::core::convert::Infallible)
}
},
|(_, err)| Box::new(err),
);
let slot = format_ident!("slot");
let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned {
(
format_ident!("HasPinData"),
format_ident!("PinData"),
format_ident!("__pin_data"),
format_ident!("pin_init_from_closure"),
)
} else {
(
format_ident!("HasInitData"),
format_ident!("InitData"),
format_ident!("__init_data"),
format_ident!("init_from_closure"),
)
};
let init_kind = get_init_kind(rest, dcx);
let zeroable_check = match init_kind {
InitKind::Normal => quote!(),
InitKind::Zeroing => quote! {
// The user specified `..Zeroable::zeroed()` at the end of the list of fields.
// Therefore we check if the struct implements `Zeroable` and then zero the memory.
// This allows us to also remove the check that all fields are present (since we
// already set the memory to zero and that is a valid bit pattern).
fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T)
where T: ::pin_init::Zeroable
{}
// Ensure that the struct is indeed `Zeroable`.
assert_zeroable(#slot);
// SAFETY: The type implements `Zeroable` by the check above.
unsafe { ::core::ptr::write_bytes(#slot, 0, 1) };
},
};
let this = match this {
None => quote!(),
Some(This { ident, .. }) => quote! {
// Create the `this` so it can be referenced by the user inside of the
// expressions creating the individual fields.
let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };
},
};
// `mixed_site` ensures that the data is not accessible to the user-controlled code.
let data = Ident::new("__data", Span::mixed_site());
let init_fields = init_fields(&fields, pinned, &data, &slot);
let field_check = make_field_check(&fields, init_kind, &path);
Ok(quote! {{
// Get the data about fields from the supplied type.
// SAFETY: TODO
let #data = unsafe {
use ::pin_init::__internal::#has_data_trait;
// Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit
// generics (which need to be present with that syntax).
#path::#get_data()
};
// Ensure that `#data` really is of type `#data` and help with type inference:
let init = ::pin_init::__internal::#data_trait::make_closure::<_, #error>(
#data,
move |slot| {
#zeroable_check
#this
#init_fields
#field_check
// SAFETY: we are the `init!` macro that is allowed to call this.
Ok(unsafe { ::pin_init::__internal::InitOk::new() })
}
);
let init = move |slot| -> ::core::result::Result<(), #error> {
init(slot).map(|__InitOk| ())
};
// SAFETY: TODO
let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) };
init
}})
}
enum InitKind {
Normal,
Zeroing,
}
fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind {
let Some((dotdot, expr)) = rest else {
return InitKind::Normal;
};
match &expr {
Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func {
Expr::Path(ExprPath {
attrs,
qself: None,
path:
Path {
leading_colon: None,
segments,
},
}) if attrs.is_empty()
&& segments.len() == 2
&& segments[0].ident == "Zeroable"
&& segments[0].arguments.is_none()
&& segments[1].ident == "init_zeroed"
&& segments[1].arguments.is_none() =>
{
return InitKind::Zeroing;
}
_ => {}
},
_ => {}
}
dcx.error(
dotdot.span().join(expr.span()).unwrap_or(expr.span()),
"expected nothing or `..Zeroable::init_zeroed()`.",
);
InitKind::Normal
}
/// Generate the code that initializes the fields of the struct using the initializers in `field`.
fn init_fields(
fields: &Punctuated<InitializerField, Token![,]>,
pinned: bool,
data: &Ident,
slot: &Ident,
) -> TokenStream {
let mut guards = vec![];
let mut guard_attrs = vec![];
let mut res = TokenStream::new();
for InitializerField { attrs, kind } in fields {
let cfgs = {
let mut cfgs = attrs.clone();
cfgs.retain(|attr| attr.path().is_ident("cfg"));
cfgs
};
let init = match kind {
InitializerKind::Value { ident, value } => {
let mut value_ident = ident.clone();
let value_prep = value.as_ref().map(|value| &value.1).map(|value| {
// Setting the span of `value_ident` to `value`'s span improves error messages
// when the type of `value` is wrong.
value_ident.set_span(value.span());
quote!(let #value_ident = #value;)
});
// Again span for better diagnostics
let write = quote_spanned!(ident.span()=> ::core::ptr::write);
// NOTE: the field accessor ensures that the initialized field is properly aligned.
// Unaligned fields will cause the compiler to emit E0793. We do not support
// unaligned fields since `Init::__init` requires an aligned pointer; the call to
// `ptr::write` below has the same requirement.
let accessor = if pinned {
let project_ident = format_ident!("__project_{ident}");
quote! {
// SAFETY: TODO
unsafe { #data.#project_ident(&mut (*#slot).#ident) }
}
} else {
quote! {
// SAFETY: TODO
unsafe { &mut (*#slot).#ident }
}
};
quote! {
#(#attrs)*
{
#value_prep
// SAFETY: TODO
unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) };
}
#(#cfgs)*
#[allow(unused_variables)]
let #ident = #accessor;
}
}
InitializerKind::Init { ident, value, .. } => {
// Again span for better diagnostics
let init = format_ident!("init", span = value.span());
// NOTE: the field accessor ensures that the initialized field is properly aligned.
// Unaligned fields will cause the compiler to emit E0793. We do not support
// unaligned fields since `Init::__init` requires an aligned pointer; the call to
// `ptr::write` below has the same requirement.
let (value_init, accessor) = if pinned {
let project_ident = format_ident!("__project_{ident}");
(
quote! {
// SAFETY:
// - `slot` is valid, because we are inside of an initializer closure, we
// return when an error/panic occurs.
// - We also use `#data` to require the correct trait (`Init` or `PinInit`)
// for `#ident`.
unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? };
},
quote! {
// SAFETY: TODO
unsafe { #data.#project_ident(&mut (*#slot).#ident) }
},
)
} else {
(
quote! {
// SAFETY: `slot` is valid, because we are inside of an initializer
// closure, we return when an error/panic occurs.
unsafe {
::pin_init::Init::__init(
#init,
::core::ptr::addr_of_mut!((*#slot).#ident),
)?
};
},
quote! {
// SAFETY: TODO
unsafe { &mut (*#slot).#ident }
},
)
};
quote! {
#(#attrs)*
{
let #init = #value;
#value_init
}
#(#cfgs)*
#[allow(unused_variables)]
let #ident = #accessor;
}
}
InitializerKind::Code { block: value, .. } => quote! {
#(#attrs)*
#[allow(unused_braces)]
#value
},
};
res.extend(init);
if let Some(ident) = kind.ident() {
// `mixed_site` ensures that the guard is not accessible to the user-controlled code.
let guard = format_ident!("__{ident}_guard", span = Span::mixed_site());
res.extend(quote! {
#(#cfgs)*
// Create the drop guard:
//
// We rely on macro hygiene to make it impossible for users to access this local
// variable.
// SAFETY: We forget the guard later when initialization has succeeded.
let #guard = unsafe {
::pin_init::__internal::DropGuard::new(
::core::ptr::addr_of_mut!((*slot).#ident)
)
};
});
guards.push(guard);
guard_attrs.push(cfgs);
}
}
quote! {
#res
// If execution reaches this point, all fields have been initialized. Therefore we can now
// dismiss the guards by forgetting them.
#(
#(#guard_attrs)*
::core::mem::forget(#guards);
)*
}
}
/// Generate the check for ensuring that every field has been initialized.
fn make_field_check(
fields: &Punctuated<InitializerField, Token![,]>,
init_kind: InitKind,
path: &Path,
) -> TokenStream {
let field_attrs = fields
.iter()
.filter_map(|f| f.kind.ident().map(|_| &f.attrs));
let field_name = fields.iter().filter_map(|f| f.kind.ident());
match init_kind {
InitKind::Normal => quote! {
// We use unreachable code to ensure that all fields have been mentioned exactly once,
// this struct initializer will still be type-checked and complain with a very natural
// error message if a field is forgotten/mentioned more than once.
#[allow(unreachable_code, clippy::diverging_sub_expression)]
// SAFETY: this code is never executed.
let _ = || unsafe {
::core::ptr::write(slot, #path {
#(
#(#field_attrs)*
#field_name: ::core::panic!(),
)*
})
};
},
InitKind::Zeroing => quote! {
// We use unreachable code to ensure that all fields have been mentioned at most once.
// Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will
// be zeroed. This struct initializer will still be type-checked and complain with a
// very natural error message if a field is mentioned more than once, or doesn't exist.
#[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)]
// SAFETY: this code is never executed.
let _ = || unsafe {
::core::ptr::write(slot, #path {
#(
#(#field_attrs)*
#field_name: ::core::panic!(),
)*
..::core::mem::zeroed()
})
};
},
}
}
impl Parse for Initializer {
fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
let attrs = input.call(Attribute::parse_outer)?;
let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
let path = input.parse()?;
let content;
let brace_token = braced!(content in input);
let mut fields = Punctuated::new();
loop {
let lh = content.lookahead1();
if lh.peek(End) || lh.peek(Token![..]) {
break;
} else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) {
fields.push_value(content.parse()?);
let lh = content.lookahead1();
if lh.peek(End) {
break;
} else if lh.peek(Token![,]) {
fields.push_punct(content.parse()?);
} else {
return Err(lh.error());
}
} else {
return Err(lh.error());
}
}
let rest = content
.peek(Token![..])
.then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?)))
.transpose()?;
let error = input
.peek(Token![?])
.then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
.transpose()?;
let attrs = attrs
.into_iter()
.map(|a| {
if a.path().is_ident("default_error") {
a.parse_args::<DefaultErrorAttribute>()
.map(InitializerAttribute::DefaultError)
} else {
Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
}
})
.collect::<Result<Vec<_>, _>>()?;
Ok(Self {
attrs,
this,
path,
brace_token,
fields,
rest,
error,
})
}
}
impl Parse for DefaultErrorAttribute {
fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
Ok(Self { ty: input.parse()? })
}
}
impl Parse for This {
fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
Ok(Self {
_and_token: input.parse()?,
ident: input.parse()?,
_in_token: input.parse()?,
})
}
}
impl Parse for InitializerField {
fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
let attrs = input.call(Attribute::parse_outer)?;
Ok(Self {
attrs,
kind: input.parse()?,
})
}
}
impl Parse for InitializerKind {
fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
let lh = input.lookahead1();
if lh.peek(Token![_]) {
Ok(Self::Code {
_underscore_token: input.parse()?,
_colon_token: input.parse()?,
block: input.parse()?,
})
} else if lh.peek(Ident) {
let ident = input.parse()?;
let lh = input.lookahead1();
if lh.peek(Token![<-]) {
Ok(Self::Init {
ident,
_left_arrow_token: input.parse()?,
value: input.parse()?,
})
} else if lh.peek(Token![:]) {
Ok(Self::Value {
ident,
value: Some((input.parse()?, input.parse()?)),
})
} else if lh.peek(Token![,]) || lh.peek(End) {
Ok(Self::Value { ident, value: None })
} else {
Err(lh.error())
}
} else {
Err(lh.error())
}
}
}