diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 8a4ff72eb..6f3930343 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -3,6 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] use clap::Parser; use service::{WorldClient, init_tracing}; @@ -34,10 +35,13 @@ async fn main() -> anyhow::Result<()> { let client = WorldClient::new(client::Config::default(), transport.await?).spawn(); let hello = async move { + let mut context = context::current(); + let mut context2 = context::current(); + // Send the request twice, just to be safe! ;) tokio::select! { - hello1 = client.hello(context::current(), format!("{}1", flags.name)) => { hello1 } - hello2 = client.hello(context::current(), format!("{}2", flags.name)) => { hello2 } + hello1 = client.hello(&mut context, format!("{}1", flags.name)) => { hello1 } + hello2 = client.hello(&mut context2, format!("{}2", flags.name)) => { hello2 } } } .instrument(tracing::info_span!("Two Hellos")) diff --git a/example-service/src/lib.rs b/example-service/src/lib.rs index 26b49e2ac..fd5031e75 100644 --- a/example-service/src/lib.rs +++ b/example-service/src/lib.rs @@ -4,6 +4,8 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] + use opentelemetry::trace::TracerProvider as _; use tracing_subscriber::{fmt::format::FmtSpan, prelude::*}; diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 896280c3d..37bdb0f42 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -3,6 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] use clap::Parser; use futures::{future, prelude::*}; @@ -35,7 +36,8 @@ struct Flags { struct HelloServer(SocketAddr); impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + type Context = context::DefaultContext; + async fn hello(self, _: &mut Self::Context, name: String) -> String { let sleep_time = Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng())); time::sleep(sleep_time).await; diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index da6443edf..46375fce0 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -4,19 +4,15 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] #![recursion_limit = "512"] -extern crate proc_macro; -extern crate proc_macro2; -extern crate quote; -extern crate syn; - use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; use quote::{ToTokens, format_ident, quote}; use syn::{ - AttrStyle, Attribute, Expr, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, Path, - ReturnType, Token, Type, Visibility, braced, + AttrStyle, Attribute, Expr, ExprLit, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, + Path, ReturnType, Token, Type, Visibility, braced, ext::IdentExt, parenthesized, parse::{Parse, ParseStream}, @@ -143,6 +139,7 @@ impl Parse for RpcMethod { #[derive(Default)] struct DeriveMeta { derive: Option, + shared_context: Option, warnings: Vec, } @@ -255,6 +252,37 @@ impl Parse for DeriveMeta { ), } derive_serde.push(meta); + } else if segment.ident == "shared_context" { + let Expr::Lit(ExprLit { + lit: Lit::Str(ref v), + .. + }) = meta.value + else { + extend_errors!( + result, + syn::Error::new( + meta.span(), + "tarpc::service requires a literal string value for the shared_context attribute" + ) + ); + continue; + }; + + let Ok(ty) = syn::parse_str(&v.value()) else { + extend_errors!( + result, + syn::Error::new( + meta.span(), + "tarpc::service could not parse the value of the shared_context attribute as a type" + ) + ); + continue; + }; + + result = result.map(|d| DeriveMeta { + shared_context: Some(ty), + ..d + }) } else { extend_errors!( result, @@ -375,7 +403,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// # Example /// /// ```no_run -/// use tarpc::{client, transport, service, server::{self, Channel}, context::Context}; +/// use tarpc::{client, context, transport, service, server::{self, Channel}, context::DefaultContext}; /// /// #[service] /// pub trait Calculator { @@ -397,11 +425,18 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// // A client can be made like so: /// let client = CalculatorClient::new(client::Config::default(), client_side); /// +/// // You would usually call it like so. +/// #[cfg(feature = "tokio1")] +/// let client = client.spawn(); +/// #[cfg(not(feature = "tokio1"))] +/// let client = client.client; // Don't forget to run the dispatch future! +/// /// // And a server like so: /// #[derive(Clone)] /// struct CalculatorServer; /// impl Calculator for CalculatorServer { -/// async fn add(self, context: Context, a: i32, b: i32) -> i32 { +/// type Context = context::DefaultContext; +/// async fn add(self, context: &mut Self::Context, a: i32, b: i32) -> i32 { /// a + b /// } /// } @@ -409,10 +444,13 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// // You would usually spawn on an async runtime. /// let server = server::BaseChannel::with_defaults(server_side); /// let _ = server.execute(CalculatorServer.serve()); +/// +/// let _ = client.add(&mut context::current(), 1,2); /// ``` #[proc_macro_attribute] pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { let derive_meta = parse_macro_input!(attr as DeriveMeta); + let unit_type: &Type = &parse_quote!(()); let Service { ref attrs, @@ -427,6 +465,9 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { .collect(); let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::>(); + let shared_context = derive_meta + .shared_context + .unwrap_or(parse_quote!(::tarpc::context::DefaultContext)); let derives = match derive_meta.derive.as_ref() { Some(Derive::Explicit(paths)) => { if !paths.is_empty() { @@ -501,6 +542,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { .map(|(rpc, name)| Ident::new(name, rpc.ident.span())) .collect::>(), derives: derives.as_ref(), + shared_context: &shared_context, warnings: &derive_meta.warnings, } .into_token_stream() @@ -528,6 +570,7 @@ struct ServiceGenerator<'a> { return_types: &'a [&'a Type], arg_pats: &'a [Vec<&'a Pat>], derives: Option<&'a TokenStream2>, + shared_context: &'a Type, warnings: &'a [TokenStream2], } @@ -543,30 +586,30 @@ impl ServiceGenerator<'_> { request_ident, response_ident, server_ident, + shared_context, .. } = self; - let rpc_fns = rpcs - .iter() - .zip(return_types.iter()) - .map( - |( - RpcMethod { - attrs, ident, args, .. - }, - output, - )| { - quote! { - #( #attrs )* - async fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> #output; - } + let rpc_fns = rpcs.iter().zip(return_types.iter()).map( + |( + RpcMethod { + attrs, ident, args, .. }, - ); + output, + )| { + quote! { + #( #attrs )* + async fn #ident(self, context: &mut Self::Context, #( #args ),*) -> #output; + } + }, + ); let stub_doc = format!("The stub trait for service [`{service_ident}`]."); quote! { #( #attrs )* #vis trait #service_ident: ::core::marker::Sized { + type Context: ::tarpc::context::ExtractContext<#shared_context>; // = ::tarpc::context::DefaultContext; TODO: Add associated type default once https://github.com/rust-lang/rust/issues/29661 is stabilized + #( #rpc_fns )* /// Returns a serving function to use with @@ -577,11 +620,11 @@ impl ServiceGenerator<'_> { } #[doc = #stub_doc] - #vis trait #client_stub_ident: ::tarpc::client::stub::Stub { + #vis trait #client_stub_ident: ::tarpc::client::stub::Stub { } - impl #client_stub_ident for S - where S: ::tarpc::client::stub::Stub + impl #client_stub_ident for S + where S: ::tarpc::client::stub::Stub { } } @@ -620,9 +663,9 @@ impl ServiceGenerator<'_> { { type Req = #request_ident; type Resp = #response_ident; + type ServerCtx = S::Context; - - async fn serve(self, ctx: ::tarpc::context::Context, req: #request_ident) + async fn serve(self, ctx: &mut Self::ServerCtx, req: #request_ident) -> ::core::result::Result<#response_ident, ::tarpc::ServerError> { match req { #( @@ -706,17 +749,25 @@ impl ServiceGenerator<'_> { client_ident, request_ident, response_ident, + shared_context, .. } = self; quote! { #[allow(unused)] - #[derive(Clone, Debug)] + #[derive(Debug)] /// The client stub that makes RPC calls to the server. All request methods return /// [Futures](::core::future::Future). #vis struct #client_ident< - Stub = ::tarpc::client::Channel<#request_ident, #response_ident> - >(Stub); + ClientCtx, + Stub = ::tarpc::client::Channel<#request_ident, #response_ident, ClientCtx, #shared_context> + >(Stub, ::std::marker::PhantomData); + + impl ::std::clone::Clone for #client_ident { + fn clone(&self) -> Self { + Self(self.0.clone(), ::std::marker::PhantomData) + } + } } } @@ -726,36 +777,38 @@ impl ServiceGenerator<'_> { vis, request_ident, response_ident, + shared_context, .. } = self; quote! { - impl #client_ident { + impl #client_ident { /// Returns a new client stub that sends requests over the given transport. #vis fn new(config: ::tarpc::client::Config, transport: T) -> ::tarpc::client::NewClient< Self, - ::tarpc::client::RequestDispatch<#request_ident, #response_ident, T> + ::tarpc::client::RequestDispatch<#request_ident, #response_ident, ClientCtx, #shared_context, T> > where - T: ::tarpc::Transport<::tarpc::ClientMessage<#request_ident>, ::tarpc::Response<#response_ident>> + T: ::tarpc::Transport<::tarpc::ClientMessage, ::tarpc::Response> { let new_client = ::tarpc::client::new(config, transport); ::tarpc::client::NewClient { - client: #client_ident(new_client.client), + client: #client_ident(new_client.client, ::std::marker::PhantomData), dispatch: new_client.dispatch, } } } - impl ::core::convert::From for #client_ident + impl ::core::convert::From for #client_ident where Stub: ::tarpc::client::stub::Stub< Req = #request_ident, - Resp = #response_ident> + Resp = #response_ident, + ClientCtx = ClientCtx> { /// Returns a new client stub that sends requests over the given transport. fn from(stub: Stub) -> Self { - #client_ident(stub) + #client_ident::(stub, ::std::marker::PhantomData) } } @@ -778,15 +831,16 @@ impl ServiceGenerator<'_> { } = self; quote! { - impl #client_ident + impl #client_ident where Stub: ::tarpc::client::stub::Stub< Req = #request_ident, - Resp = #response_ident> + Resp = #response_ident, + ClientCtx = ClientCtx> { #( #[allow(unused)] #( #method_attrs )* - #vis fn #method_idents(&self, ctx: ::tarpc::context::Context, #( #args ),*) + #vis fn #method_idents<'a>(&'a self, ctx: &'a mut Stub::ClientCtx, #( #args ),*) -> impl ::core::future::Future> + '_ { let request = #request_ident::#camel_case_idents { #( #arg_pats ),* }; let resp = self.0.call(ctx, request); diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index 26ee1ec39..7473cac3b 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -12,15 +12,16 @@ fn att_service_trait() { } impl Foo for () { - async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) { + type Context = context::DefaultContext; + async fn two_part(self, _: &mut Self::Context, s: String, i: i32) -> (String, i32) { (s, i) } - async fn bar(self, _: context::Context, s: String) -> String { + async fn bar(self, _: &mut Self::Context, s: String) -> String { s } - async fn baz(self, _: context::Context) {} + async fn baz(self, _: &mut Self::Context) {} } } @@ -37,20 +38,21 @@ fn raw_idents() { } impl r#trait for () { + type Context = context::DefaultContext; async fn r#await( self, - _: context::Context, + _: &mut Self::Context, r#struct: r#yield, r#enum: i32, ) -> (r#yield, i32) { (r#struct, r#enum) } - async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield { + async fn r#fn(self, _: &mut Self::Context, r#impl: r#yield) -> r#yield { r#impl } - async fn r#async(self, _: context::Context) {} + async fn r#async(self, _: &mut Self::Context) {} } } @@ -64,7 +66,8 @@ fn service_with_cfg_rpc() { } impl Foo for () { - async fn foo(self, _: context::Context) {} + type Context = context::DefaultContext; + async fn foo(self, _: &mut Self::Context) {} } } diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index d66261d19..3eebe963b 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -3,6 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] use flate2::{Compression, read::DeflateDecoder, write::DeflateEncoder}; use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt, prelude::*}; @@ -108,7 +109,8 @@ pub trait World { struct HelloServer; impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + type Context = context::DefaultContext; + async fn hello(self, _: &mut Self::Context, name: String) -> String { format!("Hey, {name}!") } } @@ -134,7 +136,9 @@ async fn main() -> anyhow::Result<()> { println!( "{}", - client.hello(context::current(), "friend".into()).await? + client + .hello(&mut context::current(), "friend".into()) + .await? ); Ok(()) } diff --git a/tarpc/examples/custom_context.rs b/tarpc/examples/custom_context.rs new file mode 100644 index 000000000..6a76bcb12 --- /dev/null +++ b/tarpc/examples/custom_context.rs @@ -0,0 +1,347 @@ +// Copyright 2018 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] + +use futures::{FutureExt, SinkExt, StreamExt, TryStreamExt}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::ops::Add; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tarpc::context::{ExtractContext, SharedContext, UpdateContext}; +use tarpc::server::request_hook::{AfterRequest, BeforeRequest, RequestHook}; +use tarpc::transport::channel::UnboundedChannel; +use tarpc::{ + ClientMessage, Request, Response, ServerError, Transport, client, + server::{self, Channel}, + trace, +}; +use tokio::sync::Mutex; + + +/// This is the context that is sent between the client and server. +#[derive(Serialize, Deserialize, Clone)] +struct CustomSharedContext { + #[serde(with = "absolute_to_relative_time")] + pub deadline: Instant, + pub trace_context: trace::Context, + pub session_id: Option, +} + +/// This context is only seen by the client side. It can be used by the transport, or any step before sending to dispatch, +/// In this case used by the very simple example of delaying the request, but it can be used for batching, prioritisation, etc. +#[derive(Clone, Debug)] +struct ClientContext { + pub session_id: Option, + pub delay_sending_by_seconds: u32, +} + +/// This context is only seen by the server. It can be used by the transport, the hooks or the service implementation itself to +/// influence its behaviour. In our case the SessionHook extracts the session data +struct ServerContext { + pub deadline: Instant, + pub trace_context: trace::Context, + pub session_id: Option, + pub balance: u64 +} +impl SharedContext for CustomSharedContext { + fn deadline(&self) -> Instant { + self.deadline + } + + fn trace_context(&self) -> trace::Context { + self.trace_context + } + + fn set_trace_context(&mut self, trace_context: trace::Context) { + self.trace_context = trace_context; + } +} + + +impl ExtractContext for ClientContext { + fn extract(&self) -> CustomSharedContext { + CustomSharedContext { + deadline: Instant::now().add(Duration::from_secs(60)), + trace_context: Default::default(), + session_id: self.session_id, + } + } +} + +impl UpdateContext for ClientContext { + fn update(&mut self, value: CustomSharedContext) { + self.session_id = value.session_id; + } +} + +impl ExtractContext for ServerContext { + fn extract(&self) -> CustomSharedContext { + CustomSharedContext { + deadline: self.deadline, + trace_context: self.trace_context, + session_id: self.session_id, + } + } +} + +/// This is the service definition. It looks a lot like a trait definition. +/// It defines one RPC, hello, which takes one arg, name, and returns a String. +#[tarpc::service(shared_context = "CustomSharedContext")] +pub trait World { + async fn create_session() -> (); + async fn increase_balance(credits: u32) -> (); + async fn hello(name: String) -> Result; +} + +/// This is the type that implements the generated World trait. It is the business logic +/// and is used to start the server. +#[derive(Clone)] +struct HelloServer; + +impl World for HelloServer { + type Context = ServerContext; + + async fn create_session(self, ctx: &mut Self::Context) -> () { + ctx.session_id = Some(42); + ctx.balance = 0; + } + + async fn increase_balance(self, ctx: &mut Self::Context, credits: u32) -> () { + ctx.balance = ctx.balance + credits as u64; + } + + async fn hello(self, ctx: &mut Self::Context, name: String) -> Result { + if ctx.session_id != Some(42) { + Err("Session not yet initialized!")? + } + + if ctx.balance == 0 { + Err("Give me more money")? + } + + ctx.balance = ctx.balance - 1; + + Ok(format!("Hello, {name}!")) + } +} + +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + +#[derive(Clone)] +struct SessionHook { + balances: Arc>>, +} + +impl BeforeRequest for SessionHook { + async fn before( + &mut self, + ctx: &mut ServerContext, + _req: &WorldRequest, + ) -> Result<(), ServerError> { + if let Some(id) = ctx.session_id { + let balances = self.balances.lock().await; + ctx.balance = *balances.get(&id).unwrap_or(&0u64) + } + + Ok(()) + } +} + +impl AfterRequest for SessionHook { + async fn after( + &mut self, + ctx: &mut ServerContext, + resp: &mut Result, + ) { + if resp.is_ok() + && let Some(id) = ctx.session_id + { + let mut balances = self.balances.lock().await; + + let b = balances.entry(id).or_insert(0); + *b = ctx.balance; + } + } +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let (client_transport, server_transport) = create_channel(); + + let client_transport = + client_transport.with(|f: ClientMessage| { + async move { + if let ClientMessage::Request(Request { + ref context, + ref message, + .. + }) = f + { + println!("msg = {:?}, ctx = {:?}", message, context); + tokio::time::sleep(Duration::from_secs( + context.delay_sending_by_seconds as u64, + )) + .await; + } + + Ok(f) + } + .boxed() + }); + + let server = server::BaseChannel::with_defaults(server_transport); + let hook = SessionHook { + balances: Arc::new(Mutex::new(HashMap::new())), + }; + tokio::spawn( + server + .execute(HelloServer.serve().before_and_after(hook)) + .for_each(spawn), + ); + + // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` + // that takes a config and any Transport as input. + let client = WorldClient::new(client::Config::default(), client_transport).spawn(); + + // The client has an RPC method for each RPC defined in the annotated trait. It takes the same + // args as defined, with the addition of a Context, which is always the first arg. The Context + // specifies a deadline and trace information which can be helpful in debugging requests. + + let mut client_context: ClientContext = ClientContext { + session_id: None, + delay_sending_by_seconds: 1, + }; + + let hello = client + .hello(&mut client_context, "Stan".to_string()) + .await?; + + assert_eq!(hello, Err("Session not yet initialized!".to_string())); + + let _ = client.create_session(&mut client_context).await?; + let hello = client + .hello(&mut client_context, "Stan".to_string()) + .await?; + assert_eq!(hello, Err("Give me more money".to_string())); + + let _ = client.increase_balance(&mut client_context, 2u32).await?; + + let hello = client + .hello(&mut client_context, "Stan".to_string()) + .await?; + assert_eq!(hello, Ok("Hello, Stan!".to_string())); + + let hello = client + .hello(&mut client_context, "Frank".to_string()) + .await?; + assert_eq!(hello, Ok("Hello, Frank!".to_string())); + + let hello = client + .hello(&mut client_context, "Joshua".to_string()) + .await?; + assert_eq!(hello, Err("Give me more money".to_string())); + + Ok(()) +} + +//*** Helper functions below ***// + +mod absolute_to_relative_time { + pub use serde::{Deserialize, Deserializer, Serialize, Serializer}; + pub use std::time::{Duration, Instant}; + + pub fn serialize(deadline: &Instant, serializer: S) -> Result + where + S: Serializer, + { + let deadline = deadline.duration_since(Instant::now()); + deadline.serialize(serializer) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let deadline = Duration::deserialize(deserializer)?; + Ok(Instant::now() + deadline) + } +} + +fn map_request_context( + req: ClientMessage, + f: impl FnOnce(Ctx) -> Ctx2, +) -> ClientMessage { + match req { + ClientMessage::Request(Request { + context, + id, + message, + }) => ClientMessage::Request(Request { + context: f(context), + id, + message, + }), + ClientMessage::Cancel { + trace_context, + request_id, + } => ClientMessage::Cancel { + trace_context, + request_id, + }, + _ => unimplemented!(), + } +} + +fn map_response_context( + res: Response, + f: impl FnOnce(Ctx) -> Ctx2, +) -> Response { + Response { + request_id: res.request_id, + context: f(res.context), + message: res.message, + } +} + +fn create_channel() -> ( + impl Transport, Response>, + impl Transport, ClientMessage>, +) { + let (client, server): ( + UnboundedChannel< + Response, + ClientMessage, + >, + UnboundedChannel< + ClientMessage, + Response, + >, + ) = tarpc::transport::channel::unbounded(); + + let client = client + .with(|m| futures::future::ok(map_request_context(m, |c: ClientContext| c.extract()))) + .map_ok(|r| { + map_response_context(r, |c: CustomSharedContext| ClientContext { + session_id: c.session_id, + delay_sending_by_seconds: 0, + }) + }); + let server = server + .with(|r| futures::future::ok(map_response_context(r, |c: ServerContext| c.extract()))) + .map_ok(|m| { + map_request_context(m, |c| ServerContext { + deadline: c.deadline, + trace_context: c.trace_context, + session_id: c.session_id, + balance: 0, + }) + }); + + (client, server) +} diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 5f5386785..aa62baf99 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -3,9 +3,10 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] use futures::prelude::*; -use tarpc::context::Context; +use tarpc::context; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; @@ -21,7 +22,8 @@ pub trait PingService { struct Service; impl PingService for Service { - async fn ping(self, _: Context) {} + type Context = context::DefaultContext; + async fn ping(self, _: &mut Self::Context) {} } #[tokio::main] @@ -52,7 +54,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); PingServiceClient::new(Default::default(), transport) .spawn() - .ping(tarpc::context::current()) + .ping(&mut context::current()) .await?; Ok(()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index d61f68c48..6fc08d7b5 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -3,6 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] /// - The PubSub server sets up TCP listeners on 2 ports, the "subscriber" port and the "publisher" /// port. Because both publishers and subscribers initiate their connections to the PubSub @@ -48,6 +49,7 @@ use std::{ sync::{Arc, Mutex, RwLock}, }; use subscriber::Subscriber as _; +use tarpc::context::DefaultContext; use tarpc::{ client, context, serde_transport::tcp, @@ -80,11 +82,12 @@ struct Subscriber { } impl subscriber::Subscriber for Subscriber { - async fn topics(self, _: context::Context) -> Vec { + type Context = context::DefaultContext; + async fn topics(self, _: &mut Self::Context) -> Vec { self.topics.clone() } - async fn receive(self, _: context::Context, topic: String, message: String) { + async fn receive(self, _: &mut Self::Context, topic: String, message: String) { info!(local_addr = %self.local_addr, %topic, %message, "ReceivedMessage") } } @@ -132,10 +135,21 @@ struct Subscription { topics: Vec, } -#[derive(Clone, Debug)] +#[derive(Debug)] struct Publisher { clients: Arc>>, - subscriptions: Arc>>>, + subscriptions: Arc< + RwLock>>>, + >, +} + +impl Clone for Publisher { + fn clone(&self) -> Self { + Publisher { + clients: self.clients.clone(), + subscriptions: self.subscriptions.clone(), + } + } } struct PublisherAddrs { @@ -183,7 +197,6 @@ impl Publisher { tokio::spawn(async move { while let Some(conn) = connecting_subscribers.next().await { let subscriber_addr = conn.peer_addr().unwrap(); - let tarpc::client::NewClient { client: subscriber, dispatch, @@ -207,10 +220,10 @@ impl Publisher { async fn initialize_subscription( &mut self, subscriber_addr: SocketAddr, - subscriber: subscriber::SubscriberClient, + subscriber: subscriber::SubscriberClient, ) { // Populate the topics - if let Ok(topics) = subscriber.topics(context::current()).await { + if let Ok(topics) = subscriber.topics(&mut context::current()).await { self.clients.lock().unwrap().insert( subscriber_addr, Subscription { @@ -263,7 +276,9 @@ impl Publisher { } impl publisher::Publisher for Publisher { - async fn publish(self, _: context::Context, topic: String, message: String) { + type Context = DefaultContext; + + async fn publish(self, ctx: &mut Self::Context, topic: String, message: String) { info!("received message to publish."); let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) { None => return, @@ -271,7 +286,12 @@ impl publisher::Publisher for Publisher { }; let mut publications = Vec::new(); for client in subscribers.values_mut() { - publications.push(client.receive(context::current(), topic.clone(), message.clone())); + publications.push(async { + let mut context = ctx.clone(); + client + .receive(&mut context, topic.clone(), message.clone()) + .await + }); } // Ignore failing subscribers. In a real pubsub, you'd want to continually retry until // subscribers ack. Of course, a lot would be different in a real pubsub :) @@ -342,26 +362,30 @@ async fn main() -> anyhow::Result<()> { .spawn(); publisher - .publish(context::current(), "calculus".into(), "sqrt(2)".into()) + .publish(&mut context::current(), "calculus".into(), "sqrt(2)".into()) .await?; publisher .publish( - context::current(), + &mut context::current(), "cool shorts".into(), "hello to all".into(), ) .await?; publisher - .publish(context::current(), "history".into(), "napoleon".to_string()) + .publish( + &mut context::current(), + "history".into(), + "napoleon".to_string(), + ) .await?; drop(_subscriber0); publisher .publish( - context::current(), + &mut context::current(), "cool shorts".into(), "hello to who?".into(), ) diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index c328bd884..acbade9be 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -3,6 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] use futures::prelude::*; use tarpc::{ @@ -23,7 +24,8 @@ pub trait World { struct HelloServer; impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + type Context = context::DefaultContext; + async fn hello(self, _: &mut Self::Context, name: String) -> String { format!("Hello, {name}!") } } @@ -46,7 +48,9 @@ async fn main() -> anyhow::Result<()> { // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context // specifies a deadline and trace information which can be helpful in debugging requests. - let hello = client.hello(context::current(), "Stim".to_string()).await?; + let hello = client + .hello(&mut context::current(), "Stim".to_string()) + .await?; println!("{hello}"); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index 968f76c17..e7084fe84 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -3,6 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] use futures::prelude::*; use rustls_pemfile::certs; @@ -18,7 +19,7 @@ use tokio_rustls::rustls::{ }; use tokio_rustls::{TlsAcceptor, TlsConnector}; -use tarpc::context::Context; +use tarpc::context; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; @@ -33,7 +34,8 @@ pub trait PingService { struct Service; impl PingService for Service { - async fn ping(self, _: Context) -> String { + type Context = context::DefaultContext; + async fn ping(self, _: &mut Self::Context) -> String { "🔒".to_owned() } } @@ -146,7 +148,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); let answer = PingServiceClient::new(Default::default(), transport) .spawn() - .ping(tarpc::context::current()) + .ping(&mut context::current()) .await?; println!("ping answer: {answer}"); diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 79a7026c0..e281f39fd 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -3,7 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. - +#![deny(warnings, unused, dead_code)] #![allow(clippy::type_complexity)] use crate::{ @@ -19,6 +19,7 @@ use std::{ atomic::{AtomicBool, Ordering}, }, }; +use tarpc::context::DefaultContext; use tarpc::{ ClientMessage, RequestName, Response, ServerError, Transport, client::{ @@ -56,23 +57,25 @@ pub mod double { struct AddServer; impl AddService for AddServer { - async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { + type Context = context::DefaultContext; + async fn add(self, _: &mut Self::Context, x: i32, y: i32) -> i32 { x + y } } #[derive(Clone)] struct DoubleServer { - add_client: add::AddClient, + add_client: add::AddClient, } impl DoubleService for DoubleServer where - Stub: AddStub + Clone + Send + Sync + 'static, + Stub: AddStub + Clone + Send + Sync + 'static, { - async fn double(self, _: context::Context, x: i32) -> Result { + type Context = context::DefaultContext; + async fn double(self, _: &mut Self::Context, x: i32) -> Result { self.add_client - .add(context::current(), x, x) + .add(&mut context::current(), x, x) .await .map_err(|e| e.to_string()) } @@ -124,10 +127,13 @@ where } fn make_stub( - backends: [impl Transport>, Response> + Send + Sync + 'static; N], + backends: [impl Transport>, Response> + + Send + + Sync + + 'static; N], ) -> retry::Retry< impl Fn(&Result, u32) -> bool + Clone, - load_balance::RoundRobin, Resp>>, + load_balance::RoundRobin, Resp, DefaultContext, DefaultContext>>, > where Req: RequestName + Send + Sync + 'static, @@ -193,9 +199,11 @@ async fn main() -> anyhow::Result<()> { let double_client = double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); - let ctx = context::current(); for _ in 1..=5 { - tracing::info!("{:?}", double_client.double(ctx, 1).await?); + tracing::info!( + "{:?}", + double_client.double(&mut context::current(), 1).await? + ); } tracer_provider.shutdown()?; diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 3cf9ff07a..7ebe88cb2 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -9,6 +9,7 @@ mod in_flight_requests; pub mod stub; +use crate::context::{ExtractContext, SharedContext, UpdateContext}; use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -18,6 +19,7 @@ use crate::{ use futures::{prelude::*, ready, stream::Fuse, task::*}; use in_flight_requests::InFlightRequests; use pin_project::pin_project; +use std::marker::PhantomData; use std::{ any::Any, convert::TryFrom, @@ -95,27 +97,33 @@ const _CHECK_USIZE: () = assert!( /// Handles communication from the client to request dispatch. #[derive(Debug)] -pub struct Channel { - to_dispatch: mpsc::Sender>, +pub struct Channel { + to_dispatch: mpsc::Sender>, /// Channel to send a cancel message to the dispatcher. cancellation: RequestCancellation, /// The ID to use for the next request to stage. next_request_id: Arc, + + ///TODO: Document + ghost: PhantomData, } -impl Clone for Channel { +impl Clone for Channel { fn clone(&self) -> Self { Self { to_dispatch: self.to_dispatch.clone(), cancellation: self.cancellation.clone(), next_request_id: self.next_request_id.clone(), + ghost: PhantomData, } } } -impl Channel +impl Channel where Req: RequestName, + ClientCtx: UpdateContext + Clone, + SharedCtx: SharedContext, { /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that /// resolves to the response. @@ -124,19 +132,26 @@ where skip(self, ctx, request), fields( rpc.trace_id = tracing::field::Empty, - rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + ctx.deadline.time_until()), + rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + ctx.extract().deadline().time_until()), otel.kind = "client", otel.name = %request.name()) )] - pub async fn call(&self, mut ctx: context::Context, request: Req) -> Result { + pub async fn call(&self, ctx: &mut ClientCtx, request: Req) -> Result { let span = Span::current(); - ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { + let mut shared_context = ctx.extract(); + shared_context.set_trace_context(trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::trace!( "OpenTelemetry subscriber not installed; making unsampled child context." ); - ctx.trace_context.new_child() - }); - span.record("rpc.trace_id", tracing::field::display(ctx.trace_id())); + shared_context.trace_context().new_child() + })); + span.record( + "rpc.trace_id", + tracing::field::display(shared_context.trace_context().trace_id), + ); + + ctx.update(shared_context); + let (response_completion, mut response) = oneshot::channel(); let request_id = u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); @@ -151,9 +166,10 @@ where cancellation: &self.cancellation, cancel: true, }; + self.to_dispatch .send(DispatchRequest { - ctx, + ctx: ctx.clone(), // TODO: It would be best to forward the &mut ctx here, but unsure how to make the lifetimes work. span, request_id, request, @@ -161,14 +177,19 @@ where }) .await .map_err(|mpsc::error::SendError(_)| RpcError::Shutdown)?; - response_guard.response().await + + let (response_ctx, r) = response_guard.response().await?; + + ctx.update(response_ctx.extract()); + + Ok(r) } } /// A server response that is completed by request dispatch when the corresponding response /// arrives off the wire. -struct ResponseGuard<'a, Resp> { - response: &'a mut oneshot::Receiver>, +struct ResponseGuard<'a, Resp, SharedCtx> { + response: &'a mut oneshot::Receiver>, cancellation: &'a RequestCancellation, request_id: u64, cancel: bool, @@ -195,8 +216,8 @@ pub enum RpcError { Server(#[from] ServerError), } -impl ResponseGuard<'_, Resp> { - async fn response(mut self) -> Result { +impl ResponseGuard<'_, Resp, SharedCtx> { + async fn response(mut self) -> Result<(SharedCtx, Resp), RpcError> { let response = (&mut self.response).await; // Cancel drop logic once a response has been received. self.cancel = false; @@ -213,7 +234,7 @@ impl ResponseGuard<'_, Resp> { } // Cancels the request when dropped, if not already complete. -impl Drop for ResponseGuard<'_, Resp> { +impl Drop for ResponseGuard<'_, Resp, SharedCtx> { fn drop(&mut self) { // The receiver needs to be closed to handle the edge case that the request has not // yet been received by the dispatch task. It is possible for the cancel message to @@ -234,12 +255,15 @@ impl Drop for ResponseGuard<'_, Resp> { /// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the /// channel. -pub fn new( +pub fn new( config: Config, transport: C, -) -> NewClient, RequestDispatch> +) -> NewClient< + Channel, + RequestDispatch, +> where - C: Transport, Response>, + C: Transport, Response>, { let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer); let (cancellation, canceled_requests) = cancellations(); @@ -249,6 +273,7 @@ where to_dispatch, cancellation, next_request_id: Arc::new(AtomicUsize::new(0)), + ghost: PhantomData, }, dispatch: RequestDispatch { config, @@ -257,6 +282,7 @@ where in_flight_requests: InFlightRequests::default(), pending_requests, terminal_error: None, + ghost: PhantomData, }, } } @@ -266,16 +292,16 @@ where #[must_use] #[pin_project()] #[derive(Debug)] -pub struct RequestDispatch { +pub struct RequestDispatch { /// Writes requests to the wire and reads responses off the wire. #[pin] transport: Fuse, /// Requests waiting to be written to the wire. - pending_requests: mpsc::Receiver>, + pending_requests: mpsc::Receiver>, /// Requests that were dropped. canceled_requests: CanceledRequests, /// Requests already written to the wire that haven't yet received responses. - in_flight_requests: InFlightRequests>, + in_flight_requests: InFlightRequests, /// Configures limits to prevent unlimited resource usage. config: Config, /// Produces errors that can be sent in response to any unprocessed requests at the time @@ -283,15 +309,19 @@ pub struct RequestDispatch { /// RequestDispatch::poll, which relies on downcasting the Any to a concrete error type /// determined within the poll function. terminal_error: Option>, + + ghost: PhantomData, } -impl RequestDispatch +impl RequestDispatch where - C: Transport, Response>, + C: Transport, Response>, + ClientCtx: ExtractContext + Clone, //TODO: We need to clone the ClientCtx to be able to send it through the dispatch channel which requires a 'static type, thus we need to move it. I feel this limitation can be lifted with some smart lifetime magic. + SharedCtx: SharedContext, { fn in_flight_requests<'a>( self: &'a mut Pin<&mut Self>, - ) -> &'a mut InFlightRequests> { + ) -> &'a mut InFlightRequests { self.as_mut().project().in_flight_requests } @@ -308,7 +338,10 @@ where .map_err(|e| ChannelError::Ready(Arc::new(e))) } - fn start_send(self: &mut Pin<&mut Self>, message: ClientMessage) -> Result<(), C::Error> { + fn start_send( + self: &mut Pin<&mut Self>, + message: ClientMessage, + ) -> Result<(), C::Error> { self.transport_pin_mut().start_send(message) } @@ -336,7 +369,7 @@ where fn pending_requests_mut<'a>( self: &'a mut Pin<&mut Self>, - ) -> &'a mut mpsc::Receiver> { + ) -> &'a mut mpsc::Receiver> { self.as_mut().project().pending_requests } @@ -417,7 +450,7 @@ where fn poll_next_request( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll, ChannelError>>> { + ) -> Poll, ChannelError>>> { if self.in_flight_requests().len() >= self.config.max_in_flight_requests { tracing::debug!( "At in-flight request capacity ({}/{}).", @@ -457,7 +490,7 @@ where fn poll_next_cancellation( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>>> { + ) -> Poll>>> { ready!(self.ensure_writeable(cx)?); loop { @@ -510,16 +543,26 @@ where // poll_next_request only returns Ready if there is room to buffer another request. // Therefore, we can call write_request without fear of erroring due to a full // buffer. + + let shared_context = ctx.extract(); + + let trace_context = shared_context.trace_context(); + let deadline = shared_context.deadline(); + let request = ClientMessage::Request(Request { id: request_id, message: request, - context: context::Context { - deadline: ctx.deadline, - trace_context: ctx.trace_context, - }, + context: ctx, }); + self.in_flight_requests() - .insert_request(request_id, ctx, span.clone(), response_completion) + .insert_request( + request_id, + trace_context, + deadline, + span.clone(), + response_completion, + ) .expect("Request IDs should be unique"); match self.start_send(request) { Ok(()) => tracing::debug!("SendRequest"), @@ -541,14 +584,15 @@ where self: &mut Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>>> { - let (context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) { - Some(triple) => triple, - None => return Poll::Ready(None), - }; + let (trace_context, span, request_id) = + match ready!(self.as_mut().poll_next_cancellation(cx)?) { + Some(triple) => triple, + None => return Poll::Ready(None), + }; let _entered = span.enter(); let cancel = ClientMessage::Cancel { - trace_context: context.trace_context, + trace_context, request_id, }; self.start_send(cancel) @@ -558,10 +602,13 @@ where } /// Sends a server response to the client task that initiated the associated request. - fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { + fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { if let Some(span) = self.in_flight_requests().complete_request( response.request_id, - response.message.map_err(RpcError::Server), + response + .message + .map_err(RpcError::Server) + .map(|m| (response.context, m)), ) { let _entered = span.enter(); tracing::debug!("ReceiveResponse"); @@ -636,9 +683,12 @@ where } } -impl Future for RequestDispatch +impl Future + for RequestDispatch where - C: Transport, Response>, + C: Transport, Response>, + ClientCtx: ExtractContext + Clone, + SharedCtx: context::SharedContext, { type Output = Result<(), ChannelError>; @@ -668,12 +718,13 @@ where /// A server-bound request sent from a [`Channel`] to request dispatch, which will then manage /// the lifecycle of the request. #[derive(Debug)] -struct DispatchRequest { - pub ctx: context::Context, +struct DispatchRequest { + pub ctx: ClientCtx, + ///TODO: this should be a &mut ClientCtx pub span: Span, pub request_id: u64, pub request: Req, - pub response_completion: oneshot::Sender>, + pub response_completion: oneshot::Sender>, } #[cfg(test)] @@ -681,10 +732,11 @@ mod tests { use super::{ Channel, DispatchRequest, RequestDispatch, ResponseGuard, RpcError, cancellations, }; + use crate::context::DefaultContext; use crate::{ ChannelError, ClientMessage, Response, client::{Config, in_flight_requests::InFlightRequests}, - context::{self, current}, + context, transport::{self, channel::UnboundedChannel}, }; use assert_matches::assert_matches; @@ -708,30 +760,39 @@ mod tests { #[tokio::test] async fn response_completes_request_future() { - let (mut dispatch, mut _channel, mut server_channel) = set_up(); + let (mut dispatch, _channel, mut server_channel) = set_up(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); + let context = context::current(); + dispatch .in_flight_requests - .insert_request(0, context::current(), Span::current(), tx) + .insert_request( + 0, + context.trace_context, + context.deadline, + Span::current(), + tx, + ) .unwrap(); server_channel .send(Response { request_id: 0, + context: context::current(), message: Ok("Resp".into()), }) .await .unwrap(); assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending); - assert_matches!(rx.try_recv(), Ok(Ok(resp)) if resp == "Resp"); + assert_matches!(rx.try_recv(), Ok(Ok((_, resp))) if resp == "Resp"); } #[tokio::test] async fn dispatch_response_cancels_on_drop() { let (cancellation, mut canceled_requests) = cancellations(); let (_, mut response) = oneshot::channel(); - drop(ResponseGuard:: { + drop(ResponseGuard:: { response: &mut response, cancellation: &cancellation, request_id: 3, @@ -746,11 +807,7 @@ mod tests { async fn dispatch_response_doesnt_cancel_after_complete() { let (cancellation, mut canceled_requests) = cancellations(); let (tx, mut response) = oneshot::channel(); - tx.send(Ok(Response { - request_id: 0, - message: Ok("well done"), - })) - .unwrap(); + tx.send(Ok((context::current(), "well done"))).unwrap(); // resp's drop() is run, but should not send a cancel message. ResponseGuard { response: &mut response, @@ -768,11 +825,11 @@ mod tests { #[tokio::test] async fn stage_request() { - let (mut dispatch, mut channel, _server_channel) = set_up(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); - let _resp = send_request(&mut channel, "hi", tx, &mut rx).await; + let _resp = send_request(&mut channel, context::current(), "hi", tx, &mut rx).await; #[allow(unstable_name_collisions)] let req = dispatch.as_mut().poll_next_request(cx).ready(); @@ -790,7 +847,7 @@ mod tests { let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); - let _ = send_request(&mut channel, "hi", tx, &mut rx).await; + let _ = send_request(&mut channel, context::current(), "hi", tx, &mut rx).await; drop(channel); assert!(dispatch.as_mut().poll(cx).is_ready()); @@ -798,6 +855,7 @@ mod tests { &mut server_channel, Response { request_id: 0, + context: context::current(), message: Ok("hello".into()), }, ) @@ -808,11 +866,11 @@ mod tests { #[allow(unstable_name_collisions)] #[tokio::test] async fn stage_request_response_future_dropped_is_canceled_before_sending() { - let (mut dispatch, mut channel, _server_channel) = set_up(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); - let _ = send_request(&mut channel, "hi", tx, &mut rx).await; + let _ = send_request(&mut channel, context::current(), "hi", tx, &mut rx).await; // Drop the channel so polling returns none if no requests are currently ready. drop(channel); @@ -824,11 +882,11 @@ mod tests { #[allow(unstable_name_collisions)] #[tokio::test] async fn stage_request_response_future_dropped_is_canceled_after_sending() { - let (mut dispatch, mut channel, _server_channel) = set_up(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); - let req = send_request(&mut channel, "hi", tx, &mut rx).await; + let req = send_request(&mut channel, context::current(), "hi", tx, &mut rx).await; assert!(dispatch.as_mut().pump_write(cx).ready().is_some()); assert!(!dispatch.in_flight_requests.is_empty()); @@ -845,14 +903,14 @@ mod tests { #[tokio::test] async fn stage_request_response_closed_skipped() { - let (mut dispatch, mut channel, _server_channel) = set_up(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); // Test that a request future that's closed its receiver but not yet canceled its request -- // i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request // map. - let resp = send_request(&mut channel, "hi", tx, &mut rx).await; + let resp = send_request(&mut channel, context::current(), "hi", tx, &mut rx).await; resp.response.close(); assert!(dispatch.as_mut().poll_next_request(cx).is_pending()); @@ -861,14 +919,15 @@ mod tests { #[tokio::test] async fn test_permit_before_transport_error() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); - let (mut dispatch, mut channel, mut cx) = set_up_always_err(TransportError::Flush); + let (mut dispatch, mut channel, mut cx) = + set_up_always_err::(TransportError::Flush); let (tx, mut rx) = oneshot::channel(); // reserve succeeds let permit = reserve_for_send(&mut channel, tx, &mut rx).await; // Since there's an outstanding permit, dispatch should not complete yet. assert_matches!(dispatch.as_mut().poll(&mut cx), Poll::Pending); - let resp = permit("hi"); + let resp = permit(context::current(), "hi"); // errors from both the dispatch future and the request assert_matches!(dispatch.as_mut().poll(&mut cx), Poll::Ready(Err(ChannelError::Flush(e))) if matches!(*e, TransportError::Flush)); @@ -878,20 +937,23 @@ mod tests { #[tokio::test] async fn test_shutdown() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); - let (dispatch, channel, _server_channel) = set_up(); + let (dispatch, channel, _server_channel) = set_up::(); drop(dispatch); // error on send - let resp = channel.call(current(), "hi".to_string()).await; + let resp = channel + .call(&mut context::current(), "hi".to_string()) + .await; assert_matches!(resp, Err(RpcError::Shutdown)); } #[tokio::test] async fn test_transport_error_write() { let cause = TransportError::Write; - let (mut dispatch, mut channel, mut cx) = set_up_always_err(cause); + let (mut dispatch, mut channel, mut cx) = + set_up_always_err::(cause); let (tx, mut rx) = oneshot::channel(); - let resp = send_request(&mut channel, "hi", tx, &mut rx).await; + let resp = send_request(&mut channel, context::current(), "hi", tx, &mut rx).await; assert!(dispatch.as_mut().poll(&mut cx).is_pending()); let res = resp.response().await; assert_matches!(res, Err(RpcError::Send(_))); @@ -911,9 +973,10 @@ mod tests { #[tokio::test] async fn test_transport_error_read() { let cause = TransportError::Read; - let (mut dispatch, mut channel, mut cx) = set_up_always_err(cause); + let (mut dispatch, mut channel, mut cx) = + set_up_always_err::(cause); let (tx, mut rx) = oneshot::channel(); - let resp = send_request(&mut channel, "hi", tx, &mut rx).await; + let resp = send_request(&mut channel, context::current(), "hi", tx, &mut rx).await; assert_eq!( dispatch.as_mut().pump_write(&mut cx), Poll::Ready(Some(Ok(()))) @@ -928,7 +991,7 @@ mod tests { #[tokio::test] async fn test_transport_error_ready() { let cause = TransportError::Ready; - let (mut dispatch, _, mut cx) = set_up_always_err(cause); + let (mut dispatch, _, mut cx) = set_up_always_err::(cause); assert_eq!( dispatch.as_mut().poll(&mut cx), Poll::Ready(Err(ChannelError::Ready(Arc::new(cause)))) @@ -938,7 +1001,7 @@ mod tests { #[tokio::test] async fn test_transport_error_flush() { let cause = TransportError::Flush; - let (mut dispatch, _, mut cx) = set_up_always_err(cause); + let (mut dispatch, _, mut cx) = set_up_always_err::(cause); assert_eq!( dispatch.as_mut().poll(&mut cx), Poll::Ready(Err(ChannelError::Flush(Arc::new(cause)))) @@ -948,7 +1011,8 @@ mod tests { #[tokio::test] async fn test_transport_error_close() { let cause = TransportError::Close; - let (mut dispatch, channel, mut cx) = set_up_always_err(cause); + let (mut dispatch, channel, mut cx) = + set_up_always_err::(cause); drop(channel); assert_eq!( dispatch.as_mut().poll(&mut cx), @@ -957,34 +1021,47 @@ mod tests { } /// Sets up a RequestDispatch with a transport that always errors. - fn set_up_always_err( + fn set_up_always_err( cause: TransportError, ) -> ( - Pin>>>, - Channel, + Pin< + Box< + RequestDispatch< + String, + String, + ClientCtx, + SharedCtx, + AlwaysErrorTransport, + >, + >, + >, + Channel, Context<'static>, ) { let (to_dispatch, pending_requests) = mpsc::channel(1); let (cancellation, canceled_requests) = cancellations(); - let transport: AlwaysErrorTransport = AlwaysErrorTransport(cause, PhantomData); - let dispatch = Box::pin(RequestDispatch:: { + let transport: AlwaysErrorTransport = + AlwaysErrorTransport(cause, PhantomData); + let dispatch = Box::pin(RequestDispatch:: { transport: transport.fuse(), pending_requests, canceled_requests, in_flight_requests: InFlightRequests::default(), config: Config::default(), terminal_error: None, + ghost: PhantomData, }); let channel = Channel { to_dispatch, cancellation, next_request_id: Arc::new(AtomicUsize::new(0)), + ghost: PhantomData, }; let cx = Context::from_waker(noop_waker_ref()); (dispatch, channel, cx) } - struct AlwaysErrorTransport(TransportError, PhantomData); + struct AlwaysErrorTransport(TransportError, PhantomData<(I, ClientCtx)>); #[derive(Debug, Error, PartialEq, Eq, Clone, Copy)] enum TransportError { @@ -1001,7 +1078,7 @@ mod tests { } } - impl Sink for AlwaysErrorTransport { + impl Sink for AlwaysErrorTransport { type Error = TransportError; fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { match self.0 { @@ -1033,8 +1110,8 @@ mod tests { } } - impl Stream for AlwaysErrorTransport { - type Item = Result, TransportError>; + impl Stream for AlwaysErrorTransport { + type Item = Result, TransportError>; fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { if matches!(self.0, TransportError::Read) { Poll::Ready(Some(Err(self.0))) @@ -1044,18 +1121,20 @@ mod tests { } } - fn set_up() -> ( + fn set_up() -> ( Pin< Box< RequestDispatch< String, String, - UnboundedChannel, ClientMessage>, + ClientCtx, + DefaultContext, + UnboundedChannel, ClientMessage>, >, >, >, - Channel, - UnboundedChannel, Response>, + Channel, + UnboundedChannel, Response>, ) { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); @@ -1063,35 +1142,37 @@ mod tests { let (cancellation, canceled_requests) = cancellations(); let (client_channel, server_channel) = transport::channel::unbounded(); - let dispatch = RequestDispatch:: { + let dispatch = RequestDispatch:: { transport: client_channel.fuse(), pending_requests, canceled_requests, in_flight_requests: InFlightRequests::default(), config: Config::default(), terminal_error: None, + ghost: PhantomData, }; let channel = Channel { to_dispatch, cancellation, next_request_id: Arc::new(AtomicUsize::new(0)), + ghost: PhantomData, }; (Box::pin(dispatch), channel, server_channel) } - async fn reserve_for_send<'a>( - channel: &'a mut Channel, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, - ) -> impl FnOnce(&str) -> ResponseGuard<'a, String> { + async fn reserve_for_send<'a, ClientCtx>( + channel: &'a mut Channel, + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, + ) -> impl FnOnce(ClientCtx, &str) -> ResponseGuard<'a, String, DefaultContext> { let permit = channel.to_dispatch.reserve().await.unwrap(); - |request| { + |ctx, request| { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { - ctx: context::current(), + ctx, span: Span::current(), request_id, request: request.to_string(), @@ -1107,16 +1188,17 @@ mod tests { } } - async fn send_request<'a>( - channel: &'a mut Channel, + async fn send_request<'a, ClientCtx>( + channel: &'a mut Channel, + context: ClientCtx, request: &str, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, - ) -> ResponseGuard<'a, String> { + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, + ) -> ResponseGuard<'a, String, ClientCtx> { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { - ctx: context::current(), + ctx: context, span: Span::current(), request_id, request: request.to_string(), @@ -1132,9 +1214,12 @@ mod tests { response_guard } - async fn send_response( - channel: &mut UnboundedChannel, Response>, - response: Response, + async fn send_response( + channel: &mut UnboundedChannel< + ClientMessage, + Response, + >, + response: Response, ) { channel.send(response).await.unwrap(); } diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index 1776a74a0..ec71ad628 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -1,8 +1,10 @@ +use crate::client::RpcError; use crate::{ - context, + trace, util::{Compact, TimeUntil}, }; use fnv::FnvHashMap; +use std::time::Instant; use std::{ collections::hash_map, task::{Context, Poll}, @@ -13,12 +15,12 @@ use tracing::Span; /// Requests already written to the wire that haven't yet received responses. #[derive(Debug)] -pub struct InFlightRequests { - request_data: FnvHashMap>, +pub struct InFlightRequests { + request_data: FnvHashMap>, deadlines: DelayQueue, } -impl Default for InFlightRequests { +impl Default for InFlightRequests { fn default() -> Self { Self { request_data: Default::default(), @@ -28,10 +30,10 @@ impl Default for InFlightRequests { } #[derive(Debug)] -struct RequestData { - ctx: context::Context, +struct RequestData { + ctx: trace::Context, span: Span, - response_completion: oneshot::Sender, + response_completion: oneshot::Sender>, /// The key to remove the timer for the request's deadline. deadline_key: delay_queue::Key, } @@ -41,7 +43,7 @@ struct RequestData { #[derive(Debug)] pub struct AlreadyExistsError; -impl InFlightRequests { +impl InFlightRequests { /// Returns the number of in-flight requests. pub fn len(&self) -> usize { self.request_data.len() @@ -56,13 +58,14 @@ impl InFlightRequests { pub fn insert_request( &mut self, request_id: u64, - ctx: context::Context, + ctx: trace::Context, + deadline: Instant, span: Span, - response_completion: oneshot::Sender, + response_completion: oneshot::Sender>, ) -> Result<(), AlreadyExistsError> { match self.request_data.entry(request_id) { hash_map::Entry::Vacant(vacant) => { - let timeout = ctx.deadline.time_until(); + let timeout = deadline.time_until(); let deadline_key = self.deadlines.insert(request_id, timeout); vacant.insert(RequestData { ctx, @@ -76,8 +79,12 @@ impl InFlightRequests { } } - /// Removes a request without aborting. Returns true iff the request was found. - pub fn complete_request(&mut self, request_id: u64, result: Res) -> Option { + /// Removes a request without aborting. Returns true if the request was found. + pub fn complete_request( + &mut self, + request_id: u64, + result: Result<(SharedCtx, Res), RpcError>, + ) -> Option { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); @@ -95,7 +102,7 @@ impl InFlightRequests { /// Returns Spans for all completes requests. pub fn complete_all_requests<'a>( &'a mut self, - mut result: impl FnMut() -> Res + 'a, + mut result: impl FnMut() -> Result<(SharedCtx, Res), RpcError> + 'a, ) -> impl Iterator + 'a { self.deadlines.clear(); self.request_data.drain().map(move |(_, request_data)| { @@ -106,7 +113,7 @@ impl InFlightRequests { /// Cancels a request without completing (typically used when a request handle was dropped /// before the request completed). - pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::Context, Span)> { + pub fn cancel_request(&mut self, request_id: u64) -> Option<(trace::Context, Span)> { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); @@ -121,7 +128,7 @@ impl InFlightRequests { pub fn poll_expired( &mut self, cx: &mut Context, - expired_error: impl Fn() -> Res, + expired_error: impl Fn() -> Result<(SharedCtx, Res), RpcError>, ) -> Poll> { self.deadlines.poll_expired(cx).map(|expired| { let request_id = expired?.into_inner(); diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 85746b7f2..06e0e438d 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -4,6 +4,7 @@ use crate::{ RequestName, client::{Channel, RpcError}, context, + context::UpdateContext, server::Serve, }; @@ -23,19 +24,28 @@ pub trait Stub { /// The service response type. type Resp; + ///TODO: document + type ClientCtx; + /// Calls a remote service. - async fn call(&self, ctx: context::Context, request: Self::Req) - -> Result; + async fn call( + &self, + ctx: &mut Self::ClientCtx, + request: Self::Req, + ) -> Result; } -impl Stub for Channel +impl Stub for Channel where Req: RequestName, + ClientCtx: UpdateContext + Clone, + SharedCtx: context::SharedContext, { type Req = Req; type Resp = Resp; + type ClientCtx = ClientCtx; - async fn call(&self, ctx: context::Context, request: Req) -> Result { + async fn call(&self, ctx: &mut Self::ClientCtx, request: Req) -> Result { Self::call(self, ctx, request).await } } @@ -46,7 +56,12 @@ where { type Req = S::Req; type Resp = S::Resp; - async fn call(&self, ctx: context::Context, req: Self::Req) -> Result { + type ClientCtx = S::ServerCtx; + async fn call( + &self, + ctx: &mut Self::ClientCtx, + req: Self::Req, + ) -> Result { self.clone().serve(ctx, req).await.map_err(RpcError::Server) } } diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index d28a3c137..5c6cc9aca 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -5,10 +5,7 @@ pub use round_robin::RoundRobin; /// Provides a stub that load-balances with a simple round-robin strategy. mod round_robin { - use crate::{ - client::{RpcError, stub}, - context, - }; + use crate::client::{RpcError, stub}; use cycle::AtomicCycle; impl stub::Stub for RoundRobin @@ -17,10 +14,11 @@ mod round_robin { { type Req = Stub::Req; type Resp = Stub::Resp; + type ClientCtx = Stub::ClientCtx; async fn call( &self, - ctx: context::Context, + ctx: &mut Self::ClientCtx, request: Self::Req, ) -> Result { let next = self.stubs.next(); @@ -98,10 +96,7 @@ mod round_robin { /// Each request is hashed, then mapped to a stub based on the hash. Equivalent requests will use /// the same stub. mod consistent_hash { - use crate::{ - client::{RpcError, stub}, - context, - }; + use crate::client::{RpcError, stub}; use std::{ collections::hash_map::RandomState, hash::{BuildHasher, Hash}, @@ -116,10 +111,11 @@ mod consistent_hash { { type Req = Stub::Req; type Resp = Stub::Resp; + type ClientCtx = Stub::ClientCtx; async fn call( &self, - ctx: context::Context, + ctx: &mut Self::ClientCtx, request: Self::Req, ) -> Result { let index = usize::try_from(self.hasher.hash_one(&request) % self.stubs_len).expect( @@ -200,13 +196,13 @@ mod consistent_hash { )?; for _ in 0..2 { - let resp = stub.call(context::current(), 'a').await?; + let resp = stub.call(&mut context::current(), 'a').await?; assert_eq!(resp, 1); - let resp = stub.call(context::current(), 'b').await?; + let resp = stub.call(&mut context::current(), 'b').await?; assert_eq!(resp, 2); - let resp = stub.call(context::current(), 'c').await?; + let resp = stub.call(&mut context::current(), 'c').await?; assert_eq!(resp, 3); } diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs index 145c14c1f..171f8918e 100644 --- a/tarpc/src/client/stub/mock.rs +++ b/tarpc/src/client/stub/mock.rs @@ -1,16 +1,17 @@ use crate::{ RequestName, ServerError, client::{RpcError, stub::Stub}, - context, }; +use std::marker::PhantomData; use std::{collections::HashMap, hash::Hash, io}; /// A mock stub that returns user-specified responses. -pub struct Mock { +pub struct Mock { responses: HashMap, + ghost: PhantomData, } -impl Mock +impl Mock where Req: Eq + Hash, { @@ -18,19 +19,21 @@ where pub fn new(responses: [(Req, Resp); N]) -> Self { Self { responses: HashMap::from(responses), + ghost: PhantomData, } } } -impl Stub for Mock +impl Stub for Mock where Req: Eq + Hash + RequestName, Resp: Clone, { type Req = Req; type Resp = Resp; + type ClientCtx = ServerCtx; - async fn call(&self, _: context::Context, request: Self::Req) -> Result { + async fn call(&self, _: &mut Self::ClientCtx, request: Self::Req) -> Result { self.responses .get(&request) .cloned() diff --git a/tarpc/src/client/stub/retry.rs b/tarpc/src/client/stub/retry.rs index a07b05fc5..5499f60e4 100644 --- a/tarpc/src/client/stub/retry.rs +++ b/tarpc/src/client/stub/retry.rs @@ -3,7 +3,6 @@ use crate::{ RequestName, client::{RpcError, stub}, - context, }; use std::sync::Arc; @@ -15,10 +14,11 @@ where { type Req = Req; type Resp = Stub::Resp; + type ClientCtx = Stub::ClientCtx; async fn call( &self, - ctx: context::Context, + ctx: &mut Self::ClientCtx, request: Self::Req, ) -> Result { let request = Arc::new(request); diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index 8e77cf223..5a510219b 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -21,10 +21,9 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; /// /// The context should not be stored directly in a server implementation, because the context will /// be different for each request in scope. -#[derive(Clone, Copy, Debug)] -#[non_exhaustive] +#[derive(Debug, Clone)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Context { +pub struct DefaultContext { /// When the client expects the request to be complete by. The server should cancel the request /// if it is not complete by this time. #[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))] @@ -38,6 +37,129 @@ pub struct Context { pub trace_context: trace::Context, } +/// A shared, on-the-wire request context. +/// +/// `SharedContext` defines the minimal interface required for contexts that are +/// transmitted between peers as part of an RPC call. +/// +/// Implementations of this trait represent the *wire-level* context: a +/// portable, serializable view of request-scoped metadata such as deadlines +/// and tracing information. Service implementations are free to define +/// richer context types as part of their contract, for example to implement: +/// +/// - ephemeral sessions +/// - authentication / authorization data +/// - cookie-like state +/// - other application-specific metadata +/// +/// while preserving a common core required by the RPC runtime. +pub trait SharedContext { + /// Returns the absolute deadline for the request. + /// + /// The deadline represents the latest instant at which the request + /// should still be processed. RPC runtimes and middleware may use this + /// value to enforce timeouts, cancel in-flight work, or reject requests + /// that have already expired. + fn deadline(&self) -> Instant; + + /// Returns the distributed tracing context associated with the request. + /// + /// This context is propagated across RPC boundaries to enable + /// end-to-end request tracing and correlation. + //TODO: May want to remove this in the long run from the default context, may need https://github.com/rust-lang/rust/issues/144361 for that. + fn trace_context(&self) -> trace::Context; + + /// Updates the distributed tracing context. + /// + /// This is typically used by middleware to attach spans, propagate + /// trace identifiers, or replace the tracing context after deserialization. + //TODO: May want to remove this in the long run from the default context, may need https://github.com/rust-lang/rust/issues/144361 for that. + fn set_trace_context(&mut self, trace_context: trace::Context); +} + +impl SharedContext for DefaultContext { + fn deadline(&self) -> Instant { + self.deadline + } + + fn trace_context(&self) -> trace::Context { + self.trace_context + } + + fn set_trace_context(&mut self, trace_context: trace::Context) { + self.trace_context = trace_context; + } +} + +/// Extracts a wire-level shared context contained within a client or server context. +/// +/// `ExtractContext` defines a mapping between an internal +/// context representation and a *shared* context type (`Ctx`) that is +/// suitable for serialization and transmission over the wire. +/// +/// The shared context typically represents the minimal, stable data +/// exchanged between the client and server, while the +/// implementing type may contain additional, local side only state or +/// a different internal structure. +/// +/// # Design notes +/// +/// If a type implements both UpdateContext and ExtractContext, it is expected that +/// `foo.update(v).extract() == v` will hold. +// TODO: Revisit this trait once try_as_dyn is stabilized, https://github.com/rust-lang/rust/issues/29661. +pub trait ExtractContext { + /// Extracts the inner context from the internal state. + /// + /// This method is typically called before sending the context over + /// the wire, or just after receiving it. The returned value should contain + /// all information required by the remote side to reconstruct or update its own + /// local context. + fn extract(&self) -> Ctx; + +} + +/// Updates a wire-level shared context contained within a client context. +/// +/// `ExtractContext` defines a mapping between a *shared* context type (`Ctx`) +/// and an internal context representation +/// +/// The shared context typically represents the minimal, stable data +/// exchanged between the client and server, while the +/// implementing type may contain additional, local side only state or +/// a different internal structure. +/// +/// # Design notes +/// +/// It is expected that `ctx.update(shared_ctx).extract() == shared_ctx` will always hold. +/// +// TODO: Revisit this trait once try_as_dyn is stabilized, https://github.com/rust-lang/rust/issues/29661. +pub trait UpdateContext: ExtractContext { + /// Updates the internal state from an inner context value. + /// + /// This method is typically called after executing a request and before + /// sending the updated context over the wire. Implementations should apply + /// the provided value to their internal representation, updating any derived or + /// local-only state as necessary. + fn update(&mut self, value: Ctx); +} + +impl ExtractContext for T +where + T: Clone, +{ + fn extract(&self) -> T { + self.clone() + } +} + +impl> UpdateContext for T { + fn update(&mut self, value: T) { + *self = value + } +} + + + #[cfg(feature = "serde1")] mod absolute_to_relative_time { pub use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -91,15 +213,15 @@ mod absolute_to_relative_time { } } -assert_impl_all!(Context: Send, Sync); +assert_impl_all!(DefaultContext: Send, Sync); fn ten_seconds_from_now() -> Instant { Instant::now() + Duration::from_secs(10) } /// Returns the context for the current request, or a default Context if no request is active. -pub fn current() -> Context { - Context::current() +pub fn current() -> DefaultContext { + DefaultContext::current() } #[derive(Clone)] @@ -111,7 +233,7 @@ impl Default for Deadline { } } -impl Context { +impl DefaultContext { /// Returns the context for the current request, or a default Context if no request is active. pub fn current() -> Self { let span = tracing::Span::current(); @@ -137,21 +259,23 @@ impl Context { pub(crate) trait SpanExt { /// Sets the given context on this span. Newly-created spans will be children of the given /// context's trace context. - fn set_context(&self, context: &Context); + fn set_context(&self, context: &T); } impl SpanExt for tracing::Span { - fn set_context(&self, context: &Context) { + fn set_context(&self, context: &T) { self.set_parent( opentelemetry::Context::new() .with_remote_span_context(opentelemetry::trace::SpanContext::new( - opentelemetry::trace::TraceId::from(context.trace_context.trace_id), - opentelemetry::trace::SpanId::from(context.trace_context.span_id), - opentelemetry::trace::TraceFlags::from(context.trace_context.sampling_decision), + opentelemetry::trace::TraceId::from(context.trace_context().trace_id), + opentelemetry::trace::SpanId::from(context.trace_context().span_id), + opentelemetry::trace::TraceFlags::from( + context.trace_context().sampling_decision, + ), true, opentelemetry::trace::TraceState::default(), )) - .with_value(Deadline(context.deadline)), + .with_value(Deadline(context.deadline())), ); } } diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 7e1944305..cff903110 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -124,8 +124,9 @@ //! struct HelloServer; //! //! impl World for HelloServer { +//! type Context = context::DefaultContext; //! // Each defined rpc generates an async fn that serves the RPC -//! async fn hello(self, _: context::Context, name: String) -> String { +//! async fn hello(self, _: &mut Self::Context, name: String) -> String { //! format!("Hello, {name}!") //! } //! } @@ -157,8 +158,10 @@ //! # #[derive(Clone)] //! # struct HelloServer; //! # impl World for HelloServer { -//! // Each defined rpc generates an async fn that serves the RPC -//! # async fn hello(self, _: context::Context, name: String) -> String { +//! # type Context = context::DefaultContext; +//! # +//! # // Each defined rpc generates an async fn that serves the RPC +//! # async fn hello(self, _: &mut Self::Context, name: String) -> String { //! # format!("Hello, {name}!") //! # } //! # } @@ -168,7 +171,6 @@ //! #[tokio::main] //! async fn main() -> anyhow::Result<()> { //! let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); -//! //! let server = server::BaseChannel::with_defaults(server_transport); //! tokio::spawn( //! server.execute(HelloServer.serve()) @@ -184,7 +186,8 @@ //! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same //! // args as defined, with the addition of a Context, which is always the first arg. The Context //! // specifies a deadline and trace information which can be helpful in debugging requests. -//! let hello = client.hello(context::current(), "Stim".to_string()).await?; +//! let mut context = context::current(); +//! let hello = client.hello(&mut context, "Stim".to_string()).await?; //! //! println!("{hello}"); //! @@ -197,7 +200,7 @@ //! Use `cargo doc` as you normally would to see the documentation created for all //! items expanded by a `service!` invocation. -#![deny(missing_docs)] +#![deny(missing_docs, warnings, unused, dead_code)] #![allow(clippy::type_complexity)] #![cfg_attr(docsrs, feature(doc_cfg))] @@ -250,17 +253,17 @@ pub(crate) mod util; pub use crate::transport::sealed::Transport; -use std::{any::Any, error::Error, io, sync::Arc, time::Instant}; +use std::{any::Any, error::Error, io, sync::Arc}; /// A message from a client to a server. #[derive(Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[non_exhaustive] -pub enum ClientMessage { +pub enum ClientMessage { /// A request initiated by a user. The server responds to a request by invoking a /// service-provided request handler. The handler completes with a [`response`](Response), which /// the server sends back to the client. - Request(Request), + Request(Request), /// A command to cancel an in-flight request, automatically sent by the client when a response /// future is dropped. /// @@ -279,15 +282,15 @@ pub enum ClientMessage { } /// A request from a client to a server. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Request { +pub struct Request { /// Trace context, deadline, and other cross-cutting concerns. - pub context: context::Context, + pub context: Ctx, /// Uniquely identifies the request across all requests sent over a single channel. pub id: u64, /// The request body. - pub message: T, + pub message: Req, } /// Implemented by the request types generated by tarpc::service. @@ -360,13 +363,14 @@ impl RequestName for u64 { /// A response from a server to a client. #[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Response { +pub struct Response { /// The ID of the request being responded to. pub request_id: u64, + /// Trace context, deadline, and other cross-cutting concerns. + pub context: Ctx, /// The response body, or an error if the request failed. pub message: Result, } - /// An error indicating the server aborted the request early, e.g., due to request throttling. #[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)] #[error("{kind:?}: {detail}")] @@ -489,14 +493,6 @@ impl ServerError { Self { kind, detail } } } - -impl Request { - /// Returns the deadline for this request. - pub fn deadline(&self) -> &Instant { - &self.context.deadline - } -} - #[test] fn test_channel_any_casts() { use assert_matches::assert_matches; diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index da3b3ae21..47c8476c4 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -6,10 +6,11 @@ //! Provides a server that concurrently handles many connections sending multiplexed requests. +use crate::context::{ExtractContext, SharedContext}; use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, - context::{self, SpanExt}, + context::SpanExt, trace, util::TimeUntil, }; @@ -58,9 +59,14 @@ impl Default for Config { impl Config { /// Returns a channel backed by `transport` and configured with `self`. - pub fn channel(self, transport: T) -> BaseChannel + pub fn channel( + self, + transport: T, + ) -> BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, + ServerCtx: ExtractContext, + SharedCtx: SharedContext, { BaseChannel::new(self, transport) } @@ -69,6 +75,9 @@ impl Config { /// Equivalent to a `FnOnce(Req) -> impl Future`. #[allow(async_fn_in_trait)] pub trait Serve { + ///TODO document + type ServerCtx; + /// Type of request. type Req: RequestName; @@ -76,17 +85,21 @@ pub trait Serve { type Resp; /// Responds to a single request. - async fn serve(self, ctx: context::Context, req: Self::Req) -> Result; + async fn serve( + self, + ctx: &mut Self::ServerCtx, + req: Self::Req, + ) -> Result; } /// A Serve wrapper around a Fn. #[derive(Debug)] -pub struct ServeFn { +pub struct ServeFn { f: F, - data: PhantomData Resp>, + data: PhantomData<(Req, Resp, ServerCtx)>, } -impl Clone for ServeFn +impl Clone for ServeFn where F: Clone, { @@ -98,14 +111,17 @@ where } } -impl Copy for ServeFn where F: Copy {} +impl Copy for ServeFn where F: Copy {} /// Creates a [`Serve`] wrapper around a `FnOnce(context::Context, Req) -> impl Future>`. -pub fn serve(f: F) -> ServeFn +pub fn serve(f: F) -> ServeFn where - F: FnOnce(context::Context, Req) -> Fut, - Fut: Future>, + // This should be -> impl Future<...>, but there is no syntax to express the 'a lifetime. + for<'a> F: FnOnce( + &'a mut ServerCtx, + Req, + ) -> Pin> + 'a + Send>>, { ServeFn { f, @@ -113,16 +129,20 @@ where } } -impl Serve for ServeFn +impl Serve for ServeFn where Req: RequestName, - F: FnOnce(context::Context, Req) -> Fut, - Fut: Future>, + // This should be -> impl Future<...>, but there is no syntax to express the 'a lifetime. + for<'a> F: FnOnce( + &'a mut ServerCtx, + Req, + ) -> Pin> + 'a + Send>>, { + type ServerCtx = ServerCtx; type Req = Req; type Resp = Resp; - async fn serve(self, ctx: context::Context, req: Req) -> Result { + async fn serve(self, ctx: &mut ServerCtx, req: Req) -> Result { (self.f)(ctx, req).await } } @@ -138,7 +158,7 @@ where /// messages. Instead, it internally handles them by cancelling corresponding requests (removing /// the corresponding in-flight requests and aborting their handlers). #[pin_project] -pub struct BaseChannel { +pub struct BaseChannel { config: Config, /// Writes responses to the wire and reads requests off the wire. #[pin] @@ -151,12 +171,14 @@ pub struct BaseChannel { /// Holds data necessary to clean up in-flight requests. in_flight_requests: InFlightRequests, /// Types the request and response. - ghost: PhantomData<(fn() -> Req, fn(Resp))>, + ghost: PhantomData<(Req, Resp, ServerCtx, SharedCtx)>, } -impl BaseChannel +impl BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, + ServerCtx: ExtractContext, + SharedCtx: SharedContext, { /// Creates a new channel backed by `transport` and configured with `config`. pub fn new(config: Config, transport: T) -> Self { @@ -202,28 +224,29 @@ where fn start_request( mut self: Pin<&mut Self>, - mut request: Request, - ) -> Result, AlreadyExistsError> { + request: Request, + ) -> Result, AlreadyExistsError> { + let mut shared_context = request.context.extract(); let span = info_span!( "RPC", - rpc.trace_id = %request.context.trace_id(), - rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + request.context.deadline.time_until()), + rpc.trace_id = %shared_context.trace_context().trace_id, + rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + shared_context.deadline().time_until()), otel.kind = "server", otel.name = tracing::field::Empty, ); - span.set_context(&request.context); - request.context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { + span.set_context(&shared_context); + shared_context.set_trace_context(trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::trace!( "OpenTelemetry subscriber not installed; making unsampled \ child context." ); - request.context.trace_context.new_child() - }); + shared_context.trace_context().new_child() + })); let entered = span.enter(); tracing::debug!("ReceiveRequest"); let start = self.in_flight_requests_mut().start_request( request.id, - request.context.deadline, + shared_context.deadline(), span.clone(), ); match start { @@ -248,7 +271,9 @@ where } } -impl fmt::Debug for BaseChannel { +impl fmt::Debug + for BaseChannel +{ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "BaseChannel") } @@ -256,9 +281,9 @@ impl fmt::Debug for BaseChannel { /// A request tracked by a [`Channel`]. #[derive(Debug)] -pub struct TrackedRequest { +pub struct TrackedRequest { /// The request sent by the client. - pub request: Request, + pub request: Request, /// A registration to abort a future when the [`Channel`] that produced this request stops /// tracking it. pub abort_registration: AbortRegistration, @@ -295,7 +320,10 @@ pub struct TrackedRequest { /// created by [`BaseChannel`]. pub trait Channel where - Self: Transport::Resp>, TrackedRequest<::Req>>, + Self: Transport< + Response::Resp>, + TrackedRequest::Req>, + >, { /// Type of request item. type Req; @@ -305,6 +333,8 @@ where /// The wrapped transport. type Transport; + ///TODO document + type ServerCtx; /// Configuration of the channel. fn config(&self) -> &Config; @@ -343,6 +373,7 @@ where /// /// ```rust /// use tarpc::{ + /// ClientMessage, /// context, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, serve}, @@ -360,10 +391,11 @@ where /// let mut requests = server.requests(); /// tokio::spawn(async move { /// while let Some(Ok(request)) = requests.next().await { - /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }))); + /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed()))); /// } /// }); - /// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + /// let mut context = context::current(); + /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` fn requests(self) -> Requests @@ -386,7 +418,7 @@ where /// # Example /// /// ```rust - /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; + /// use tarpc::{ClientMessage, context, client, server::{self, BaseChannel, Channel, serve}, transport}; /// use futures::prelude::*; /// use tracing_subscriber::prelude::*; /// @@ -399,12 +431,13 @@ where /// let client = client::new(client::Config::default(), tx).spawn(); /// let channel = BaseChannel::with_defaults(rx); /// tokio::spawn( - /// channel.execute(serve(|_, i: i32| async move { Ok(i + 1) })) + /// channel.execute(serve(|_, i: i32| async move { Ok(i + 1) }.boxed())) /// .for_each(|response| async move { /// tokio::spawn(response); - /// })); + /// }.boxed())); + /// let mut context = context::current(); /// assert_eq!( - /// client.call(context::current(), 1).await.unwrap(), + /// client.call(&mut context, 1).await.unwrap(), /// 2); /// } /// ``` @@ -412,17 +445,19 @@ where where Self: Sized, Self::Req: RequestName, - S: Serve + Clone, + S: Serve + Clone, { self.requests().execute(serve) } } -impl Stream for BaseChannel +impl Stream for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, + ServerCtx: ExtractContext, + SharedCtx: SharedContext, { - type Item = Result, ChannelError>; + type Item = Result, ChannelError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { #[derive(Clone, Copy, Debug)] @@ -525,10 +560,13 @@ where } } -impl Sink> for BaseChannel +impl Sink> + for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, T::Error: Error, + ServerCtx: ExtractContext, + SharedCtx: SharedContext, { type Error = ChannelError; @@ -539,7 +577,10 @@ where .map_err(|e| ChannelError::Ready(Arc::new(e))) } - fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { + fn start_send( + mut self: Pin<&mut Self>, + response: Response, + ) -> Result<(), Self::Error> { if let Some(span) = self .in_flight_requests_mut() .remove_request(response.request_id) @@ -572,19 +613,24 @@ where } } -impl AsRef for BaseChannel { +impl AsRef + for BaseChannel +{ fn as_ref(&self) -> &T { self.transport.get_ref() } } -impl Channel for BaseChannel +impl Channel for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, + ServerCtx: ExtractContext, + SharedCtx: SharedContext, { type Req = Req; type Resp = Resp; type Transport = T; + type ServerCtx = ServerCtx; fn config(&self) -> &Config { &self.config @@ -609,9 +655,9 @@ where #[pin] channel: C, /// Responses waiting to be written to the wire. - pending_responses: mpsc::Receiver>, + pending_responses: mpsc::Receiver>, /// Handed out to request handlers to fan in responses. - responses_tx: mpsc::Sender>, + responses_tx: mpsc::Sender>, } impl Requests @@ -631,14 +677,14 @@ where /// Returns the inner channel over which messages are sent and received. pub fn pending_responses_mut<'a>( self: &'a mut Pin<&mut Self>, - ) -> &'a mut mpsc::Receiver> { + ) -> &'a mut mpsc::Receiver> { self.as_mut().project().pending_responses } fn pump_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll, C::Error>>> { + ) -> Poll, C::Error>>> { self.channel_pin_mut().poll_next(cx).map_ok( |TrackedRequest { request, @@ -703,7 +749,7 @@ where fn poll_next_response( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll, C::Error>>> { + ) -> Poll, C::Error>>> { ready!(self.ensure_writeable(cx)?); match ready!(self.pending_responses_mut().poll_recv(cx)) { @@ -736,7 +782,7 @@ where /// # Example /// /// ```rust - /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; + /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport, ClientMessage}; /// use futures::prelude::*; /// /// # #[cfg(not(feature = "tokio1"))] @@ -748,17 +794,18 @@ where /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); /// let client = client::new(client::Config::default(), tx).spawn(); /// tokio::spawn( - /// requests.execute(serve(|_, i| async move { Ok(i + 1) })) + /// requests.execute(serve(|_, i| async move { Ok(i + 1) }.boxed())) /// .for_each(|response| async move { /// tokio::spawn(response); - /// })); - /// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + /// }.boxed())); + /// let mut context = context::current(); + /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` pub fn execute(self, serve: S) -> impl Stream> where C::Req: RequestName, - S: Serve + Clone, + S: Serve + Clone, { self.take_while(|result| { if let Err(e) = result { @@ -805,17 +852,17 @@ impl Drop for ResponseGuard { /// If dropped without calling [`execute`](InFlightRequest::execute), a cancellation message will /// be sent to the Channel to clean up associated request state. #[derive(Debug)] -pub struct InFlightRequest { - request: Request, +pub struct InFlightRequest { + request: Request, abort_registration: AbortRegistration, response_guard: ResponseGuard, span: Span, - response_tx: mpsc::Sender>, + response_tx: mpsc::Sender>, } -impl InFlightRequest { +impl InFlightRequest { /// Returns a reference to the request. - pub fn get(&self) -> &Request { + pub fn get(&self) -> &Request { &self.request } @@ -828,7 +875,7 @@ impl InFlightRequest { /// /// 1. The channel that yielded this request receives a [cancellation /// message](ClientMessage::Cancel) for this request. - /// 2. The request [deadline](crate::context::Context::deadline) is reached. + /// 2. The request [deadline](crate::context::DefaultContext::deadline) is reached. /// 3. The service function completes. /// /// If the returned Future is dropped before completion, a cancellation message will be sent to @@ -838,6 +885,7 @@ impl InFlightRequest { /// /// ```rust /// use tarpc::{ + /// ClientMessage, /// context, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, serve}, @@ -855,18 +903,18 @@ impl InFlightRequest { /// tokio::spawn(async move { /// let mut requests = server.requests(); /// while let Some(Ok(in_flight_request)) = requests.next().await { - /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) })).await; + /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed())).await; /// } - /// /// }); - /// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + /// let mut context = context::current(); + /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` /// pub async fn execute(self, serve: S) where Req: RequestName, - S: Serve, + S: Serve, { let Self { response_tx, @@ -875,7 +923,7 @@ impl InFlightRequest { span, request: Request { - context, + mut context, message, id: request_id, }, @@ -883,10 +931,11 @@ impl InFlightRequest { span.record("otel.name", message.name()); let _ = Abortable::new( async move { - let message = serve.serve(context, message).await; + let message = serve.serve(&mut context, message).await; tracing::debug!("CompleteRequest"); let response = Response { request_id, + context, message, }; let _ = response_tx.send(response).await; @@ -914,7 +963,7 @@ impl Stream for Requests where C: Channel, { - type Item = Result, C::Error>; + type Item = Result, C::Error>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { @@ -960,6 +1009,7 @@ mod tests { request_hook::{AfterRequest, BeforeRequest, RequestHook}, serve, }; + use crate::context::{DefaultContext}; use crate::{ ClientMessage, Request, Response, ServerError, context, trace, transport::channel::{self, UnboundedChannel}, @@ -979,8 +1029,24 @@ mod tests { }; fn test_channel() -> ( - Pin, Response>>>>, - UnboundedChannel, ClientMessage>, + Pin< + Box< + BaseChannel< + Req, + Resp, + UnboundedChannel< + ClientMessage, + Response, + >, + context::DefaultContext, + context::DefaultContext, + >, + >, + >, + UnboundedChannel< + Response, + ClientMessage, + >, ) { let (tx, rx) = crate::transport::channel::unbounded(); (Box::pin(BaseChannel::new(Config::default(), rx)), tx) @@ -990,11 +1056,23 @@ mod tests { Pin< Box< Requests< - BaseChannel, Response>>, + BaseChannel< + Req, + Resp, + UnboundedChannel< + ClientMessage, + Response, + >, + context::DefaultContext, + context::DefaultContext, + >, >, >, >, - UnboundedChannel, ClientMessage>, + UnboundedChannel< + Response, + ClientMessage, + >, ) { let (tx, rx) = crate::transport::channel::unbounded(); ( @@ -1009,11 +1087,23 @@ mod tests { Pin< Box< Requests< - BaseChannel, Response>>, + BaseChannel< + Req, + Resp, + channel::Channel< + ClientMessage, + Response, + >, + context::DefaultContext, + context::DefaultContext, + >, >, >, >, - channel::Channel, ClientMessage>, + channel::Channel< + Response, + ClientMessage, + >, ) { let (tx, rx) = crate::transport::channel::bounded(capacity); // Add 1 because capacity 0 is not supported (but is supported by transport::channel::bounded). @@ -1023,7 +1113,7 @@ mod tests { (Box::pin(BaseChannel::new(config, rx).requests()), tx) } - fn fake_request(req: Req) -> ClientMessage { + fn fake_request(req: Req) -> ClientMessage { ClientMessage::Request(Request { context: context::current(), id: 0, @@ -1039,19 +1129,16 @@ mod tests { #[tokio::test] async fn test_serve() { - let serve = serve(|_, i| async move { Ok(i) }); - assert_matches!(serve.serve(context::current(), 7).await, Ok(7)); + let serve = serve(|_, i| async move { Ok(i) }.boxed()); + assert_matches!(serve.serve(&mut context::current(), 7).await, Ok(7)); } #[tokio::test] async fn serve_before_mutates_context() -> anyhow::Result<()> { struct SetDeadline(Instant); - impl BeforeRequest for SetDeadline { - async fn before( - &mut self, - ctx: &mut context::Context, - _: &Req, - ) -> Result<(), ServerError> { + impl BeforeRequest for SetDeadline + { + async fn before(&mut self, ctx: &mut DefaultContext, _: &Req) -> Result<(), ServerError> { ctx.deadline = self.0; Ok(()) } @@ -1060,14 +1147,17 @@ mod tests { let some_time = Instant::now() + Duration::from_secs(37); let some_other_time = Instant::now() + Duration::from_secs(83); - let serve = serve(move |ctx: context::Context, i| async move { - assert_eq!(ctx.deadline, some_time); - Ok(i) + let serve = serve(move |ctx: &mut context::DefaultContext, i| { + async move { + assert_eq!(ctx.deadline, some_time); + Ok(i) + } + .boxed() }); let deadline_hook = serve.before(SetDeadline(some_time)); let mut ctx = context::current(); ctx.deadline = some_other_time; - deadline_hook.serve(ctx, 7).await?; + deadline_hook.serve(&mut ctx, 7).await?; Ok(()) } @@ -1085,37 +1175,33 @@ mod tests { } } } - impl BeforeRequest for PrintLatency { - async fn before( - &mut self, - _: &mut context::Context, - _: &Req, - ) -> Result<(), ServerError> { + impl BeforeRequest for PrintLatency { + async fn before(&mut self, _: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { self.start = Instant::now(); Ok(()) } } - impl AfterRequest for PrintLatency { - async fn after(&mut self, _: &mut context::Context, _: &mut Result) { + impl AfterRequest for PrintLatency { + async fn after(&mut self, _: &mut ServerCtx, _: &mut Result) { tracing::debug!("Elapsed: {:?}", self.start.elapsed()); } } - let serve = serve(move |_: context::Context, i| async move { Ok(i) }); + let serve = serve(move |_: &mut context::DefaultContext, i| async move { Ok(i) }.boxed()); serve .before_and_after(PrintLatency::new()) - .serve(context::current(), 7) + .serve(&mut context::current(), 7) .await?; Ok(()) } #[tokio::test] async fn serve_before_error_aborts_request() -> anyhow::Result<()> { - let serve = serve(|_, _| async { panic!("Shouldn't get here") }); - let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async { + let serve = serve(|_, _| async { panic!("Shouldn't get here") }.boxed()); + let deadline_hook = serve.before(|_: &mut context::DefaultContext, _: &i32| async { Err(ServerError::new(io::ErrorKind::Other, "oops".into())) }); - let resp: Result = deadline_hook.serve(context::current(), 7).await; + let resp: Result = deadline_hook.serve(&mut context::current(), 7).await; assert_matches!(resp, Err(_)); Ok(()) } @@ -1285,6 +1371,7 @@ mod tests { .as_mut() .start_send(Response { request_id: 0, + context: context::current(), message: Ok(()), }) .unwrap(); @@ -1320,7 +1407,9 @@ mod tests { Poll::Ready(Some(Ok(request))) => request, result => panic!("Unexpected result: {result:?}"), }; - request.execute(serve(|_, _| async { Ok(()) })).await; + request + .execute(serve(|_, _| async { Ok(()) }.boxed())) + .await; assert!( requests .as_mut() @@ -1350,6 +1439,7 @@ mod tests { .channel_pin_mut() .start_send(Response { request_id: 0, + context: context::current(), message: Ok(()), }) .unwrap(); @@ -1361,6 +1451,7 @@ mod tests { .responses_tx .send(Response { request_id: 1, + context: context::current(), message: Ok(()), }) .await @@ -1401,6 +1492,7 @@ mod tests { .channel_pin_mut() .start_send(Response { request_id: 0, + context: context::current(), message: Ok(()), }) .unwrap(); @@ -1421,6 +1513,7 @@ mod tests { .responses_tx .send(Response { request_id: 1, + context: context::current(), message: Ok(()), }) .await diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 428eb1a7d..568ae4495 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -33,7 +33,7 @@ where ) -> impl Stream>> where C::Req: RequestName, - S: Serve + Clone, + S: Serve + Clone, { self.map(move |channel| channel.execute(serve.clone())) } @@ -48,6 +48,7 @@ where /// # Example /// ```rust /// use tarpc::{ +/// ClientMessage, /// context, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, incoming::{Incoming, spawn_incoming}, serve}, @@ -57,15 +58,17 @@ where /// /// #[tokio::main] /// async fn main() { -/// let (tx, rx) = transport::channel::unbounded(); +/// use tracing_subscriber::filter::FilterExt; +/// let (tx, rx) = transport::channel::unbounded(); /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); /// tokio::spawn(dispatch); /// /// let incoming = stream::once(async move { /// BaseChannel::new(server::Config::default(), rx) -/// }).execute(serve(|_, i| async move { Ok(i + 1) })); +/// }).execute(serve(|_, i| async move { Ok(i + 1) }.boxed())); /// tokio::spawn(spawn_incoming(incoming)); -/// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); +/// let mut context = context::current(); +/// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` pub async fn spawn_incoming( diff --git a/tarpc/src/server/limits/channels_per_key.rs b/tarpc/src/server/limits/channels_per_key.rs index 64b644278..3ffdfac89 100644 --- a/tarpc/src/server/limits/channels_per_key.rs +++ b/tarpc/src/server/limits/channels_per_key.rs @@ -107,6 +107,7 @@ where type Req = C::Req; type Resp = C::Resp; type Transport = C::Transport; + type ServerCtx = C::ServerCtx; fn config(&self) -> &server::Config { self.inner.config() diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index bd9c103b0..3fa81e580 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -67,6 +67,7 @@ where self.as_mut().start_send(Response { request_id: r.request.id, + context: r.request.context, message: Err(ServerError { kind: io::ErrorKind::WouldBlock, detail: "server throttled the request.".into(), @@ -80,7 +81,7 @@ where } } -impl Sink::Resp>> for MaxRequests +impl Sink::Resp>> for MaxRequests where C: Channel, { @@ -92,7 +93,7 @@ where fn start_send( self: Pin<&mut Self>, - item: Response<::Resp>, + item: Response::Resp>, ) -> Result<(), Self::Error> { self.project().inner.start_send(item) } @@ -119,6 +120,7 @@ where type Req = ::Req; type Resp = ::Resp; type Transport = ::Transport; + type ServerCtx = ::ServerCtx; fn in_flight_requests(&self) -> usize { self.inner.in_flight_requests() @@ -178,6 +180,7 @@ where mod tests { use super::*; + use crate::context; use crate::server::{ TrackedRequest, testing::{self, FakeChannel, PollExt}, @@ -267,8 +270,10 @@ mod tests { ghost: PhantomData In>, } impl PendingSink<(), ()> { - pub fn default() - -> PendingSink>, Response> { + pub fn default() -> PendingSink< + io::Result>, + Response, + > { PendingSink { ghost: PhantomData } } } @@ -293,10 +298,16 @@ mod tests { Poll::Pending } } - impl Channel for PendingSink>, Response> { + impl Channel + for PendingSink< + io::Result>, + Response, + > + { type Req = Req; type Resp = Resp; type Transport = (); + type ServerCtx = context::DefaultContext; fn config(&self) -> &Config { unimplemented!() } @@ -326,16 +337,16 @@ mod tests { .as_mut() .start_send(Response { request_id: 0, + context: context::current(), message: Ok(1), }) .unwrap(); assert_eq!(throttler.inner.in_flight_requests.len(), 0); - assert_eq!( - throttler.inner.sink.front(), - Some(&Response { - request_id: 0, - message: Ok(1), - }) - ); + + let result = throttler.inner.sink.front(); + + assert_eq!(result.map(|r| r.request_id), Some(0)); + + assert_eq!(result.map(|r| &r.message), Some(&Ok(1))); } } diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs index 66cf2878c..afc9ad25a 100644 --- a/tarpc/src/server/request_hook.rs +++ b/tarpc/src/server/request_hook.rs @@ -43,12 +43,12 @@ pub trait RequestHook: Serve { /// # Example /// /// ```rust - /// use futures::{executor::block_on, future}; + /// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{context, ServerError, server::{Serve, request_hook::RequestHook, serve}}; /// use std::io; /// - /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }) - /// .before(|_ctx: &mut context::Context, req: &i32| { + /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }.boxed()) + /// .before(|_ctx: &mut context::DefaultContext, req: &i32| { /// future::ready( /// if *req == 1 { /// Err(ServerError::new( @@ -58,12 +58,13 @@ pub trait RequestHook: Serve { /// Ok(()) /// }) /// }); - /// let response = serve.serve(context::current(), 1); + /// let mut context = context::current(); + /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` - fn before(self, hook: Hook) -> HookThenServe + fn before(self, hook: Hook) -> HookThenServe where - Hook: BeforeRequest, + Hook: BeforeRequest, Self: Sized, { HookThenServe::new(self, hook) @@ -80,7 +81,7 @@ pub trait RequestHook: Serve { /// # Example /// /// ```rust - /// use futures::{executor::block_on, future}; + /// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{context, ServerError, server::{Serve, request_hook::RequestHook, serve}}; /// use std::io; /// @@ -93,20 +94,20 @@ pub trait RequestHook: Serve { /// } else { /// Ok(i + 1) /// } - /// }) - /// .after(|_ctx: &mut context::Context, resp: &mut Result| { + /// }.boxed()) + /// .after(|_ctx: &mut context::DefaultContext, resp: &mut Result| { /// if let Err(e) = resp { /// eprintln!("server error: {e:?}"); /// } /// future::ready(()) /// }); - /// - /// let response = serve.serve(context::current(), 1); + /// let mut context = context::current(); + /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` fn after(self, hook: Hook) -> ServeThenHook where - Hook: AfterRequest, + Hook: AfterRequest, Self: Sized, { ServeThenHook::new(self, hook) @@ -123,7 +124,7 @@ pub trait RequestHook: Serve { /// # Example /// /// ```rust - /// use futures::{executor::block_on, future}; + /// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{ /// context, ServerError, /// server::{Serve, serve, request_hook::{BeforeRequest, AfterRequest, RequestHook}} @@ -132,17 +133,17 @@ pub trait RequestHook: Serve { /// /// struct PrintLatency(Instant); /// - /// impl BeforeRequest for PrintLatency { - /// async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> { + /// impl BeforeRequest for PrintLatency { + /// async fn before(&mut self, _: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { /// self.0 = Instant::now(); /// Ok(()) /// } /// } /// - /// impl AfterRequest for PrintLatency { + /// impl AfterRequest for PrintLatency { /// async fn after( /// &mut self, - /// _: &mut context::Context, + /// _: &mut ServerCtx, /// _: &mut Result, /// ) { /// tracing::info!("Elapsed: {:?}", self.0.elapsed()); @@ -151,16 +152,17 @@ pub trait RequestHook: Serve { /// /// let serve = serve(|_ctx, i| async move { /// Ok(i + 1) - /// }).before_and_after(PrintLatency(Instant::now())); - /// let response = serve.serve(context::current(), 1); + /// }.boxed()).before_and_after(PrintLatency(Instant::now())); + /// let mut context = context::current(); + /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// ``` fn before_and_after( self, hook: Hook, - ) -> HookThenServeThenHook + ) -> HookThenServeThenHook where - Hook: BeforeRequest + AfterRequest, + Hook: BeforeRequest + AfterRequest, Self: Sized, { HookThenServeThenHook::new(self, hook) diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs index b2ef9ccbd..1fa3cee51 100644 --- a/tarpc/src/server/request_hook/after.rs +++ b/tarpc/src/server/request_hook/after.rs @@ -6,24 +6,24 @@ //! Provides a hook that runs after request execution. -use crate::{ServerError, context, server::Serve}; +use crate::{ServerError, server::Serve}; use futures::prelude::*; /// A hook that runs after request execution. #[allow(async_fn_in_trait)] -pub trait AfterRequest { +pub trait AfterRequest { /// The function that is called after request execution. /// /// The hook can modify the request context and the response. - async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result); + async fn after(&mut self, ctx: &mut ServerCtx, resp: &mut Result); } -impl AfterRequest for F +impl AfterRequest for F where - F: FnMut(&mut context::Context, &mut Result) -> Fut, + F: FnMut(&mut ServerCtx, &mut Result) -> Fut, Fut: Future, { - async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result) { + async fn after(&mut self, ctx: &mut ServerCtx, resp: &mut Result) { self(ctx, resp).await } } @@ -52,21 +52,22 @@ impl Clone for ServeThenHook { impl Serve for ServeThenHook where Serv: Serve, - Hook: AfterRequest, + Hook: AfterRequest, { type Req = Serv::Req; type Resp = Serv::Resp; + type ServerCtx = Serv::ServerCtx; async fn serve( self, - mut ctx: context::Context, + ctx: &mut Serv::ServerCtx, req: Serv::Req, ) -> Result { let ServeThenHook { serve, mut hook, .. } = self; let mut resp = serve.serve(ctx, req).await; - hook.after(&mut ctx, &mut resp).await; + hook.after(ctx, &mut resp).await; resp } } diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index e72e28a42..adfac8e79 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -6,12 +6,13 @@ //! Provides a hook that runs before request execution. -use crate::{ServerError, context, server::Serve}; +use crate::{ServerError, server::Serve}; use futures::prelude::*; +use std::marker::PhantomData; /// A hook that runs before request execution. #[allow(async_fn_in_trait)] -pub trait BeforeRequest { +pub trait BeforeRequest { /// The function that is called before request execution. /// /// If this function returns an error, the request will not be executed and the error will be @@ -19,22 +20,22 @@ pub trait BeforeRequest { /// /// This function can also modify the request context. This could be used, for example, to /// enforce a maximum deadline on all requests. - async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError>; + async fn before(&mut self, ctx: &mut ServerCtx, req: &Req) -> Result<(), ServerError>; } /// A list of hooks that run in order before request execution. -pub trait BeforeRequestList: BeforeRequest { +pub trait BeforeRequestList: BeforeRequest { /// The hook returned by `BeforeRequestList::then`. - type Then: BeforeRequest + type Then: BeforeRequest where - Next: BeforeRequest; + Next: BeforeRequest; /// Returns a hook that, when run, runs two hooks, first `self` and then `next`. - fn then>(self, next: Next) -> Self::Then; + fn then>(self, next: Next) -> Self::Then; /// Same as `then`, but helps the compiler with type inference when Next is a closure. fn then_fn< - Next: FnMut(&mut context::Context, &Req) -> Fut, + Next: FnMut(&mut ServerCtx, &Req) -> Fut, Fut: Future>, >( self, @@ -47,53 +48,60 @@ pub trait BeforeRequestList: BeforeRequest { } /// The service fn returned by `BeforeRequestList::serving`. - type Serve>: Serve; + type Serve>: Serve; /// Runs the list of request hooks before execution of the given serve fn. /// This is equivalent to `serve.before(before_request_chain)` but may be syntactically nicer. - fn serving>(self, serve: S) -> Self::Serve; + fn serving>(self, serve: S) -> Self::Serve; } -impl BeforeRequest for F +impl BeforeRequest for F where - F: FnMut(&mut context::Context, &Req) -> Fut, + F: FnMut(&mut ServerCtx, &Req) -> Fut, Fut: Future>, { - async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> { + async fn before(&mut self, ctx: &mut ServerCtx, req: &Req) -> Result<(), ServerError> { self(ctx, req).await } } /// A Service function that runs a hook before request execution. -#[derive(Clone)] -pub struct HookThenServe { +pub struct HookThenServe { serve: Serv, hook: Hook, + ghost: PhantomData, } -impl HookThenServe { +impl Clone for HookThenServe { + fn clone(&self) -> Self { + Self::new(self.serve.clone(), self.hook.clone()) + } +} + +impl HookThenServe { pub(crate) fn new(serve: Serv, hook: Hook) -> Self { - Self { serve, hook } + Self { + serve, + hook, + ghost: PhantomData, + } } } -impl Serve for HookThenServe +impl Serve for HookThenServe where - Serv: Serve, - Hook: BeforeRequest, + Serv: Serve, + Hook: BeforeRequest, { + type ServerCtx = ServerCtx; type Req = Serv::Req; type Resp = Serv::Resp; - async fn serve( - self, - mut ctx: context::Context, - req: Self::Req, - ) -> Result { + async fn serve(self, ctx: &mut ServerCtx, req: Self::Req) -> Result { let HookThenServe { serve, mut hook, .. } = self; - hook.before(&mut ctx, &req).await?; + hook.before(ctx, &req).await?; serve.serve(ctx, req).await } } @@ -103,7 +111,7 @@ where /// Example /// /// ```rust -/// use futures::{executor::block_on, future}; +/// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{context, ServerError, server::{Serve, serve, request_hook::{self, /// BeforeRequest, BeforeRequestList}}}; /// use std::{cell::Cell, io}; @@ -120,8 +128,9 @@ where /// i.set(2); /// Ok(()) /// }) -/// .serving(serve(|_ctx, i| async move { Ok(i + 1) })); -/// let response = serve.clone().serve(context::current(), 1); +/// .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); +/// let mut context = context::current(); +/// let response = serve.clone().serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// assert!(i.get() == 2); /// ``` @@ -137,10 +146,10 @@ pub struct BeforeRequestCons(First, Rest); #[derive(Clone, Copy)] pub struct BeforeRequestNil; -impl, Rest: BeforeRequest> BeforeRequest - for BeforeRequestCons +impl, Rest: BeforeRequest, ServerCtx> + BeforeRequest for BeforeRequestCons { - async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> { + async fn before(&mut self, ctx: &mut ServerCtx, req: &Req) -> Result<(), ServerError> { let BeforeRequestCons(first, rest) = self; first.before(ctx, req).await?; rest.before(ctx, req).await?; @@ -148,45 +157,45 @@ impl, Rest: BeforeRequest> BeforeRequest BeforeRequest for BeforeRequestNil { - async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> { +impl BeforeRequest for BeforeRequestNil { + async fn before(&mut self, _: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { Ok(()) } } -impl, Rest: BeforeRequestList> BeforeRequestList - for BeforeRequestCons +impl, Rest: BeforeRequestList, ServerCtx> + BeforeRequestList for BeforeRequestCons { type Then = BeforeRequestCons> where - Next: BeforeRequest; + Next: BeforeRequest; - fn then>(self, next: Next) -> Self::Then { + fn then>(self, next: Next) -> Self::Then { let BeforeRequestCons(first, rest) = self; BeforeRequestCons(first, rest.then(next)) } - type Serve> = HookThenServe; + type Serve> = HookThenServe; - fn serving>(self, serve: S) -> Self::Serve { + fn serving>(self, serve: S) -> Self::Serve { HookThenServe::new(serve, self) } } -impl BeforeRequestList for BeforeRequestNil { +impl BeforeRequestList for BeforeRequestNil { type Then = BeforeRequestCons where - Next: BeforeRequest; + Next: BeforeRequest; - fn then>(self, next: Next) -> Self::Then { + fn then>(self, next: Next) -> Self::Then { BeforeRequestCons(next, BeforeRequestNil) } - type Serve> = S; + type Serve> = S; - fn serving>(self, serve: S) -> S { + fn serving>(self, serve: S) -> S { serve } } @@ -209,8 +218,9 @@ fn before_request_list() { i.set(2); Ok(()) }) - .serving(serve(|_ctx, i| async move { Ok(i + 1) })); - let response = serve.clone().serve(context::current(), 1); + .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); + let mut context = crate::context::current(); + let response = serve.clone().serve(&mut context, 1); assert!(block_on(response).is_ok()); assert!(i.get() == 2); } diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index 0761a7df3..f3653a513 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -7,17 +7,17 @@ //! Provides a hook that runs both before and after request execution. use super::{after::AfterRequest, before::BeforeRequest}; -use crate::{RequestName, ServerError, context, server::Serve}; +use crate::{RequestName, ServerError, server::Serve}; use std::marker::PhantomData; /// A Service function that runs a hook both before and after request execution. -pub struct HookThenServeThenHook { +pub struct HookThenServeThenHook { serve: Serv, hook: Hook, - fns: PhantomData<(fn(Req), fn(Resp))>, + fns: PhantomData<(Req, Resp, ServerCtx)>, } -impl HookThenServeThenHook { +impl HookThenServeThenHook { pub(crate) fn new(serve: Serv, hook: Hook) -> Self { Self { serve, @@ -27,7 +27,9 @@ impl HookThenServeThenHook { } } -impl Clone for HookThenServeThenHook { +impl Clone + for HookThenServeThenHook +{ fn clone(&self) -> Self { Self { serve: self.serve.clone(), @@ -37,22 +39,24 @@ impl Clone for HookThenServeThenHook Serve for HookThenServeThenHook +impl Serve + for HookThenServeThenHook where Req: RequestName, - Serv: Serve, - Hook: BeforeRequest + AfterRequest, + Serv: Serve, + Hook: BeforeRequest + AfterRequest, { type Req = Req; type Resp = Resp; + type ServerCtx = ServerCtx; - async fn serve(self, mut ctx: context::Context, req: Req) -> Result { + async fn serve(self, ctx: &mut ServerCtx, req: Req) -> Result { let HookThenServeThenHook { serve, mut hook, .. } = self; - hook.before(&mut ctx, &req).await?; + hook.before(ctx, &req).await?; let mut resp = serve.serve(ctx, req).await; - hook.after(&mut ctx, &mut resp).await; + hook.after(ctx, &mut resp).await; resp } } diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index db167c42e..bbb396c45 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -38,14 +38,19 @@ where } } -impl Sink> for FakeChannel> { +impl Sink> + for FakeChannel> +{ type Error = io::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { self.project().sink.poll_ready(cx).map_err(|e| match e {}) } - fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { + fn start_send( + mut self: Pin<&mut Self>, + response: Response, + ) -> Result<(), Self::Error> { self.as_mut() .project() .in_flight_requests @@ -65,13 +70,18 @@ impl Sink> for FakeChannel> { } } -impl Channel for FakeChannel>, Response> +impl Channel + for FakeChannel< + io::Result>, + Response, + > where Req: Unpin, { type Req = Req; type Resp = Resp; type Transport = (); + type ServerCtx = context::DefaultContext; fn config(&self) -> &Config { &self.config @@ -86,13 +96,18 @@ where } } -impl FakeChannel>, Response> { +impl + FakeChannel< + io::Result>, + Response, + > +{ pub fn push_req(&mut self, id: u64, message: Req) { let (_, abort_registration) = futures::future::AbortHandle::new_pair(); let (request_cancellation, _) = cancellations(); self.stream.push_back(Ok(TrackedRequest { request: Request { - context: context::Context { + context: context::DefaultContext { deadline: Instant::now(), trace_context: Default::default(), }, @@ -111,7 +126,10 @@ impl FakeChannel>, Response> { } impl FakeChannel<(), ()> { - pub fn default() -> FakeChannel>, Response> { + pub fn default() -> FakeChannel< + io::Result>, + Response, + > { let (request_cancellation, canceled_requests) = cancellations(); FakeChannel { stream: Default::default(), diff --git a/tarpc/src/trace.rs b/tarpc/src/trace.rs index d172dc2e2..6b00e7f71 100644 --- a/tarpc/src/trace.rs +++ b/tarpc/src/trace.rs @@ -66,13 +66,14 @@ pub struct SpanId(u64); /// dependencies. On the other hand, if an upstream process has chosen to sample this trace, then /// the downstream samplers are expected to respect that decision and also sample the trace. /// Otherwise, the full trace would not be able to be reconstructed reliably. -#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Default)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[repr(u8)] pub enum SamplingDecision { /// The associated span was sampled by its creating process. Child spans must also be sampled. Sampled, /// The associated span was not sampled by its creating process. + #[default] Unsampled, } @@ -203,12 +204,6 @@ impl From<&opentelemetry::trace::SpanContext> for SamplingDecision { } } -impl Default for SamplingDecision { - fn default() -> Self { - Self::Unsampled - } -} - /// Returned when a [`Context`] cannot be constructed from a [`Span`](tracing::Span). #[derive(Debug)] pub struct NoActiveSpan; diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 0268300dc..658925e6d 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -191,14 +191,19 @@ mod tests { tokio::spawn( stream::once(future::ready(server_channel)) .map(BaseChannel::with_defaults) - .execute(serve(|_ctx, request: String| async move { - request.parse::().map_err(|_| { - ServerError::new( - io::ErrorKind::InvalidInput, - format!("{request:?} is not an int"), - ) - }) - })) + .execute(serve( + |_ctx: &mut context::DefaultContext, request: String| { + async move { + request.parse::().map_err(|_| { + ServerError::new( + io::ErrorKind::InvalidInput, + format!("{request:?} is not an int"), + ) + }) + } + .boxed() + }, + )) .for_each(|channel| async move { tokio::spawn(channel.for_each(|response| response)); }), @@ -206,8 +211,8 @@ mod tests { let client = client::new(client::Config::default(), client_channel).spawn(); - let response1 = client.call(context::current(), "123".into()).await; - let response2 = client.call(context::current(), "abc".into()).await; + let response1 = client.call(&mut context::current(), "123".into()).await; + let response2 = client.call(&mut context::current(), "abc".into()).await; trace!("response1: {:?}, response2: {:?}", response1, response2); diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.rs b/tarpc/tests/compile_fail/must_use_request_dispatch.rs index 2915d3237..9498d02b2 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.rs +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.rs @@ -1,5 +1,5 @@ use tarpc::client; - +use tarpc::context::DefaultContext; #[tarpc::service] trait World { async fn hello(name: String) -> String; @@ -10,6 +10,6 @@ fn main() { #[deny(unused_must_use)] { - WorldClient::new(client::Config::default(), client_transport).dispatch; + WorldClient::::new(client::Config::default(), client_transport).dispatch; } } diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr index e652cc8e8..93e6aef45 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr @@ -1,8 +1,8 @@ error: unused `RequestDispatch` that must be used --> tests/compile_fail/must_use_request_dispatch.rs:13:9 | -13 | WorldClient::new(client::Config::default(), client_transport).dispatch; - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +13 | WorldClient::::new(client::Config::default(), client_transport).dispatch; + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | note: the lint level is defined here --> tests/compile_fail/must_use_request_dispatch.rs:11:12 @@ -11,5 +11,5 @@ note: the lint level is defined here | ^^^^^^^^^^^^^^^ help: use `let _ = ...` to ignore the resulting value | -13 | let _ = WorldClient::new(client::Config::default(), client_transport).dispatch; +13 | let _ = WorldClient::::new(client::Config::default(), client_transport).dispatch; | +++++++ diff --git a/tarpc/tests/compile_fail/serde1/opt_out_serde.stderr b/tarpc/tests/compile_fail/serde1/opt_out_serde.stderr index bdb46999c..630924f71 100644 --- a/tarpc/tests/compile_fail/serde1/opt_out_serde.stderr +++ b/tarpc/tests/compile_fail/serde1/opt_out_serde.stderr @@ -2,10 +2,15 @@ error[E0277]: the trait bound `FooRequest: serde::Serialize` is not satisfied --> tests/compile_fail/serde1/opt_out_serde.rs:12:40 | 12 | tarpc::serde::Serialize::serialize(&x, f); - | ---------------------------------- ^^ the trait `Serialize` is not implemented for `FooRequest` + | ---------------------------------- ^^ unsatisfied trait bound | | | required by a bound introduced by this call | +help: the trait `Serialize` is not implemented for `FooRequest` + --> tests/compile_fail/serde1/opt_out_serde.rs:5:1 + | + 5 | #[tarpc::service(derive_serde = false)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ = note: for local types consider adding `#[derive(serde::Serialize)]` to your `FooRequest` type = note: for types from other crates check whether the crate offers a `serde` feature flag = help: the following other types implement trait `Serialize`: @@ -18,3 +23,4 @@ error[E0277]: the trait bound `FooRequest: serde::Serialize` is not satisfied (T0, T1, T2, T3) (T0, T1, T2, T3, T4) and $N others + = note: this error originates in the attribute macro `tarpc::service` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 18bb3a997..ac4d17027 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -22,7 +22,8 @@ pub trait ColorProtocol { struct ColorServer; impl ColorProtocol for ColorServer { - async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData { + type Context = context::DefaultContext; + async fn get_opposite_color(self, _: &mut Self::Context, color: TestData) -> TestData { match color { TestData::White => TestData::Black, TestData::Black => TestData::White, @@ -53,7 +54,7 @@ async fn test_call() -> anyhow::Result<()> { let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); let color = client - .get_opposite_color(context::current(), TestData::White) + .get_opposite_color(&mut context::current(), TestData::White) .await?; assert_eq!(color, TestData::Black); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 06542b43b..28157b25f 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -22,11 +22,12 @@ trait Service { struct Server; impl Service for Server { - async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { + type Context = context::DefaultContext; + async fn add(self, _: &mut Self::Context, x: i32, y: i32) -> i32 { x + y } - async fn hey(self, _: context::Context, name: String) -> String { + async fn hey(self, _: &mut Self::Context, name: String) -> String { format!("Hey, {name}.") } } @@ -38,10 +39,12 @@ async fn sequential() { let channel = BaseChannel::with_defaults(rx); tokio::spawn( channel - .execute(tarpc::server::serve(|_, i: u32| async move { Ok(i + 1) })) + .execute(tarpc::server::serve(|_, i: u32| { + async move { Ok(i + 1) }.boxed() + })) .for_each(|response| response), ); - assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + assert_eq!(client.call(&mut context::current(), 1).await.unwrap(), 2); } #[tokio::test] @@ -55,7 +58,8 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { struct LoopServer; impl Loop for LoopServer { - async fn r#loop(self, _: context::Context) { + type Context = context::DefaultContext; + async fn r#loop(self, _: &mut Self::Context) { loop { futures::pending!(); } @@ -64,7 +68,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = channel::unbounded(); + let (tx, rx) = tarpc::transport::channel::unbounded(); // Set up a client that initiates a long-lived request. // The request will complete in error when the server drops the connection. @@ -73,7 +77,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { let mut ctx = context::current(); ctx.deadline = Instant::now() + Duration::from_secs(60 * 60); - let _ = client.r#loop(ctx).await; + let _ = client.r#loop(&mut ctx).await; }); let mut requests = BaseChannel::with_defaults(rx).requests(); @@ -112,9 +116,9 @@ async fn serde_tcp() -> anyhow::Result<()> { let transport = serde_transport::tcp::connect(addr, Json::default).await?; let client = ServiceClient::new(client::Config::default(), transport).spawn(); - assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); + assert_matches!(client.add(&mut context::current(), 1, 2).await, Ok(3)); assert_matches!( - client.hey(context::current(), "Tim".to_string()).await, + client.hey(&mut context::current(), "Tim".to_string()).await, Ok(ref s) if s == "Hey, Tim." ); @@ -145,8 +149,8 @@ async fn serde_uds() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), transport).spawn(); // Save results using socket so we can clean the socket even if our test assertions fail - let res1 = client.add(context::current(), 1, 2).await; - let res2 = client.hey(context::current(), "Tim".to_string()).await; + let res1 = client.add(&mut context::current(), 1, 2).await; + let res2 = client.hey(&mut context::current(), "Tim".to_string()).await; assert_matches!(res1, Ok(3)); assert_matches!(res2, Ok(ref s) if s == "Hey, Tim."); @@ -158,7 +162,8 @@ async fn serde_uds() -> anyhow::Result<()> { async fn concurrent() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = channel::unbounded(); + let (tx, rx) = tarpc::transport::channel::unbounded(); + tokio::spawn( stream::once(ready(rx)) .map(BaseChannel::with_defaults) @@ -169,12 +174,15 @@ async fn concurrent() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let req1 = client.add(context::current(), 1, 2); - let req2 = client.add(context::current(), 3, 4); - let req3 = client.hey(context::current(), "Tim".to_string()); + let mut context = context::current(); + let req1 = client.add(&mut context, 1, 2); assert_matches!(req1.await, Ok(3)); + + let req2 = client.add(&mut context, 3, 4); assert_matches!(req2.await, Ok(7)); + + let req3 = client.hey(&mut context, "Tim".to_string()); assert_matches!(req3.await, Ok(ref s) if s == "Hey, Tim."); Ok(()) @@ -184,7 +192,8 @@ async fn concurrent() -> anyhow::Result<()> { async fn concurrent_join() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = channel::unbounded(); + let (tx, rx) = tarpc::transport::channel::unbounded(); + tokio::spawn( stream::once(ready(rx)) .map(BaseChannel::with_defaults) @@ -195,9 +204,13 @@ async fn concurrent_join() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let req1 = client.add(context::current(), 1, 2); - let req2 = client.add(context::current(), 3, 4); - let req3 = client.hey(context::current(), "Tim".to_string()); + let mut context1 = context::current(); + let mut context2 = context::current(); + let mut context3 = context::current(); + + let req1 = client.add(&mut context1, 1, 2); + let req2 = client.add(&mut context2, 3, 4); + let req3 = client.hey(&mut context3, "Tim".to_string()); let (resp1, resp2, resp3) = join!(req1, req2, req3); assert_matches!(resp1, Ok(3)); @@ -216,7 +229,7 @@ async fn spawn(fut: impl Future + Send + 'static) { async fn concurrent_join_all() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = channel::unbounded(); + let (tx, rx) = tarpc::transport::channel::unbounded(); tokio::spawn( BaseChannel::with_defaults(rx) .execute(Server.serve()) @@ -225,8 +238,11 @@ async fn concurrent_join_all() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let req1 = client.add(context::current(), 1, 2); - let req2 = client.add(context::current(), 3, 4); + let mut context1 = context::current(); + let mut context2 = context::current(); + + let req1 = client.add(&mut context1, 1, 2); + let req2 = client.add(&mut context2, 3, 4); let responses = join_all(vec![req1, req2]).await; assert_matches!(responses[0], Ok(3)); @@ -245,14 +261,16 @@ async fn counter() -> anyhow::Result<()> { struct CountService(u32); impl Counter for &mut CountService { - async fn count(self, _: context::Context) -> u32 { + type Context = context::DefaultContext; + async fn count(self, _: &mut Self::Context) -> u32 { self.0 += 1; self.0 } } let (tx, rx) = channel::unbounded(); - tokio::spawn(async { + + tokio::task::spawn(async { let mut requests = BaseChannel::with_defaults(rx).requests(); let mut counter = CountService(0); @@ -262,8 +280,8 @@ async fn counter() -> anyhow::Result<()> { }); let client = CounterClient::new(client::Config::default(), tx).spawn(); - assert_matches!(client.count(context::current()).await, Ok(1)); - assert_matches!(client.count(context::current()).await, Ok(2)); + assert_matches!(client.count(&mut context::current()).await, Ok(1)); + assert_matches!(client.count(&mut context::current()).await, Ok(2)); Ok(()) }