Skip to content

Commit fbd8b33

Browse files
Rework Rust method chaining system.
1 parent 2e00369 commit fbd8b33

9 files changed

Lines changed: 348 additions & 55 deletions

File tree

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
use anyhow::{Result, bail};
2+
use std::collections::HashSet;
3+
use std::fmt;
4+
use wit_parser::{Function, FunctionKind, Resolve, WorldKey};
5+
6+
/// Structure used to parse the command line argument `--chainable-method` consistently
7+
/// across guest generators.
8+
#[cfg_attr(feature = "clap", derive(clap::Parser))]
9+
#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
10+
#[derive(Clone, Default, Debug)]
11+
pub struct ChainableMethodFilterSet {
12+
/// Determines which resource methods should have chaining enabled.
13+
/// Chaining takes a WIT method import returning nothing, and modifies bindgen
14+
/// in a language-dependent way to return `self` in the glue code. This does
15+
/// not affect the ABI in any way.
16+
///
17+
/// This option can be passed multiple times and additionally accepts
18+
/// comma-separated values for each option passed. Each individual argument
19+
/// passed here can be one of:
20+
///
21+
/// - `all` - all applicable methods will be chainable
22+
/// - `-all` - no methods will be chainable
23+
/// - `foo:bar/baz#my-resource` - enable chaining for all methods in a resource
24+
/// - `foo:bar/baz#my-resource.some-method` - enable chaining for particular method
25+
///
26+
/// Options are processed in the order they are passed here, so if a method
27+
/// matches two directives passed the least-specific one should be last.
28+
#[cfg_attr(
29+
feature = "clap",
30+
arg(
31+
long = "chainable-methods",
32+
value_parser = parse_chainable_method,
33+
value_delimiter =',',
34+
value_name = "FILTER",
35+
),
36+
)]
37+
chainable_methods: Vec<ChainableMethod>,
38+
39+
#[cfg_attr(feature = "clap", arg(skip))]
40+
#[cfg_attr(feature = "serde", serde(skip))]
41+
used_options: HashSet<usize>,
42+
}
43+
44+
#[cfg(feature = "clap")]
45+
fn parse_chainable_method(s: &str) -> Result<ChainableMethod, String> {
46+
Ok(ChainableMethod::parse(s))
47+
}
48+
49+
impl ChainableMethodFilterSet {
50+
/// Returns a set where all functions should be chainable or not depending on
51+
/// `enable` provided.
52+
pub fn all(enable: bool) -> ChainableMethodFilterSet {
53+
ChainableMethodFilterSet {
54+
chainable_methods: vec![ChainableMethod {
55+
enabled: enable,
56+
filter: ChainableMethodFilter::All,
57+
}],
58+
used_options: HashSet::new(),
59+
}
60+
}
61+
62+
/// Returns whether the `func` provided should be made chainable
63+
pub fn should_be_chainable(
64+
&mut self,
65+
resolve: &Resolve,
66+
interface: Option<&WorldKey>,
67+
func: &Function,
68+
is_import: bool,
69+
) -> bool {
70+
if !is_import {
71+
return false;
72+
}
73+
74+
if func.result.is_some() {
75+
return false;
76+
}
77+
78+
match func.kind {
79+
FunctionKind::AsyncMethod(resource) | FunctionKind::Method(resource) => {
80+
let interface_name = match interface.map(|key| resolve.name_world_key(key)) {
81+
Some(str) => str + "#",
82+
None => "".into(),
83+
};
84+
85+
let resource_name_to_test = format!(
86+
"{}{}",
87+
interface_name,
88+
resolve.types[resource].name.as_ref().unwrap()
89+
);
90+
91+
let method_name_to_test = format!("{}{}", interface_name, func.name);
92+
93+
for (i, opt) in self.chainable_methods.iter().enumerate() {
94+
match &opt.filter {
95+
ChainableMethodFilter::All => {
96+
self.used_options.insert(i);
97+
return opt.enabled;
98+
}
99+
ChainableMethodFilter::Resource(s) => {
100+
if *s == resource_name_to_test {
101+
self.used_options.insert(i);
102+
return opt.enabled;
103+
}
104+
}
105+
ChainableMethodFilter::Method(s) => {
106+
if *s == method_name_to_test {
107+
self.used_options.insert(i);
108+
return opt.enabled;
109+
}
110+
}
111+
};
112+
}
113+
114+
return false;
115+
}
116+
_ => {
117+
return false;
118+
}
119+
}
120+
}
121+
122+
/// Intended to be used in the header comment of generated code to help
123+
/// indicate what options were specified.
124+
pub fn debug_opts(&self) -> impl Iterator<Item = String> + '_ {
125+
self.chainable_methods.iter().map(|opt| opt.to_string())
126+
}
127+
128+
/// Tests whether all `--chainable-method` options were used throughout bindings
129+
/// generation, returning an error if any were unused.
130+
pub fn ensure_all_used(&self) -> Result<()> {
131+
for (i, opt) in self.chainable_methods.iter().enumerate() {
132+
if self.used_options.contains(&i) {
133+
continue;
134+
}
135+
if !matches!(opt.filter, ChainableMethodFilter::All) {
136+
bail!("unused chainable option: {opt}");
137+
}
138+
}
139+
Ok(())
140+
}
141+
142+
/// Pushes a new option into this set.
143+
pub fn push(&mut self, directive: &str) {
144+
self.chainable_methods
145+
.push(ChainableMethod::parse(directive));
146+
}
147+
}
148+
149+
#[derive(Debug, Clone)]
150+
#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
151+
struct ChainableMethod {
152+
enabled: bool,
153+
filter: ChainableMethodFilter,
154+
}
155+
156+
impl ChainableMethod {
157+
fn parse(s: &str) -> ChainableMethod {
158+
let (s, enabled) = match s.strip_prefix('-') {
159+
Some(s) => (s, false),
160+
None => (s, true),
161+
};
162+
let filter = match s {
163+
"all" => ChainableMethodFilter::All,
164+
other => {
165+
if other.contains("[method]") {
166+
ChainableMethodFilter::Method(other.to_string())
167+
} else {
168+
ChainableMethodFilter::Resource(other.to_string())
169+
}
170+
}
171+
};
172+
ChainableMethod { enabled, filter }
173+
}
174+
}
175+
176+
impl fmt::Display for ChainableMethod {
177+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178+
if !self.enabled {
179+
write!(f, "-")?;
180+
}
181+
self.filter.fmt(f)
182+
}
183+
}
184+
185+
#[derive(Debug, Clone)]
186+
#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
187+
enum ChainableMethodFilter {
188+
All,
189+
Resource(String),
190+
Method(String),
191+
}
192+
193+
impl fmt::Display for ChainableMethodFilter {
194+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195+
match self {
196+
ChainableMethodFilter::All => write!(f, "all"),
197+
ChainableMethodFilter::Resource(s) => write!(f, "{s}"),
198+
ChainableMethodFilter::Method(s) => write!(f, "{s}"),
199+
}
200+
}
201+
}

