Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
4697c13
make context ref mut
axos88 Nov 16, 2025
1b605a3
run cargo fmt
axos88 Nov 23, 2025
02ca335
cargo clippy
axos88 Nov 23, 2025
8e1dce4
separate context into shared, client and server contexts. only transm…
axos88 Nov 18, 2025
d1afa2c
allow transports to see and manipulate client and server contexts.
axos88 Nov 23, 2025
97d0a37
run cargo fmt
axos88 Nov 23, 2025
15b84e4
run cargo clippy
axos88 Nov 23, 2025
117ae57
simplify api
axos88 Nov 24, 2025
b2eb13b
allow transport to access server context on response as well
axos88 Nov 25, 2025
54b8fe8
allow server to mutate shared context
axos88 Nov 25, 2025
6ded9ed
fix typo
axos88 Nov 25, 2025
6fa1292
make servertransport generic, defined by the service implementation.
axos88 Nov 26, 2025
0045581
remove servercontext entirely
axos88 Nov 26, 2025
7989bc0
make clientContext generic as well
axos88 Nov 26, 2025
8bd243a
fix merge conflict
axos88 Nov 26, 2025
cf3fa53
run cargo fmt
axos88 Nov 26, 2025
34a87d6
cleanup...
axos88 Nov 26, 2025
116c718
cleanup
axos88 Nov 26, 2025
044629d
cleanup
axos88 Nov 26, 2025
b112017
more
axos88 Nov 26, 2025
b988a2d
cleanup
axos88 Nov 26, 2025
21e9223
cleanup
axos88 Nov 26, 2025
835d92c
cleanup
axos88 Nov 26, 2025
9d9f646
allow shared context sent between the client and server to be customi…
axos88 Dec 26, 2025
067e3ae
run cargo fmt and clippy. Rename CustomContext to CustomSharedContext
axos88 Dec 26, 2025
8348cf6
separate context extraction and updating, since the server side only …
axos88 Dec 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions example-service/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![deny(warnings, unused, dead_code)]

