1use crate::config::{Cnx, CuConfig, NodeId};
6use crate::config::{ComponentConfig, Node};
7use crate::copperlist::{CopperList, CopperListState, CuListsManager};
8use crate::monitoring::CuMonitor;
9use cu29_clock::{ClockProvider, RobotClock};
10use cu29_log_runtime::LoggerRuntime;
11use cu29_traits::CopperListTuple;
12use cu29_traits::CuResult;
13use cu29_traits::WriteStream;
14use cu29_unifiedlog::UnifiedLoggerWrite;
15use std::sync::{Arc, Mutex};
16
17use petgraph::prelude::*;
18use petgraph::visit::VisitMap;
19use petgraph::visit::Visitable;
20use std::fmt::Debug;
21
22pub struct CopperContext {
24 pub unified_logger: Arc<Mutex<UnifiedLoggerWrite>>,
25 pub logger_runtime: LoggerRuntime,
26 pub clock: RobotClock,
27}
28
29pub struct CuRuntime<CT, P: CopperListTuple, M: CuMonitor, const NBCL: usize> {
33 pub tasks: CT,
35
36 pub monitor: M,
37
38 pub copper_lists_manager: CuListsManager<P, NBCL>,
40
41 pub clock: RobotClock, logger: Option<Box<dyn WriteStream<CopperList<P>>>>,
46}
47
48impl<CT, P: CopperListTuple, M: CuMonitor, const NBCL: usize> ClockProvider
50 for CuRuntime<CT, P, M, NBCL>
51{
52 fn get_clock(&self) -> RobotClock {
53 self.clock.clone()
54 }
55}
56
57impl<CT, P: CopperListTuple + 'static, M: CuMonitor, const NBCL: usize> CuRuntime<CT, P, M, NBCL> {
58 pub fn new(
59 clock: RobotClock,
60 config: &CuConfig,
61 tasks_instanciator: impl Fn(Vec<Option<&ComponentConfig>>) -> CuResult<CT>,
62 monitor_instanciator: impl Fn(&CuConfig) -> M,
63 logger: impl WriteStream<CopperList<P>> + 'static,
64 ) -> CuResult<Self> {
65 let all_instances_configs: Vec<Option<&ComponentConfig>> = config
66 .get_all_nodes(None) .iter()
68 .map(|(_, node)| node.get_instance_config())
69 .collect();
70 let tasks = tasks_instanciator(all_instances_configs)?;
71
72 let monitor = monitor_instanciator(config);
73
74 let logger_: Option<Box<dyn WriteStream<CopperList<P>>>> =
76 if let Some(logging_config) = &config.logging {
77 if logging_config.enable_task_logging {
78 Some(Box::new(logger))
79 } else {
80 None
81 }
82 } else {
83 Some(Box::new(logger))
84 };
85
86 let runtime = Self {
87 tasks,
88 monitor,
89 copper_lists_manager: CuListsManager::new(), clock,
91 logger: logger_,
92 };
93
94 Ok(runtime)
95 }
96
97 pub fn available_copper_lists(&self) -> usize {
98 NBCL - self.copper_lists_manager.len()
99 }
100
101 pub fn end_of_processing(&mut self, culistid: u32) {
102 let mut is_top = true;
103 let mut nb_done = 0;
104 self.copper_lists_manager.iter_mut().for_each(|cl| {
105 if cl.id == culistid && cl.get_state() == CopperListState::Processing {
106 cl.change_state(CopperListState::DoneProcessing);
107 }
108 if is_top && cl.get_state() == CopperListState::DoneProcessing {
111 if let Some(logger) = &mut self.logger {
112 cl.change_state(CopperListState::BeingSerialized);
113 logger.log(cl).unwrap();
114 }
115 cl.change_state(CopperListState::Free);
116 nb_done += 1;
117 } else {
118 is_top = false;
119 }
120 });
121 for _ in 0..nb_done {
122 let _ = self.copper_lists_manager.pop();
123 }
124 }
125}
126
127#[derive(Debug, PartialEq, Eq, Clone, Copy)]
132pub enum CuTaskType {
133 Source,
134 Regular,
135 Sink,
136}
137
138pub struct CuExecutionStep {
140 pub node_id: NodeId,
142 pub node: Node,
144 pub task_type: CuTaskType,
146
147 pub input_msg_indices_types: Vec<(u32, String)>,
149
150 pub output_msg_index_type: Option<(u32, String)>,
152}
153
154impl Debug for CuExecutionStep {
155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 f.write_str(format!(" CuExecutionStep: Node Id: {}\n", self.node_id).as_str())?;
157 f.write_str(format!(" task_type: {:?}\n", self.node.get_type()).as_str())?;
158 f.write_str(format!(" task: {:?}\n", self.task_type).as_str())?;
159 f.write_str(
160 format!(
161 " input_msg_types: {:?}\n",
162 self.input_msg_indices_types
163 )
164 .as_str(),
165 )?;
166 f.write_str(
167 format!(" output_msg_type: {:?}\n", self.output_msg_index_type).as_str(),
168 )?;
169 Ok(())
170 }
171}
172
173pub struct CuExecutionLoop {
178 pub steps: Vec<CuExecutionUnit>,
179 pub loop_count: Option<u32>,
180}
181
182impl Debug for CuExecutionLoop {
183 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184 f.write_str("CuExecutionLoop:\n")?;
185 for step in &self.steps {
186 match step {
187 CuExecutionUnit::Step(step) => {
188 step.fmt(f)?;
189 }
190 CuExecutionUnit::Loop(l) => {
191 l.fmt(f)?;
192 }
193 }
194 }
195
196 f.write_str(format!(" count: {:?}", self.loop_count).as_str())?;
197 Ok(())
198 }
199}
200
201#[derive(Debug)]
203pub enum CuExecutionUnit {
204 Step(CuExecutionStep),
205 Loop(CuExecutionLoop),
206}
207
208fn find_output_index_type_from_nodeid(
209 node_id: NodeId,
210 steps: &Vec<CuExecutionUnit>,
211) -> Option<(u32, String)> {
212 for step in steps {
213 match step {
214 CuExecutionUnit::Loop(loop_unit) => {
215 if let Some(index) = find_output_index_type_from_nodeid(node_id, &loop_unit.steps) {
216 return Some(index);
217 }
218 }
219 CuExecutionUnit::Step(step) => {
220 if step.node_id == node_id {
221 return step.output_msg_index_type.clone();
222 }
223 }
224 }
225 }
226 None
227}
228
229pub fn find_task_type_for_id(
230 graph: &StableDiGraph<Node, Cnx, NodeId>,
231 node_id: NodeId,
232) -> CuTaskType {
233 if graph.neighbors_directed(node_id.into(), Incoming).count() == 0 {
234 CuTaskType::Source
235 } else if graph.neighbors_directed(node_id.into(), Outgoing).count() == 0 {
236 CuTaskType::Sink
237 } else {
238 CuTaskType::Regular
239 }
240}
241
242fn find_edge_with_plan_input_id(
245 plan: &[CuExecutionUnit],
246 config: &CuConfig,
247 plan_id: u32,
248 output_node_id: NodeId,
249) -> usize {
250 let input_node = plan
251 .get(plan_id as usize)
252 .expect("Input step should've been added to plan before the step that receives the input");
253 let CuExecutionUnit::Step(input_step) = input_node else {
254 panic!("Expected input to be from a step, not a loop");
255 };
256 let input_node_id = input_step.node_id;
257
258 config
259 .get_graph(None)
260 .unwrap() .edges_connecting(input_node_id.into(), output_node_id.into())
262 .map(|edge| edge.id().index())
263 .next()
264 .expect("An edge connecting the input to the output should exist")
265}
266
267fn sort_inputs_by_cnx_id(
270 input_msg_indices_types: &mut [(u32, String)],
271 plan: &[CuExecutionUnit],
272 config: &CuConfig,
273 curr_node_id: NodeId,
274) {
275 input_msg_indices_types.sort_by(|(a_index, _), (b_index, _)| {
276 let a_edge_id = find_edge_with_plan_input_id(plan, config, *a_index, curr_node_id);
277 let b_edge_id = find_edge_with_plan_input_id(plan, config, *b_index, curr_node_id);
278 a_edge_id.cmp(&b_edge_id)
279 });
280}
281fn plan_tasks_tree_branch(
283 config: &CuConfig,
284 mut next_culist_output_index: u32,
285 starting_point: NodeId,
286 plan: &mut Vec<CuExecutionUnit>,
287) -> (u32, bool) {
288 #[cfg(feature = "macro_debug")]
289 eprintln!("-- starting branch from node {starting_point}");
290
291 let graph = config.get_graph(None).unwrap(); let mut visitor = Bfs::new(&graph, starting_point.into());
293 let mut handled = false;
294
295 while let Some(node) = visitor.next(&graph) {
296 let id = node.index() as NodeId;
297 let node_ref = config.graphs.get_node(id, None).unwrap();
298 #[cfg(feature = "macro_debug")]
299 eprintln!(" Visiting node: {node_ref:?}");
300
301 let mut input_msg_indices_types: Vec<(u32, String)> = Vec::new();
302 let output_msg_index_type: Option<(u32, String)>;
303 let task_type = find_task_type_for_id(graph, id);
304
305 match task_type {
306 CuTaskType::Source => {
307 #[cfg(feature = "macro_debug")]
308 eprintln!(" → Source node, assign output index {next_culist_output_index}");
309 output_msg_index_type = Some((
310 next_culist_output_index,
311 graph
312 .edge_weight(EdgeIndex::new(config.get_src_edges(id, None).unwrap()[0])) .unwrap()
314 .msg
315 .clone(),
316 ));
317 next_culist_output_index += 1;
318 }
319 CuTaskType::Sink => {
320 let parents: Vec<NodeIndex> =
321 graph.neighbors_directed(id.into(), Incoming).collect();
322 #[cfg(feature = "macro_debug")]
323 eprintln!(" → Sink with parents: {parents:?}");
324 for parent in &parents {
325 let pid = parent.index() as NodeId;
326 let index_type = find_output_index_type_from_nodeid(pid, plan);
327 if let Some(index_type) = index_type {
328 #[cfg(feature = "macro_debug")]
329 eprintln!(" ✓ Input from {pid} ready: {index_type:?}");
330 input_msg_indices_types.push(index_type);
331 } else {
332 #[cfg(feature = "macro_debug")]
333 eprintln!(" ✗ Input from {pid} not ready, returning");
334 return (next_culist_output_index, handled);
335 }
336 }
337 output_msg_index_type = Some((next_culist_output_index, "()".to_string()));
338 next_culist_output_index += 1;
339 }
340 CuTaskType::Regular => {
341 let parents: Vec<NodeIndex> =
342 graph.neighbors_directed(id.into(), Incoming).collect();
343 #[cfg(feature = "macro_debug")]
344 eprintln!(" → Regular task with parents: {parents:?}");
345 for parent in &parents {
346 let pid = parent.index() as NodeId;
347 let index_type = find_output_index_type_from_nodeid(pid, plan);
348 if let Some(index_type) = index_type {
349 #[cfg(feature = "macro_debug")]
350 eprintln!(" ✓ Input from {pid} ready: {index_type:?}");
351 input_msg_indices_types.push(index_type);
352 } else {
353 #[cfg(feature = "macro_debug")]
354 eprintln!(" ✗ Input from {pid} not ready, returning");
355 return (next_culist_output_index, handled);
356 }
357 }
358 output_msg_index_type = Some((
359 next_culist_output_index,
360 graph
361 .edge_weight(EdgeIndex::new(config.get_src_edges(id, None).unwrap()[0])) .unwrap()
363 .msg
364 .clone(),
365 ));
366 next_culist_output_index += 1;
367 }
368 }
369
370 sort_inputs_by_cnx_id(&mut input_msg_indices_types, plan, config, id);
371
372 if let Some(pos) = plan
373 .iter()
374 .position(|step| matches!(step, CuExecutionUnit::Step(s) if s.node_id == id))
375 {
376 #[cfg(feature = "macro_debug")]
377 eprintln!(" → Already in plan, modifying existing step");
378 let mut step = plan.remove(pos);
379 if let CuExecutionUnit::Step(ref mut s) = step {
380 s.input_msg_indices_types = input_msg_indices_types;
381 }
382 plan.push(step);
383 } else {
384 #[cfg(feature = "macro_debug")]
385 eprintln!(" → New step added to plan");
386 let step = CuExecutionStep {
387 node_id: id,
388 node: node_ref.clone(),
389 task_type,
390 input_msg_indices_types,
391 output_msg_index_type,
392 };
393 plan.push(CuExecutionUnit::Step(step));
394 }
395
396 handled = true;
397 }
398
399 #[cfg(feature = "macro_debug")]
400 eprintln!("-- finished branch from node {starting_point} with handled={handled}");
401 (next_culist_output_index, handled)
402}
403
404pub fn compute_runtime_plan(config: &CuConfig) -> CuResult<CuExecutionLoop> {
407 #[cfg(feature = "macro_debug")]
408 eprintln!("[runtime plan]");
409 let graph = config.get_graph(None).unwrap(); let visited = graph.visit_map();
411 let mut plan = Vec::new();
412 let mut next_culist_output_index = 0u32;
413
414 let graph = config.get_graph(None).unwrap(); let mut queue: std::collections::VecDeque<NodeId> = graph
416 .node_indices()
417 .filter(|&node| find_task_type_for_id(graph, node.index() as NodeId) == CuTaskType::Source)
418 .map(|node| node.index() as NodeId)
419 .collect();
420
421 #[cfg(feature = "macro_debug")]
422 eprintln!("Initial source nodes: {queue:?}");
423
424 while let Some(start_node) = queue.pop_front() {
425 if visited.is_visited(&start_node) {
426 #[cfg(feature = "macro_debug")]
427 eprintln!("→ Skipping already visited source {start_node}");
428 continue;
429 }
430
431 #[cfg(feature = "macro_debug")]
432 eprintln!("→ Starting BFS from source {start_node}");
433 let graph = config.get_graph(None).unwrap(); let mut bfs = Bfs::new(graph, start_node.into());
435
436 while let Some(node_index) = bfs.next(graph) {
437 let node_id = node_index.index() as NodeId;
438 let already_in_plan = plan
439 .iter()
440 .any(|unit| matches!(unit, CuExecutionUnit::Step(s) if s.node_id == node_id));
441 if already_in_plan {
442 #[cfg(feature = "macro_debug")]
443 eprintln!(" → Node {node_id} already planned, skipping");
444 continue;
445 }
446
447 #[cfg(feature = "macro_debug")]
448 eprintln!(" Planning from node {node_id}");
449 let (new_index, handled) =
450 plan_tasks_tree_branch(config, next_culist_output_index, node_id, &mut plan);
451 next_culist_output_index = new_index;
452
453 if !handled {
454 #[cfg(feature = "macro_debug")]
455 eprintln!(" ✗ Node {node_id} was not handled, skipping enqueue of neighbors");
456 continue;
457 }
458
459 #[cfg(feature = "macro_debug")]
460 eprintln!(" ✓ Node {node_id} handled successfully, enqueueing neighbors");
461 let graph = config.get_graph(None).unwrap(); for neighbor in graph.neighbors(node_index) {
463 if !visited.is_visited(&neighbor) {
464 let nid = neighbor.index() as NodeId;
465 #[cfg(feature = "macro_debug")]
466 eprintln!(" → Enqueueing neighbor {nid}");
467 queue.push_back(nid);
468 }
469 }
470 }
471 }
472
473 Ok(CuExecutionLoop {
474 steps: plan,
475 loop_count: None,
476 })
477}
478
479#[cfg(test)]
481mod tests {
482 use super::*;
483 use crate::config::Node;
484 use crate::cutask::CuSinkTask;
485 use crate::cutask::{CuSrcTask, Freezable};
486 use crate::monitoring::NoMonitor;
487 use bincode::Encode;
488
489 pub struct TestSource {}
490
491 impl Freezable for TestSource {}
492
493 impl CuSrcTask<'_> for TestSource {
494 type Output = ();
495 fn new(_config: Option<&ComponentConfig>) -> CuResult<Self>
496 where
497 Self: Sized,
498 {
499 Ok(Self {})
500 }
501
502 fn process(&mut self, _clock: &RobotClock, _empty_msg: Self::Output) -> CuResult<()> {
503 Ok(())
504 }
505 }
506
507 pub struct TestSink {}
508
509 impl Freezable for TestSink {}
510
511 impl CuSinkTask<'_> for TestSink {
512 type Input = ();
513
514 fn new(_config: Option<&ComponentConfig>) -> CuResult<Self>
515 where
516 Self: Sized,
517 {
518 Ok(Self {})
519 }
520
521 fn process(&mut self, _clock: &RobotClock, _input: Self::Input) -> CuResult<()> {
522 Ok(())
523 }
524 }
525
526 type Tasks = (TestSource, TestSink);
528 type Msgs = ((),);
529
530 fn tasks_instanciator(all_instances_configs: Vec<Option<&ComponentConfig>>) -> CuResult<Tasks> {
531 Ok((
532 TestSource::new(all_instances_configs[0])?,
533 TestSink::new(all_instances_configs[1])?,
534 ))
535 }
536
537 fn monitor_instanciator(_config: &CuConfig) -> NoMonitor {
538 NoMonitor {}
539 }
540
541 #[derive(Debug)]
542 struct FakeWriter {}
543
544 impl<E: Encode> WriteStream<E> for FakeWriter {
545 fn log(&mut self, _obj: &E) -> CuResult<()> {
546 Ok(())
547 }
548 }
549
550 #[test]
551 fn test_runtime_instantiation() {
552 let mut config = CuConfig::default();
553 let graph = config.get_graph_mut(None).unwrap();
554 graph.add_node(Node::new("a", "TestSource"));
555 graph.add_node(Node::new("b", "TestSink"));
556 config.connect(0, 1, "()").unwrap();
557 let runtime = CuRuntime::<Tasks, Msgs, NoMonitor, 2>::new(
558 RobotClock::default(),
559 &config,
560 tasks_instanciator,
561 monitor_instanciator,
562 FakeWriter {},
563 );
564 assert!(runtime.is_ok());
565 }
566
567 #[test]
568 fn test_copperlists_manager_lifecycle() {
569 let mut config = CuConfig::default();
570 let graph = config.get_graph_mut(None).unwrap();
571 graph.add_node(Node::new("a", "TestSource"));
572 graph.add_node(Node::new("b", "TestSink"));
573 config.connect(0, 1, "()").unwrap();
574 let mut runtime = CuRuntime::<Tasks, Msgs, NoMonitor, 2>::new(
575 RobotClock::default(),
576 &config,
577 tasks_instanciator,
578 monitor_instanciator,
579 FakeWriter {},
580 )
581 .unwrap();
582
583 {
585 let copperlists = &mut runtime.copper_lists_manager;
586 let culist0 = copperlists
587 .create()
588 .expect("Ran out of space for copper lists");
589 let id = culist0.id;
591 assert_eq!(id, 0);
592 culist0.change_state(CopperListState::Processing);
593 assert_eq!(runtime.available_copper_lists(), 1);
594 }
595
596 {
597 let copperlists = &mut runtime.copper_lists_manager;
598 let culist1 = copperlists
599 .create()
600 .expect("Ran out of space for copper lists"); let id = culist1.id;
602 assert_eq!(id, 1);
603 culist1.change_state(CopperListState::Processing);
604 assert_eq!(runtime.available_copper_lists(), 0);
605 }
606
607 {
608 let copperlists = &mut runtime.copper_lists_manager;
609 let culist2 = copperlists.create();
610 assert!(culist2.is_none());
611 assert_eq!(runtime.available_copper_lists(), 0);
612 }
613
614 runtime.end_of_processing(1);
616 assert_eq!(runtime.available_copper_lists(), 1);
617
618 {
620 let copperlists = &mut runtime.copper_lists_manager;
621 let culist2 = copperlists
622 .create()
623 .expect("Ran out of space for copper lists"); let id = culist2.id;
625 assert_eq!(id, 2);
626 culist2.change_state(CopperListState::Processing);
627 assert_eq!(runtime.available_copper_lists(), 0);
628 }
629
630 runtime.end_of_processing(0);
632 assert_eq!(runtime.available_copper_lists(), 0);
634
635 runtime.end_of_processing(2);
637 assert_eq!(runtime.available_copper_lists(), 2);
640 }
641
642 #[test]
643 fn test_runtime_task_input_order() {
644 let mut config = CuConfig::default();
645 let src1_id = config.add_node(Node::new("a", "Source1"), None).unwrap();
646 let src2_id = config.add_node(Node::new("b", "Source2"), None).unwrap();
647 let sink_id = config.add_node(Node::new("c", "Sink"), None).unwrap();
648
649 assert_eq!(src1_id, 0);
650 assert_eq!(src2_id, 1);
651
652 let src1_type = "src1_type";
654 let src2_type = "src2_type";
655 config.connect(src2_id, sink_id, src2_type).unwrap();
656 config.connect(src1_id, sink_id, src1_type).unwrap();
657
658 let src1_edge_id = *config
659 .get_src_edges(src1_id, None)
660 .unwrap()
661 .first()
662 .unwrap();
663 let src2_edge_id = *config
664 .get_src_edges(src2_id, None)
665 .unwrap()
666 .first()
667 .unwrap();
668 assert_eq!(src1_edge_id, 1);
671 assert_eq!(src2_edge_id, 0);
672
673 let runtime = compute_runtime_plan(&config).unwrap();
674 let sink_step = runtime
675 .steps
676 .iter()
677 .find_map(|step| match step {
678 CuExecutionUnit::Step(step) if step.node_id == sink_id => Some(step),
679 _ => None,
680 })
681 .unwrap();
682
683 assert_eq!(sink_step.input_msg_indices_types[0].1, src2_type);
686 assert_eq!(sink_step.input_msg_indices_types[1].1, src1_type);
687 }
688
689 #[test]
690 fn test_runtime_plan_diamond_case1() {
691 let mut config = CuConfig::default();
693 let cam0_id = config
694 .add_node(Node::new("cam0", "tasks::IntegerSrcTask"), None)
695 .unwrap();
696 let inf0_id = config
697 .add_node(Node::new("inf0", "tasks::Integer2FloatTask"), None)
698 .unwrap();
699 let broadcast_id = config
700 .add_node(Node::new("broadcast", "tasks::MergingSinkTask"), None)
701 .unwrap();
702
703 config.connect(cam0_id, broadcast_id, "i32").unwrap();
705 config.connect(cam0_id, inf0_id, "i32").unwrap();
706 config.connect(inf0_id, broadcast_id, "f32").unwrap();
707
708 let edge_cam0_to_broadcast = *config
709 .get_src_edges(cam0_id, None)
710 .unwrap()
711 .first()
712 .unwrap();
713 let edge_cam0_to_inf0 = config.get_src_edges(cam0_id, None).unwrap()[1];
714
715 assert_eq!(edge_cam0_to_inf0, 0);
716 assert_eq!(edge_cam0_to_broadcast, 1);
717
718 let runtime = compute_runtime_plan(&config).unwrap();
719 let broadcast_step = runtime
720 .steps
721 .iter()
722 .find_map(|step| match step {
723 CuExecutionUnit::Step(step) if step.node_id == broadcast_id => Some(step),
724 _ => None,
725 })
726 .unwrap();
727
728 assert_eq!(broadcast_step.input_msg_indices_types[0].1, "i32");
729 assert_eq!(broadcast_step.input_msg_indices_types[1].1, "f32");
730 }
731
732 #[test]
733 fn test_runtime_plan_diamond_case2() {
734 let mut config = CuConfig::default();
736 let cam0_id = config
737 .add_node(Node::new("cam0", "tasks::IntegerSrcTask"), None)
738 .unwrap();
739 let inf0_id = config
740 .add_node(Node::new("inf0", "tasks::Integer2FloatTask"), None)
741 .unwrap();
742 let broadcast_id = config
743 .add_node(Node::new("broadcast", "tasks::MergingSinkTask"), None)
744 .unwrap();
745
746 config.connect(cam0_id, inf0_id, "i32").unwrap();
748 config.connect(cam0_id, broadcast_id, "i32").unwrap();
749 config.connect(inf0_id, broadcast_id, "f32").unwrap();
750
751 let edge_cam0_to_inf0 = *config
752 .get_src_edges(cam0_id, None)
753 .unwrap()
754 .first()
755 .unwrap();
756 let edge_cam0_to_broadcast = config.get_src_edges(cam0_id, None).unwrap()[1];
757
758 assert_eq!(edge_cam0_to_broadcast, 0);
759 assert_eq!(edge_cam0_to_inf0, 1);
760
761 let runtime = compute_runtime_plan(&config).unwrap();
762 let broadcast_step = runtime
763 .steps
764 .iter()
765 .find_map(|step| match step {
766 CuExecutionUnit::Step(step) if step.node_id == broadcast_id => Some(step),
767 _ => None,
768 })
769 .unwrap();
770
771 assert_eq!(broadcast_step.input_msg_indices_types[0].1, "i32");
772 assert_eq!(broadcast_step.input_msg_indices_types[1].1, "f32");
773 }
774}