crates/core/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ mod path;
1414
pub use path::name_package_module;
1515
mod async_;
1616
pub use async_::AsyncFilterSet;
17+
mod chainable_method;
18+
pub use chainable_method::ChainableMethodFilterSet;
1719

1820
#[derive(Default, Copy, Clone, PartialEq, Eq, Debug)]
1921
pub enum Direction {

crates/guest-rust/macro/src/lib.rs

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};
66
use syn::parse::{Error, Parse, ParseStream, Result};
77
use syn::punctuated::Punctuated;
88
use syn::{Token, braced, token};
9-
use wit_bindgen_core::AsyncFilterSet;
109
use wit_bindgen_core::WorldGenerator;
1110
use wit_bindgen_core::wit_parser::{PackageId, Resolve, UnresolvedPackageGroup, WorldId};
11+
use wit_bindgen_core::{AsyncFilterSet, ChainableMethodFilterSet};
1212
use wit_bindgen_rust::{Opts, Ownership, WithOption};
1313

1414
#[proc_macro]
@@ -66,6 +66,7 @@ impl Parse for Config {
6666
let mut source = None;
6767
let mut features = Vec::new();
6868
let mut async_configured = false;
69+
let mut method_chaining_configured = false;
6970
let mut debug = false;
7071

7172
if input.peek(token::Brace) {
@@ -165,8 +166,15 @@ impl Parse for Config {
165166
async_configured = true;
166167
opts.async_ = val;
167168
}
168-
Opt::EnableMethodChaining(enable) => {
169-
opts.enable_method_chaining = enable.value();
169+
Opt::ChainableMethods(val, span) => {
170+
if method_chaining_configured {
171+
return Err(Error::new(
172+
span,
173+
"cannot specify second method chaining config",
174+
));
175+
}
176+
method_chaining_configured = true;
177+
opts.chainable_methods = val;
170178
}
171179
}
172180
}
@@ -321,7 +329,7 @@ mod kw {
321329
syn::custom_keyword!(disable_custom_section_link_helpers);
322330
syn::custom_keyword!(imports);
323331
syn::custom_keyword!(debug);
324-
syn::custom_keyword!(enable_method_chaining);
332+
syn::custom_keyword!(chainable_methods);
325333
}
326334

327335
#[derive(Clone)]
@@ -402,7 +410,7 @@ enum Opt {
402410
DisableCustomSectionLinkHelpers(syn::LitBool),
403411
Async(AsyncFilterSet, Span),
404412
Debug(syn::LitBool),
405-
EnableMethodChaining(syn::LitBool),
413+
ChainableMethods(ChainableMethodFilterSet, Span),
406414
}
407415

408416
impl Parse for Opt {
@@ -567,10 +575,17 @@ impl Parse for Opt {
567575
input.parse::<kw::debug>()?;
568576
input.parse::<Token![:]>()?;
569577
Ok(Opt::Debug(input.parse()?))
570-
} else if l.peek(kw::enable_method_chaining) {
571-
input.parse::<kw::enable_method_chaining>()?;
578+
} else if l.peek(kw::chainable_methods) {
579+
let span = input.parse::<kw::chainable_methods>()?.span;
572580
input.parse::<Token![:]>()?;
573-
Ok(Opt::EnableMethodChaining(input.parse()?))
581+
582+
let mut set = ChainableMethodFilterSet::default();
583+
let contents;
584+
syn::bracketed!(contents in input);
585+
for val in contents.parse_terminated(|p| p.parse::<syn::LitStr>(), Token![,])? {
586+
set.push(&val.value());
587+
}
588+
Ok(Opt::ChainableMethods(set, span))
574589
} else if l.peek(Token![async]) {
575590
let span = input.parse::<Token![async]>()?.span;
576591
input.parse::<Token![:]>()?;

0 commit comments

Comments
 (0)