Add raw publisher, convert Context to singleton (#76)

* Add raw publisher

* Add raw publisher examples

* Add more tests. Convert Context into singleton to avoid memory corruption when creatin multiple Context instances
This commit is contained in:
Michael Hoy 2024-01-07 18:46:59 +08:00 committed by GitHub
parent 3d6936e70a
commit cb87b9c01c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 427 additions and 178 deletions

View File

@ -0,0 +1,28 @@
use r2r::QosProfile;
use r2r::WrappedTypesupport;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let ctx = r2r::Context::create()?;
let mut node = r2r::Node::create(ctx, "testnode", "")?;
let duration = std::time::Duration::from_millis(2500);
let mut timer = node.create_wall_timer(duration)?;
let publisher =
node.create_publisher_untyped("/topic", "std_msgs/msg/String", QosProfile::default())?;
let handle = tokio::task::spawn_blocking(move || loop {
node.spin_once(std::time::Duration::from_millis(100));
});
for _ in 1..10 {
timer.tick().await?;
let msg = r2r::std_msgs::msg::String {
data: "hello from r2r".to_string(),
};
publisher.publish_raw(&msg.to_serialized_bytes()?)?;
}
handle.await?;
Ok(())
}

View File

@ -2,6 +2,7 @@ use std::ffi::CStr;
use std::ffi::CString; use std::ffi::CString;
use std::fmt::Debug; use std::fmt::Debug;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use std::sync::OnceLock;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use crate::error::*; use crate::error::*;
@ -29,9 +30,22 @@ macro_rules! check_rcl_ret {
unsafe impl Send for Context {} unsafe impl Send for Context {}
// Safety: Context is just a Arc<Mutex<..>> wrapper around ContextHandle
// so it should be safe to access from different threads
unsafe impl Sync for Context {}
// Memory corruption (double free and others) was observed creating multiple
// `Context` objects in a single thread
//
// To reproduce, run the tests from `tokio_testing` or `tokio_test_raw`
// without this OnceLock
static CONTEXT: OnceLock<Result<Context>> = OnceLock::new();
impl Context { impl Context {
/// Create a ROS context. /// Create a ROS context.
pub fn create() -> Result<Context> { pub fn create() -> Result<Context> {
CONTEXT.get_or_init(|| {
let mut ctx: Box<rcl_context_t> = unsafe { Box::new(rcl_get_zero_initialized_context()) }; let mut ctx: Box<rcl_context_t> = unsafe { Box::new(rcl_get_zero_initialized_context()) };
// argc/v // argc/v
let args = std::env::args() let args = std::env::args()
@ -73,6 +87,7 @@ impl Context {
} else { } else {
Err(Error::RCL_RET_ERROR) // TODO Err(Error::RCL_RET_ERROR) // TODO
} }
}).clone()
} }
/// Check if the ROS context is valid. /// Check if the ROS context is valid.

View File

@ -11,7 +11,7 @@ pub type Result<T> = std::result::Result<T, Error>;
/// These values are mostly copied straight from the RCL headers, but /// These values are mostly copied straight from the RCL headers, but
/// some are specific to r2r, such as `GoalCancelRejected` which does /// some are specific to r2r, such as `GoalCancelRejected` which does
/// not have an analogue in the rcl. /// not have an analogue in the rcl.
#[derive(Error, Debug)] #[derive(Error, Clone, Debug)]
pub enum Error { pub enum Error {
#[error("RCL_RET_OK")] #[error("RCL_RET_OK")]
RCL_RET_OK, RCL_RET_OK,

View File

@ -791,7 +791,9 @@ impl Node {
Ok(p) Ok(p)
} }
/// Create a ROS publisher with a type given at runtime. /// Create a ROS publisher with a type given at runtime, where the data may either be
/// supplied as JSON (using the `publish` method) or a pre-serialized ROS message
/// (i.e. &[u8], using the `publish_raw` method).
pub fn create_publisher_untyped( pub fn create_publisher_untyped(
&mut self, topic: &str, topic_type: &str, qos_profile: QosProfile, &mut self, topic: &str, topic_type: &str, qos_profile: QosProfile,
) -> Result<PublisherUntyped> { ) -> Result<PublisherUntyped> {

View File

@ -202,6 +202,43 @@ impl PublisherUntyped {
} }
} }
/// Publish an pre-serialized ROS message represented by a `&[u8]`.
///
/// It is up to the user to make sure data is a valid ROS serialized message.
pub fn publish_raw(&self, data: &[u8]) -> Result<()> {
// TODO should this be an unsafe function? I'm not sure what happens if the data is malformed ..
// upgrade to actual ref. if still alive
let publisher = self
.handle
.upgrade()
.ok_or(Error::RCL_RET_PUBLISHER_INVALID)?;
// Safety: Not retained beyond this function
let msg_buf = rcl_serialized_message_t {
buffer: data.as_ptr() as *mut u8,
buffer_length: data.len(),
buffer_capacity: data.len(),
// Since its read only, this should never be used ..
allocator: unsafe { rcutils_get_default_allocator() }
};
let result =
unsafe { rcl_publish_serialized_message(
&publisher.handle,
&msg_buf as *const rcl_serialized_message_t,
std::ptr::null_mut()
) };
if result == RCL_RET_OK as i32 {
Ok(())
} else {
log::error!("could not publish {}", result);
Err(Error::from_rcl_error(result))
}
}
/// Gets the number of external subscribers (i.e. it doesn't /// Gets the number of external subscribers (i.e. it doesn't
/// count subscribers from the same process). /// count subscribers from the same process).
pub fn get_inter_process_subscription_count(&self) -> Result<usize> { pub fn get_inter_process_subscription_count(&self) -> Result<usize> {

View File

@ -3,33 +3,41 @@ use std::time::Duration;
use r2r::QosProfile; use r2r::QosProfile;
const N_NODE_PER_CONTEXT: usize = 5;
const N_CONCURRENT_ROS_CONTEXT: usize = 2;
const N_TEARDOWN_CYCLES: usize = 2;
#[test] #[test]
// Let's create and drop a lot of node and publishers for a while to see that we can cope. // Let's create and drop a lot of node and publishers for a while to see that we can cope.
fn doesnt_crash() -> Result<(), Box<dyn std::error::Error>> { fn doesnt_crash() -> Result<(), Box<dyn std::error::Error>> {
let threads = (0..N_CONCURRENT_ROS_CONTEXT).map(|i_context| std::thread::spawn(move || {
for _i_cycle in 0..N_TEARDOWN_CYCLES {
// a global shared context. // a global shared context.
let ctx = r2r::Context::create()?; let ctx = r2r::Context::create().unwrap();
for c in 0..10 { for c in 0..10 {
let mut ths = Vec::new(); let mut ths = Vec::new();
// I have lowered this from 30 to 10 because cyclonedds can only handle a hard-coded number of // I have lowered this from 30 to (10 / N_CONCURRENT_ROS_CONTEXT) because cyclonedds can only handle a hard-coded number of
// publishers in threads. See // publishers in threads. See
// https://github.com/eclipse-cyclonedds/cyclonedds/blob/cd2136d9321212bd52fdc613f07bbebfddd90dec/src/core/ddsc/src/dds_init.c#L115 // https://github.com/eclipse-cyclonedds/cyclonedds/blob/cd2136d9321212bd52fdc613f07bbebfddd90dec/src/core/ddsc/src/dds_init.c#L115
for i in 0..10 { for i_node in 0..N_NODE_PER_CONTEXT {
// create concurrent nodes that max out the cpu // create concurrent nodes that max out the cpu
let ctx = ctx.clone(); let ctx = ctx.clone();
ths.push(thread::spawn(move || { ths.push(thread::spawn(move || {
let mut node = r2r::Node::create(ctx, &format!("testnode{}", i), "").unwrap(); let mut node = r2r::Node::create(ctx, &format!("testnode_{}_{}", i_context, i_node), "").unwrap();
// each with 10 publishers // each with 10 publishers
for _j in 0..10 { for _j in 0..10 {
let p = node let p = node
.create_publisher::<r2r::std_msgs::msg::String>( .create_publisher::<r2r::std_msgs::msg::String>(
&format!("/r2r{}", i), &format!("/r2r{}", i_node),
QosProfile::default(), QosProfile::default(),
) )
.unwrap(); .unwrap();
let to_send = r2r::std_msgs::msg::String { let to_send = r2r::std_msgs::msg::String {
data: format!("[node{}]: {}", i, c), data: format!("[node{}]: {}", i_node, c),
}; };
// move publisher to its own thread and publish as fast as we can // move publisher to its own thread and publish as fast as we can
@ -59,6 +67,14 @@ fn doesnt_crash() -> Result<(), Box<dyn std::error::Error>> {
t.join().unwrap(); t.join().unwrap();
} }
// println!("all threads done {}", c); // println!("all threads done {}", c);
}
}
}));
for thread in threads.into_iter() {
thread.join().unwrap();
} }
Ok(()) Ok(())

View File

@ -1,25 +1,36 @@
use futures::stream::StreamExt; use futures::stream::StreamExt;
use r2r::QosProfile; use r2r::QosProfile;
use tokio::task; use tokio::task;
use r2r::WrappedTypesupport;
#[tokio::test(flavor = "multi_thread")]
const N_CONCURRENT_ROS_CONTEXT: usize = 3;
const N_TEARDOWN_CYCLES: usize = 2;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn tokio_subscribe_raw_testing() -> Result<(), Box<dyn std::error::Error>> { async fn tokio_subscribe_raw_testing() -> Result<(), Box<dyn std::error::Error>> {
let ctx = r2r::Context::create()?; let mut threads = futures::stream::FuturesUnordered::from_iter(
let mut node = r2r::Node::create(ctx, "testnode2", "")?; (0..N_CONCURRENT_ROS_CONTEXT).map(|i_context| tokio::spawn(async move {
// Iterate to check for memory corruption on node setup/teardown
for i_cycle in 0..N_TEARDOWN_CYCLES {
println!("tokio_subscribe_raw_testing iteration {i_cycle}");
let mut sub_int = node.subscribe_raw("/int", "std_msgs/msg/Int32", QosProfile::default())?; let ctx = r2r::Context::create().unwrap();
let mut node = r2r::Node::create(ctx, &format!("testnode2_{i_context}"), "").unwrap();
let mut sub_int = node.subscribe_raw("/int", "std_msgs/msg/Int32", QosProfile::default()).unwrap();
let mut sub_array = let mut sub_array =
node.subscribe_raw("/int_array", "std_msgs/msg/Int32MultiArray", QosProfile::default())?; node.subscribe_raw("/int_array", "std_msgs/msg/Int32MultiArray", QosProfile::default()).unwrap();
let pub_int = let pub_int =
node.create_publisher::<r2r::std_msgs::msg::Int32>("/int", QosProfile::default())?; node.create_publisher::<r2r::std_msgs::msg::Int32>("/int", QosProfile::default()).unwrap();
// Use an array as well since its a variable sized type // Use an array as well since its a variable sized type
let pub_array = node.create_publisher::<r2r::std_msgs::msg::Int32MultiArray>( let pub_array = node.create_publisher::<r2r::std_msgs::msg::Int32MultiArray>(
"/int_array", "/int_array",
QosProfile::default(), QosProfile::default(),
)?; ).unwrap();
task::spawn(async move { task::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(100)).await; tokio::time::sleep(std::time::Duration::from_millis(100)).await;
@ -57,9 +68,115 @@ async fn tokio_subscribe_raw_testing() -> Result<(), Box<dyn std::error::Error>>
} }
}); });
sub_int_handle.await?; sub_int_handle.await.unwrap();
sub_array_handle.await?; sub_array_handle.await.unwrap();
handle.join().unwrap(); handle.join().unwrap();
println!("Going to drop tokio_subscribe_raw_testing iteration {i_cycle}");
}
})));
while let Some(thread) = threads.next().await {
thread.unwrap();
}
Ok(())
}
// Limit the number of threads to force threads to be reused
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn tokio_publish_raw_testing() -> Result<(), Box<dyn std::error::Error>> {
let mut threads = futures::stream::FuturesUnordered::from_iter(
(0..N_CONCURRENT_ROS_CONTEXT).map(|i_context| tokio::spawn(async move {
// Iterate to check for memory corruption on node setup/teardown
for i_cycle in 0..N_TEARDOWN_CYCLES {
println!("tokio_publish_raw_testing iteration {i_cycle}");
let ctx = r2r::Context::create().unwrap();
let mut node = r2r::Node::create(ctx, &format!("testnode3_{i_context}"), "").unwrap();
let mut sub_int = node.subscribe::<r2r::std_msgs::msg::Int32>("/int", QosProfile::default()).unwrap();
let mut sub_array =
node.subscribe::<r2r::std_msgs::msg::Int32MultiArray>("/int_array", QosProfile::default()).unwrap();
let pub_int = node.create_publisher_untyped(
"/int",
"std_msgs/msg/Int32",
QosProfile::default()
).unwrap();
// Use an array as well since its a variable sized type
let pub_array = node.create_publisher_untyped(
"/int_array",
"std_msgs/msg/Int32MultiArray",
QosProfile::default(),
).unwrap();
task::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
(0..10).for_each(|i| {
pub_int
.publish_raw(&r2r::std_msgs::msg::Int32 { data: i }.to_serialized_bytes().unwrap())
.unwrap();
pub_array
.publish_raw(
&r2r::std_msgs::msg::Int32MultiArray {
layout: r2r::std_msgs::msg::MultiArrayLayout::default(),
data: vec![i],
}.to_serialized_bytes().unwrap()
)
.unwrap();
});
});
let sub_int_handle = task::spawn(async move {
while let Some(msg) = sub_int.next().await {
// Try to check for any possible corruption
msg.to_serialized_bytes().unwrap();
println!("Got int msg with value {}", msg.data);
assert!(msg.data >= 0);
assert!(msg.data < 10);
}
});
let sub_array_handle = task::spawn(async move {
while let Some(msg) = sub_array.next().await {
// Try to check for any possible corruption
msg.to_serialized_bytes().unwrap();
println!("Got array msg with value {:?}", msg.data);
assert_eq!(msg.data.len(), 1);
assert!(msg.data[0] >= 0);
assert!(msg.data[0] < 10);
}
});
let handle = std::thread::spawn(move || {
for _ in 1..=30 {
node.spin_once(std::time::Duration::from_millis(100));
}
});
sub_int_handle.await.unwrap();
sub_array_handle.await.unwrap();
handle.join().unwrap();
println!("Going to drop tokio_publish_raw_testing iteration {i_cycle}");
}
})));
while let Some(thread) = threads.next().await {
thread.unwrap();
}
Ok(()) Ok(())
} }

View File

@ -4,18 +4,31 @@ use r2r::QosProfile;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use tokio::task; use tokio::task;
#[tokio::test(flavor = "multi_thread")] const N_CONCURRENT_ROS_CONTEXT: usize = 3;
const N_TEARDOWN_CYCLES: usize = 2;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn tokio_testing() -> Result<(), Box<dyn std::error::Error>> { async fn tokio_testing() -> Result<(), Box<dyn std::error::Error>> {
let ctx = r2r::Context::create()?;
let mut node = r2r::Node::create(ctx, "testnode", "")?; let mut threads = futures::stream::FuturesUnordered::from_iter(
(0..N_CONCURRENT_ROS_CONTEXT).map(|i_context| tokio::spawn(async move {
// Iterate to check for memory corruption on node setup/teardown
for i_cycle in 0..N_TEARDOWN_CYCLES {
println!("tokio_testing iteration {i_cycle}");
let ctx = r2r::Context::create().unwrap();
// let ctx = std::thread::spawn(|| r2r::Context::create().unwrap()).join().unwrap();
let mut node = r2r::Node::create(ctx, &format!("testnode_{i_context}"), "").unwrap();
let mut s_the_no = let mut s_the_no =
node.subscribe::<r2r::std_msgs::msg::Int32>("/the_no", QosProfile::default())?; node.subscribe::<r2r::std_msgs::msg::Int32>(&format!("/the_no_{i_context}"), QosProfile::default()).unwrap();
let mut s_new_no = let mut s_new_no =
node.subscribe::<r2r::std_msgs::msg::Int32>("/new_no", QosProfile::default())?; node.subscribe::<r2r::std_msgs::msg::Int32>(&format!("/new_no_{i_context}"), QosProfile::default()).unwrap();
let p_the_no = let p_the_no =
node.create_publisher::<r2r::std_msgs::msg::Int32>("/the_no", QosProfile::default())?; node.create_publisher::<r2r::std_msgs::msg::Int32>(&format!("/the_no_{i_context}"), QosProfile::default()).unwrap();
let p_new_no = let p_new_no =
node.create_publisher::<r2r::std_msgs::msg::Int32>("/new_no", QosProfile::default())?; node.create_publisher::<r2r::std_msgs::msg::Int32>(&format!("/new_no_{i_context}"), QosProfile::default()).unwrap();
let state = Arc::new(Mutex::new(0)); let state = Arc::new(Mutex::new(0));
task::spawn(async move { task::spawn(async move {
@ -23,6 +36,9 @@ async fn tokio_testing() -> Result<(), Box<dyn std::error::Error>> {
p_the_no p_the_no
.publish(&r2r::std_msgs::msg::Int32 { data: i }) .publish(&r2r::std_msgs::msg::Int32 { data: i })
.unwrap(); .unwrap();
println!("send {i}");
}); });
}); });
@ -33,23 +49,31 @@ async fn tokio_testing() -> Result<(), Box<dyn std::error::Error>> {
data: msg.data + 10, data: msg.data + 10,
}) })
.unwrap(); .unwrap();
println!("got {}, send {}", msg.data, msg.data + 10);
} }
}); });
let s = state.clone(); let s = state.clone();
task::spawn(async move { task::spawn(async move {
while let Some(msg) = s_new_no.next().await { while let Some(msg) = s_new_no.next().await {
println!("got {}", msg.data);
let i = msg.data; let i = msg.data;
if i == 19 {
*s.lock().unwrap() = 19; *s.lock().unwrap() = i;
}
} }
}); });
let handle = std::thread::spawn(move || { // std::thread::spawn doesn't work here anymore?
for _ in 1..=30 { let handle = task::spawn_blocking(move || {
for _ in 1..30 {
node.spin_once(std::time::Duration::from_millis(100)); node.spin_once(std::time::Duration::from_millis(100));
let x = state.lock().unwrap(); let x = state.lock().unwrap();
println!("rec {}", x);
if *x == 19 { if *x == 19 {
break; break;
} }
@ -57,7 +81,17 @@ async fn tokio_testing() -> Result<(), Box<dyn std::error::Error>> {
*state.lock().unwrap() *state.lock().unwrap()
}); });
let x = handle.join().unwrap(); let x = handle.await.unwrap();
assert_eq!(x, 19); assert_eq!(x, 19);
println!("tokio_testing finish iteration {i_cycle}");
}
})));
while let Some(thread) = threads.next().await {
thread.unwrap();
}
Ok(()) Ok(())
} }