diff --git a/r2r/src/context.rs b/r2r/src/context.rs index 20c4cb6..379ba1d 100644 --- a/r2r/src/context.rs +++ b/r2r/src/context.rs @@ -37,7 +37,7 @@ unsafe impl Sync for Context {} // Memory corruption (double free and others) was observed creating multiple // `Context` objects in a single thread // -// To reproduce, run the tests from `tokio_testing` or `tokio_test_raw` +// To reproduce, run the tests from `tokio_testing` or `tokio_test_raw` // without this OnceLock static CONTEXT: OnceLock> = OnceLock::new(); @@ -45,49 +45,52 @@ static CONTEXT: OnceLock> = OnceLock::new(); impl Context { /// Create a ROS context. pub fn create() -> Result { - CONTEXT.get_or_init(|| { - let mut ctx: Box = unsafe { Box::new(rcl_get_zero_initialized_context()) }; - // argc/v - let args = std::env::args() - .map(|arg| CString::new(arg).unwrap()) - .collect::>(); - let mut c_args = args - .iter() - .map(|arg| arg.as_ptr()) - .collect::>(); - c_args.push(std::ptr::null()); + CONTEXT + .get_or_init(|| { + let mut ctx: Box = + unsafe { Box::new(rcl_get_zero_initialized_context()) }; + // argc/v + let args = std::env::args() + .map(|arg| CString::new(arg).unwrap()) + .collect::>(); + let mut c_args = args + .iter() + .map(|arg| arg.as_ptr()) + .collect::>(); + c_args.push(std::ptr::null()); - let is_valid = unsafe { - let allocator = rcutils_get_default_allocator(); - let mut init_options = rcl_get_zero_initialized_init_options(); - check_rcl_ret!(rcl_init_options_init(&mut init_options, allocator)); - check_rcl_ret!(rcl_init( - (c_args.len() - 1) as ::std::os::raw::c_int, - c_args.as_ptr(), - &init_options, - ctx.as_mut(), - )); - check_rcl_ret!(rcl_init_options_fini(&mut init_options as *mut _)); - rcl_context_is_valid(ctx.as_mut()) - }; + let is_valid = unsafe { + let allocator = rcutils_get_default_allocator(); + let mut init_options = rcl_get_zero_initialized_init_options(); + check_rcl_ret!(rcl_init_options_init(&mut init_options, allocator)); + check_rcl_ret!(rcl_init( + (c_args.len() - 1) as ::std::os::raw::c_int, + c_args.as_ptr(), + &init_options, + ctx.as_mut(), + )); + check_rcl_ret!(rcl_init_options_fini(&mut init_options as *mut _)); + rcl_context_is_valid(ctx.as_mut()) + }; - let logging_ok = unsafe { - let _guard = log_guard(); - let ret = rcl_logging_configure( - &ctx.as_ref().global_arguments, - &rcutils_get_default_allocator(), - ); - ret == RCL_RET_OK as i32 - }; + let logging_ok = unsafe { + let _guard = log_guard(); + let ret = rcl_logging_configure( + &ctx.as_ref().global_arguments, + &rcutils_get_default_allocator(), + ); + ret == RCL_RET_OK as i32 + }; - if is_valid && logging_ok { - Ok(Context { - context_handle: Arc::new(Mutex::new(ContextHandle(ctx))), - }) - } else { - Err(Error::RCL_RET_ERROR) // TODO - } - }).clone() + if is_valid && logging_ok { + Ok(Context { + context_handle: Arc::new(Mutex::new(ContextHandle(ctx))), + }) + } else { + Err(Error::RCL_RET_ERROR) // TODO + } + }) + .clone() } /// Check if the ROS context is valid. diff --git a/r2r/src/lib.rs b/r2r/src/lib.rs index 122689e..3feba96 100644 --- a/r2r/src/lib.rs +++ b/r2r/src/lib.rs @@ -79,9 +79,9 @@ mod msg_types; pub use msg_types::generated_msgs::*; pub use msg_types::WrappedActionTypeSupport; pub use msg_types::WrappedNativeMsg as NativeMsg; +pub use msg_types::WrappedNativeMsgUntyped; pub use msg_types::WrappedServiceTypeSupport; pub use msg_types::WrappedTypesupport; -pub use msg_types::WrappedNativeMsgUntyped; mod utils; pub use utils::*; diff --git a/r2r/src/msg_types.rs b/r2r/src/msg_types.rs index da426c1..3c1c792 100644 --- a/r2r/src/msg_types.rs +++ b/r2r/src/msg_types.rs @@ -423,11 +423,7 @@ impl WrappedNativeMsgUntyped { // any part of msg_buf. However it shouldn't matter since from_native // clones everything again anyway .. let result = unsafe { - rmw_deserialize( - &msg_buf as *const rcl_serialized_message_t, - self.ts, - self.msg, - ) + rmw_deserialize(&msg_buf as *const rcl_serialized_message_t, self.ts, self.msg) }; if result == RCL_RET_OK as i32 { diff --git a/r2r/src/publishers.rs b/r2r/src/publishers.rs index f984b1d..5b84374 100644 --- a/r2r/src/publishers.rs +++ b/r2r/src/publishers.rs @@ -1,13 +1,13 @@ +use futures::channel::oneshot; +use futures::Future; +use futures::TryFutureExt; use std::ffi::c_void; use std::ffi::CString; use std::fmt::Debug; use std::marker::PhantomData; +use std::sync::Mutex; use std::sync::Once; use std::sync::Weak; -use std::sync::Mutex; -use futures::Future; -use futures::channel::oneshot; -use futures::TryFutureExt; use crate::error::*; use crate::msg_types::*; @@ -45,11 +45,10 @@ pub(crate) struct Publisher_ { handle: rcl_publisher_t, // TODO use a mpsc to avoid the mutex? - poll_inter_process_subscriber_channels: Mutex>> + poll_inter_process_subscriber_channels: Mutex>>, } -impl Publisher_ -{ +impl Publisher_ { fn get_inter_process_subscription_count(&self) -> Result { // See https://github.com/ros2/rclcpp/issues/623 @@ -70,8 +69,7 @@ impl Publisher_ } pub(crate) fn poll_has_inter_process_subscribers(&self) { - - let mut poll_inter_process_subscriber_channels = + let mut poll_inter_process_subscriber_channels = self.poll_inter_process_subscriber_channels.lock().unwrap(); if poll_inter_process_subscriber_channels.is_empty() { @@ -102,7 +100,6 @@ impl Publisher_ } } - /// A ROS (typed) publisher. /// /// This contains a `Weak Arc` to a typed publisher. As such it is safe to @@ -139,10 +136,7 @@ where } pub fn make_publisher_untyped(handle: Weak, type_: String) -> PublisherUntyped { - PublisherUntyped { - handle, - type_, - } + PublisherUntyped { handle, type_ } } pub fn create_publisher_helper( @@ -166,7 +160,7 @@ pub fn create_publisher_helper( if result == RCL_RET_OK as i32 { Ok(Publisher_ { handle: publisher_handle, - poll_inter_process_subscriber_channels: Mutex::new(Vec::new()) + poll_inter_process_subscriber_channels: Mutex::new(Vec::new()), }) } else { Err(Error::from_rcl_error(result)) @@ -187,12 +181,13 @@ impl PublisherUntyped { let native_msg = WrappedNativeMsgUntyped::new_from(&self.type_)?; native_msg.from_json(msg)?; - let result = - unsafe { rcl_publish( + let result = unsafe { + rcl_publish( &publisher.handle as *const rcl_publisher_t, native_msg.void_ptr(), - std::ptr::null_mut()) - }; + std::ptr::null_mut(), + ) + }; if result == RCL_RET_OK as i32 { Ok(()) @@ -221,15 +216,16 @@ impl PublisherUntyped { buffer_capacity: data.len(), // Since its read only, this should never be used .. - allocator: unsafe { rcutils_get_default_allocator() } + allocator: unsafe { rcutils_get_default_allocator() }, }; - let result = - unsafe { rcl_publish_serialized_message( - &publisher.handle, - &msg_buf as *const rcl_serialized_message_t, - std::ptr::null_mut() - ) }; + let result = unsafe { + rcl_publish_serialized_message( + &publisher.handle, + &msg_buf as *const rcl_serialized_message_t, + std::ptr::null_mut(), + ) + }; if result == RCL_RET_OK as i32 { Ok(()) @@ -265,7 +261,6 @@ impl PublisherUntyped { } } - impl Publisher where T: WrappedTypesupport, @@ -281,12 +276,13 @@ where .upgrade() .ok_or(Error::RCL_RET_PUBLISHER_INVALID)?; let native_msg: WrappedNativeMsg = WrappedNativeMsg::::from(msg); - let result = - unsafe { rcl_publish( + let result = unsafe { + rcl_publish( &publisher.handle as *const rcl_publisher_t, native_msg.void_ptr(), - std::ptr::null_mut()) - }; + std::ptr::null_mut(), + ) + }; if result == RCL_RET_OK as i32 { Ok(()) @@ -312,7 +308,7 @@ where rcl_borrow_loaned_message( &publisher.handle as *const rcl_publisher_t, T::get_ts(), - &mut loaned_msg + &mut loaned_msg, ) }; if ret != RCL_RET_OK as i32 { @@ -379,11 +375,13 @@ where ) } } else { - unsafe { rcl_publish( - &publisher.handle as *const rcl_publisher_t, - msg.void_ptr(), - std::ptr::null_mut() - ) } + unsafe { + rcl_publish( + &publisher.handle as *const rcl_publisher_t, + msg.void_ptr(), + std::ptr::null_mut(), + ) + } }; if result == RCL_RET_OK as i32 { @@ -394,7 +392,7 @@ where } } - /// Gets the number of external subscribers (i.e. it doesn't + /// Gets the number of external subscribers (i.e. it doesn't /// count subscribers from the same process). pub fn get_inter_process_subscription_count(&self) -> Result { self.handle @@ -418,5 +416,4 @@ where Ok(receiver.map_err(|_| Error::RCL_RET_CLIENT_INVALID)) } - } diff --git a/r2r/tests/threads.rs b/r2r/tests/threads.rs index 0b7c2d1..abe9b92 100644 --- a/r2r/tests/threads.rs +++ b/r2r/tests/threads.rs @@ -10,68 +10,71 @@ const N_TEARDOWN_CYCLES: usize = 2; #[test] // Let's create and drop a lot of node and publishers for a while to see that we can cope. fn doesnt_crash() -> Result<(), Box> { + let threads = (0..N_CONCURRENT_ROS_CONTEXT).map(|i_context| { + std::thread::spawn(move || { + for _i_cycle in 0..N_TEARDOWN_CYCLES { + // a global shared context. + let ctx = r2r::Context::create().unwrap(); - let threads = (0..N_CONCURRENT_ROS_CONTEXT).map(|i_context| std::thread::spawn(move || { + for c in 0..10 { + let mut ths = Vec::new(); + // I have lowered this from 30 to (10 / N_CONCURRENT_ROS_CONTEXT) because cyclonedds can only handle a hard-coded number of + // publishers in threads. See + // https://github.com/eclipse-cyclonedds/cyclonedds/blob/cd2136d9321212bd52fdc613f07bbebfddd90dec/src/core/ddsc/src/dds_init.c#L115 + for i_node in 0..N_NODE_PER_CONTEXT { + // create concurrent nodes that max out the cpu + let ctx = ctx.clone(); + ths.push(thread::spawn(move || { + let mut node = r2r::Node::create( + ctx, + &format!("testnode_{}_{}", i_context, i_node), + "", + ) + .unwrap(); - for _i_cycle in 0..N_TEARDOWN_CYCLES { - // a global shared context. - let ctx = r2r::Context::create().unwrap(); + // each with 10 publishers + for _j in 0..10 { + let p = node + .create_publisher::( + &format!("/r2r{}", i_node), + QosProfile::default(), + ) + .unwrap(); + let to_send = r2r::std_msgs::msg::String { + data: format!("[node{}]: {}", i_node, c), + }; - for c in 0..10 { - let mut ths = Vec::new(); - // I have lowered this from 30 to (10 / N_CONCURRENT_ROS_CONTEXT) because cyclonedds can only handle a hard-coded number of - // publishers in threads. See - // https://github.com/eclipse-cyclonedds/cyclonedds/blob/cd2136d9321212bd52fdc613f07bbebfddd90dec/src/core/ddsc/src/dds_init.c#L115 - for i_node in 0..N_NODE_PER_CONTEXT { - // create concurrent nodes that max out the cpu - let ctx = ctx.clone(); - ths.push(thread::spawn(move || { - let mut node = r2r::Node::create(ctx, &format!("testnode_{}_{}", i_context, i_node), "").unwrap(); - - // each with 10 publishers - for _j in 0..10 { - let p = node - .create_publisher::( - &format!("/r2r{}", i_node), - QosProfile::default(), - ) - .unwrap(); - let to_send = r2r::std_msgs::msg::String { - data: format!("[node{}]: {}", i_node, c), - }; - - // move publisher to its own thread and publish as fast as we can - thread::spawn(move || loop { - let res = p.publish(&to_send); - thread::sleep(Duration::from_millis(1)); - match res { - Ok(_) => (), - Err(_) => { - // println!("publisher died, quitting thread."); - break; + // move publisher to its own thread and publish as fast as we can + thread::spawn(move || loop { + let res = p.publish(&to_send); + thread::sleep(Duration::from_millis(1)); + match res { + Ok(_) => (), + Err(_) => { + // println!("publisher died, quitting thread."); + break; + } } - } - }); - } + }); + } - // spin to simulate some load - for _j in 0..100 { - node.spin_once(Duration::from_millis(10)); - } + // spin to simulate some load + for _j in 0..100 { + node.spin_once(Duration::from_millis(10)); + } - // println!("all done {}-{}", c, i); - })); + // println!("all done {}-{}", c, i); + })); + } + + for t in ths { + t.join().unwrap(); + } + // println!("all threads done {}", c); } - - for t in ths { - t.join().unwrap(); - } - // println!("all threads done {}", c); - - } - } - - })); + } + }) + }); for thread in threads.into_iter() { thread.join().unwrap(); diff --git a/r2r/tests/tokio_test_raw.rs b/r2r/tests/tokio_test_raw.rs index 1819143..8ec0f31 100644 --- a/r2r/tests/tokio_test_raw.rs +++ b/r2r/tests/tokio_test_raw.rs @@ -1,8 +1,7 @@ use futures::stream::StreamExt; use r2r::QosProfile; -use tokio::task; use r2r::WrappedTypesupport; - +use tokio::task; const N_CONCURRENT_ROS_CONTEXT: usize = 3; const N_TEARDOWN_CYCLES: usize = 2; @@ -10,72 +9,88 @@ const N_TEARDOWN_CYCLES: usize = 2; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn tokio_subscribe_raw_testing() -> Result<(), Box> { let mut threads = futures::stream::FuturesUnordered::from_iter( - (0..N_CONCURRENT_ROS_CONTEXT).map(|i_context| tokio::spawn(async move { - // Iterate to check for memory corruption on node setup/teardown - for i_cycle in 0..N_TEARDOWN_CYCLES { - println!("tokio_subscribe_raw_testing iteration {i_cycle}"); + (0..N_CONCURRENT_ROS_CONTEXT).map(|i_context| { + tokio::spawn(async move { + // Iterate to check for memory corruption on node setup/teardown + for i_cycle in 0..N_TEARDOWN_CYCLES { + println!("tokio_subscribe_raw_testing iteration {i_cycle}"); - let ctx = r2r::Context::create().unwrap(); - let mut node = r2r::Node::create(ctx, &format!("testnode2_{i_context}"), "").unwrap(); + let ctx = r2r::Context::create().unwrap(); + let mut node = + r2r::Node::create(ctx, &format!("testnode2_{i_context}"), "").unwrap(); - let mut sub_int = node.subscribe_raw("/int", "std_msgs/msg/Int32", QosProfile::default()).unwrap(); + let mut sub_int = node + .subscribe_raw("/int", "std_msgs/msg/Int32", QosProfile::default()) + .unwrap(); - let mut sub_array = - node.subscribe_raw("/int_array", "std_msgs/msg/Int32MultiArray", QosProfile::default()).unwrap(); + let mut sub_array = node + .subscribe_raw( + "/int_array", + "std_msgs/msg/Int32MultiArray", + QosProfile::default(), + ) + .unwrap(); - let pub_int = - node.create_publisher::("/int", QosProfile::default()).unwrap(); + let pub_int = node + .create_publisher::( + "/int", + QosProfile::default(), + ) + .unwrap(); - // Use an array as well since its a variable sized type - let pub_array = node.create_publisher::( - "/int_array", - QosProfile::default(), - ).unwrap(); + // Use an array as well since its a variable sized type + let pub_array = node + .create_publisher::( + "/int_array", + QosProfile::default(), + ) + .unwrap(); - task::spawn(async move { - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - (0..10).for_each(|i| { - pub_int - .publish(&r2r::std_msgs::msg::Int32 { data: i }) - .unwrap(); + task::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + (0..10).for_each(|i| { + pub_int + .publish(&r2r::std_msgs::msg::Int32 { data: i }) + .unwrap(); - pub_array - .publish(&r2r::std_msgs::msg::Int32MultiArray { - layout: r2r::std_msgs::msg::MultiArrayLayout::default(), - data: vec![i], - }) - .unwrap(); + pub_array + .publish(&r2r::std_msgs::msg::Int32MultiArray { + layout: r2r::std_msgs::msg::MultiArrayLayout::default(), + data: vec![i], + }) + .unwrap(); + }); }); - }); - let sub_int_handle = task::spawn(async move { - while let Some(msg) = sub_int.next().await { - println!("Got int msg of len {}", msg.len()); - assert_eq!(msg.len(), 8); - } - }); + let sub_int_handle = task::spawn(async move { + while let Some(msg) = sub_int.next().await { + println!("Got int msg of len {}", msg.len()); + assert_eq!(msg.len(), 8); + } + }); - let sub_array_handle = task::spawn(async move { - while let Some(msg) = sub_array.next().await { - println!("Got array msg of len {}", msg.len()); - assert_eq!(msg.len(), 20); - } - }); + let sub_array_handle = task::spawn(async move { + while let Some(msg) = sub_array.next().await { + println!("Got array msg of len {}", msg.len()); + assert_eq!(msg.len(), 20); + } + }); - let handle = std::thread::spawn(move || { - for _ in 1..=30 { - node.spin_once(std::time::Duration::from_millis(100)); - } - }); + let handle = std::thread::spawn(move || { + for _ in 1..=30 { + node.spin_once(std::time::Duration::from_millis(100)); + } + }); - sub_int_handle.await.unwrap(); - sub_array_handle.await.unwrap(); - handle.join().unwrap(); + sub_int_handle.await.unwrap(); + sub_array_handle.await.unwrap(); + handle.join().unwrap(); - println!("Going to drop tokio_subscribe_raw_testing iteration {i_cycle}"); - } - - }))); + println!("Going to drop tokio_subscribe_raw_testing iteration {i_cycle}"); + } + }) + }), + ); while let Some(thread) = threads.next().await { thread.unwrap(); @@ -84,95 +99,110 @@ async fn tokio_subscribe_raw_testing() -> Result<(), Box> Ok(()) } - // Limit the number of threads to force threads to be reused #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn tokio_publish_raw_testing() -> Result<(), Box> { - let mut threads = futures::stream::FuturesUnordered::from_iter( - (0..N_CONCURRENT_ROS_CONTEXT).map(|i_context| tokio::spawn(async move { - // Iterate to check for memory corruption on node setup/teardown - for i_cycle in 0..N_TEARDOWN_CYCLES { + (0..N_CONCURRENT_ROS_CONTEXT).map(|i_context| { + tokio::spawn(async move { + // Iterate to check for memory corruption on node setup/teardown + for i_cycle in 0..N_TEARDOWN_CYCLES { + println!("tokio_publish_raw_testing iteration {i_cycle}"); - println!("tokio_publish_raw_testing iteration {i_cycle}"); + let ctx = r2r::Context::create().unwrap(); + let mut node = + r2r::Node::create(ctx, &format!("testnode3_{i_context}"), "").unwrap(); - let ctx = r2r::Context::create().unwrap(); - let mut node = r2r::Node::create(ctx, &format!("testnode3_{i_context}"), "").unwrap(); + let mut sub_int = node + .subscribe::("/int", QosProfile::default()) + .unwrap(); - let mut sub_int = node.subscribe::("/int", QosProfile::default()).unwrap(); + let mut sub_array = node + .subscribe::( + "/int_array", + QosProfile::default(), + ) + .unwrap(); - let mut sub_array = - node.subscribe::("/int_array", QosProfile::default()).unwrap(); + let pub_int = node + .create_publisher_untyped( + "/int", + "std_msgs/msg/Int32", + QosProfile::default(), + ) + .unwrap(); - let pub_int = node.create_publisher_untyped( - "/int", - "std_msgs/msg/Int32", - QosProfile::default() - ).unwrap(); + // Use an array as well since its a variable sized type + let pub_array = node + .create_publisher_untyped( + "/int_array", + "std_msgs/msg/Int32MultiArray", + QosProfile::default(), + ) + .unwrap(); - // Use an array as well since its a variable sized type - let pub_array = node.create_publisher_untyped( - "/int_array", - "std_msgs/msg/Int32MultiArray", - QosProfile::default(), - ).unwrap(); + task::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + (0..10).for_each(|i| { + pub_int + .publish_raw( + &r2r::std_msgs::msg::Int32 { data: i } + .to_serialized_bytes() + .unwrap(), + ) + .unwrap(); - task::spawn(async move { - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - (0..10).for_each(|i| { - pub_int - .publish_raw(&r2r::std_msgs::msg::Int32 { data: i }.to_serialized_bytes().unwrap()) - .unwrap(); - - pub_array - .publish_raw( - &r2r::std_msgs::msg::Int32MultiArray { - layout: r2r::std_msgs::msg::MultiArrayLayout::default(), - data: vec![i], - }.to_serialized_bytes().unwrap() - ) - .unwrap(); + pub_array + .publish_raw( + &r2r::std_msgs::msg::Int32MultiArray { + layout: r2r::std_msgs::msg::MultiArrayLayout::default(), + data: vec![i], + } + .to_serialized_bytes() + .unwrap(), + ) + .unwrap(); + }); }); - }); - let sub_int_handle = task::spawn(async move { - while let Some(msg) = sub_int.next().await { - // Try to check for any possible corruption - msg.to_serialized_bytes().unwrap(); + let sub_int_handle = task::spawn(async move { + while let Some(msg) = sub_int.next().await { + // Try to check for any possible corruption + msg.to_serialized_bytes().unwrap(); - println!("Got int msg with value {}", msg.data); - assert!(msg.data >= 0); - assert!(msg.data < 10); + println!("Got int msg with value {}", msg.data); + assert!(msg.data >= 0); + assert!(msg.data < 10); + } + }); - } - }); + let sub_array_handle = task::spawn(async move { + while let Some(msg) = sub_array.next().await { + // Try to check for any possible corruption + msg.to_serialized_bytes().unwrap(); - let sub_array_handle = task::spawn(async move { - while let Some(msg) = sub_array.next().await { - // Try to check for any possible corruption - msg.to_serialized_bytes().unwrap(); + println!("Got array msg with value {:?}", msg.data); + assert_eq!(msg.data.len(), 1); + assert!(msg.data[0] >= 0); + assert!(msg.data[0] < 10); + } + }); - println!("Got array msg with value {:?}", msg.data); - assert_eq!(msg.data.len(), 1); - assert!(msg.data[0] >= 0); - assert!(msg.data[0] < 10); - } - }); + let handle = std::thread::spawn(move || { + for _ in 1..=30 { + node.spin_once(std::time::Duration::from_millis(100)); + } + }); - let handle = std::thread::spawn(move || { - for _ in 1..=30 { - node.spin_once(std::time::Duration::from_millis(100)); - } - }); + sub_int_handle.await.unwrap(); + sub_array_handle.await.unwrap(); + handle.join().unwrap(); - sub_int_handle.await.unwrap(); - sub_array_handle.await.unwrap(); - handle.join().unwrap(); - - println!("Going to drop tokio_publish_raw_testing iteration {i_cycle}"); - - } - }))); + println!("Going to drop tokio_publish_raw_testing iteration {i_cycle}"); + } + }) + }), + ); while let Some(thread) = threads.next().await { thread.unwrap();