use clap::Parser;
use service::{WorldClient, init_tracing};
Expand Down Expand Up @@ -34,10 +35,13 @@ async fn main() -> anyhow::Result<()> {
let client = WorldClient::new(client::Config::default(), transport.await?).spawn();

let hello = async move {
let mut context = context::current();
let mut context2 = context::current();

// Send the request twice, just to be safe! ;)
tokio::select! {
hello1 = client.hello(context::current(), format!("{}1", flags.name)) => { hello1 }
hello2 = client.hello(context::current(), format!("{}2", flags.name)) => { hello2 }
hello1 = client.hello(&mut context, format!("{}1", flags.name)) => { hello1 }
hello2 = client.hello(&mut context2, format!("{}2", flags.name)) => { hello2 }
}
}
.instrument(tracing::info_span!("Two Hellos"))
Expand Down
2 changes: 2 additions & 0 deletions example-service/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

#![deny(warnings, unused, dead_code)]

use opentelemetry::trace::TracerProvider as _;
use tracing_subscriber::{fmt::format::FmtSpan, prelude::*};

Expand Down
4 changes: 3 additions & 1 deletion example-service/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![deny(warnings, unused, dead_code)]

use clap::Parser;
use futures::{future, prelude::*};
Expand Down Expand Up @@ -35,7 +36,8 @@ struct Flags {
struct HelloServer(SocketAddr);

impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
type Context = context::DefaultContext;
async fn hello(self, _: &mut Self::Context, name: String) -> String {
let sleep_time =
Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng()));
time::sleep(sleep_time).await;
Expand Down
138 changes: 96 additions & 42 deletions plugins/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,15 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

#![deny(warnings, unused, dead_code)]
#![recursion_limit = "512"]

extern crate proc_macro;
extern crate proc_macro2;
extern crate quote;
extern crate syn;

use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{ToTokens, format_ident, quote};
use syn::{
AttrStyle, Attribute, Expr, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, Path,
ReturnType, Token, Type, Visibility, braced,
AttrStyle, Attribute, Expr, ExprLit, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType,
Path, ReturnType, Token, Type, Visibility, braced,
ext::IdentExt,
parenthesized,
parse::{Parse, ParseStream},
Expand Down Expand Up @@ -143,6 +139,7 @@ impl Parse for RpcMethod {
#[derive(Default)]
struct DeriveMeta {
derive: Option<Derive>,
shared_context: Option<Type>,
warnings: Vec<TokenStream2>,
}

Expand Down Expand Up @@ -255,6 +252,37 @@ impl Parse for DeriveMeta {
),
}
derive_serde.push(meta);
} else if segment.ident == "shared_context" {
let Expr::Lit(ExprLit {
lit: Lit::Str(ref v),
..
}) = meta.value
else {
extend_errors!(
result,
syn::Error::new(
meta.span(),
"tarpc::service requires a literal string value for the shared_context attribute"
)
);
continue;
};

let Ok(ty) = syn::parse_str(&v.value()) else {
extend_errors!(
result,
syn::Error::new(
meta.span(),
"tarpc::service could not parse the value of the shared_context attribute as a type"
)
);
continue;
};

result = result.map(|d| DeriveMeta {
shared_context: Some(ty),
..d
})
} else {
extend_errors!(
result,
Expand Down Expand Up @@ -375,7 +403,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec<Vec<&Attribute>> {
/// # Example
///
/// ```no_run
/// use tarpc::{client, transport, service, server::{self, Channel}, context::Context};
/// use tarpc::{client, context, transport, service, server::{self, Channel}, context::DefaultContext};
///
/// #[service]
/// pub trait Calculator {
Expand All @@ -397,22 +425,32 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec<Vec<&Attribute>> {
/// // A client can be made like so:
/// let client = CalculatorClient::new(client::Config::default(), client_side);
///
/// // You would usually call it like so.
/// #[cfg(feature = "tokio1")]
/// let client = client.spawn();
/// #[cfg(not(feature = "tokio1"))]
/// let client = client.client; // Don't forget to run the dispatch future!
///
/// // And a server like so:
/// #[derive(Clone)]
/// struct CalculatorServer;
/// impl Calculator for CalculatorServer {
/// async fn add(self, context: Context, a: i32, b: i32) -> i32 {
/// type Context = context::DefaultContext;
/// async fn add(self, context: &mut Self::Context, a: i32, b: i32) -> i32 {
/// a + b
/// }
/// }
///
/// // You would usually spawn on an async runtime.
/// let server = server::BaseChannel::with_defaults(server_side);
/// let _ = server.execute(CalculatorServer.serve());
///
/// let _ = client.add(&mut context::current(), 1,2);
/// ```
#[proc_macro_attribute]
pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
let derive_meta = parse_macro_input!(attr as DeriveMeta);

let unit_type: &Type = &parse_quote!(());
let Service {
ref attrs,
Expand All @@ -427,6 +465,9 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
.collect();
let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>();

let shared_context = derive_meta
.shared_context
.unwrap_or(parse_quote!(::tarpc::context::DefaultContext));
let derives = match derive_meta.derive.as_ref() {
Some(Derive::Explicit(paths)) => {
if !paths.is_empty() {
Expand Down Expand Up @@ -501,6 +542,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
.map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
.collect::<Vec<_>>(),
derives: derives.as_ref(),
shared_context: &shared_context,
warnings: &derive_meta.warnings,
}
.into_token_stream()
Expand Down Expand Up @@ -528,6 +570,7 @@ struct ServiceGenerator<'a> {
return_types: &'a [&'a Type],
arg_pats: &'a [Vec<&'a Pat>],
derives: Option<&'a TokenStream2>,
shared_context: &'a Type,
warnings: &'a [TokenStream2],
}

Expand All @@ -543,30 +586,30 @@ impl ServiceGenerator<'_> {
request_ident,
response_ident,
server_ident,
shared_context,
..
} = self;

let rpc_fns = rpcs
.iter()
.zip(return_types.iter())
.map(
|(
RpcMethod {
attrs, ident, args, ..
},
output,
)| {
quote! {
#( #attrs )*
async fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> #output;
}
let rpc_fns = rpcs.iter().zip(return_types.iter()).map(
|(
RpcMethod {
attrs, ident, args, ..
},
);
output,
)| {
quote! {
#( #attrs )*
async fn #ident(self, context: &mut Self::Context, #( #args ),*) -> #output;
}
},
);

let stub_doc = format!("The stub trait for service [`{service_ident}`].");
quote! {
#( #attrs )*
#vis trait #service_ident: ::core::marker::Sized {
type Context: ::tarpc::context::ExtractContext<#shared_context>; // = ::tarpc::context::DefaultContext; TODO: Add associated type default once https://github.com/rust-lang/rust/issues/29661 is stabilized

#( #rpc_fns )*

/// Returns a serving function to use with
Expand All @@ -577,11 +620,11 @@ impl ServiceGenerator<'_> {
}

#[doc = #stub_doc]
#vis trait #client_stub_ident: ::tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident> {
#vis trait #client_stub_ident<ClientCtx>: ::tarpc::client::stub::Stub<ClientCtx = ClientCtx, Req = #request_ident, Resp = #response_ident> {
}

impl<S> #client_stub_ident for S
where S: ::tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident>
impl<S, ClientCtx> #client_stub_ident<ClientCtx> for S
where S: ::tarpc::client::stub::Stub<ClientCtx = ClientCtx, Req = #request_ident, Resp = #response_ident>
{
}
}
Expand Down Expand Up @@ -620,9 +663,9 @@ impl ServiceGenerator<'_> {
{
type Req = #request_ident;
type Resp = #response_ident;
type ServerCtx = S::Context;


async fn serve(self, ctx: ::tarpc::context::Context, req: #request_ident)
async fn serve(self, ctx: &mut Self::ServerCtx, req: #request_ident)
-> ::core::result::Result<#response_ident, ::tarpc::ServerError> {
match req {
#(
Expand Down Expand Up @@ -706,17 +749,25 @@ impl ServiceGenerator<'_> {
client_ident,
request_ident,
response_ident,
shared_context,
..
} = self;

quote! {
#[allow(unused)]
#[derive(Clone, Debug)]
#[derive(Debug)]
/// The client stub that makes RPC calls to the server. All request methods return
/// [Futures](::core::future::Future).
#vis struct #client_ident<
Stub = ::tarpc::client::Channel<#request_ident, #response_ident>
>(Stub);
ClientCtx,
Stub = ::tarpc::client::Channel<#request_ident, #response_ident, ClientCtx, #shared_context>
>(Stub, ::std::marker::PhantomData<ClientCtx>);

impl<ClientCtx, Stub: ::std::clone::Clone> ::std::clone::Clone for #client_ident<ClientCtx, Stub> {
fn clone(&self) -> Self {
Self(self.0.clone(), ::std::marker::PhantomData)
}
}
}
}

Expand All @@ -726,36 +777,38 @@ impl ServiceGenerator<'_> {
vis,
request_ident,
response_ident,
shared_context,
..
} = self;

quote! {
impl #client_ident {
impl<ClientCtx> #client_ident<ClientCtx> {
/// Returns a new client stub that sends requests over the given transport.
#vis fn new<T>(config: ::tarpc::client::Config, transport: T)
-> ::tarpc::client::NewClient<
Self,
::tarpc::client::RequestDispatch<#request_ident, #response_ident, T>
::tarpc::client::RequestDispatch<#request_ident, #response_ident, ClientCtx, #shared_context, T>
>
where
T: ::tarpc::Transport<::tarpc::ClientMessage<#request_ident>, ::tarpc::Response<#response_ident>>
T: ::tarpc::Transport<::tarpc::ClientMessage<ClientCtx, #request_ident>, ::tarpc::Response<ClientCtx, #response_ident>>
{
let new_client = ::tarpc::client::new(config, transport);
::tarpc::client::NewClient {
client: #client_ident(new_client.client),
client: #client_ident(new_client.client, ::std::marker::PhantomData),
dispatch: new_client.dispatch,
}
}
}

impl<Stub> ::core::convert::From<Stub> for #client_ident<Stub>
impl<ClientCtx, Stub> ::core::convert::From<Stub> for #client_ident<ClientCtx, Stub>
where Stub: ::tarpc::client::stub::Stub<
Req = #request_ident,
Resp = #response_ident>
Resp = #response_ident,
ClientCtx = ClientCtx>
{
/// Returns a new client stub that sends requests over the given transport.
fn from(stub: Stub) -> Self {
#client_ident(stub)
#client_ident::<ClientCtx, Stub>(stub, ::std::marker::PhantomData)
}

}
Expand All @@ -778,15 +831,16 @@ impl ServiceGenerator<'_> {
} = self;

quote! {
impl<Stub> #client_ident<Stub>
impl<ClientCtx, Stub> #client_ident<ClientCtx, Stub>
where Stub: ::tarpc::client::stub::Stub<
Req = #request_ident,
Resp = #response_ident>
Resp = #response_ident,
ClientCtx = ClientCtx>
{
#(
#[allow(unused)]
#( #method_attrs )*
#vis fn #method_idents(&self, ctx: ::tarpc::context::Context, #( #args ),*)
#vis fn #method_idents<'a>(&'a self, ctx: &'a mut Stub::ClientCtx, #( #args ),*)
-> impl ::core::future::Future<Output = ::core::result::Result<#return_types, ::tarpc::client::RpcError>> + '_ {
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
let resp = self.0.call(ctx, request);
Expand Down
Loading