1use crate::config::{ComponentConfig, Node, DEFAULT_KEYFRAME_INTERVAL};
6use crate::config::{CuConfig, CuGraph, NodeId};
7use crate::copperlist::{CopperList, CopperListState, CuListsManager};
8use crate::cutask::{BincodeAdapter, Freezable};
9use crate::monitoring::CuMonitor;
10use cu29_clock::{ClockProvider, RobotClock};
11use cu29_log_runtime::LoggerRuntime;
12use cu29_traits::CuResult;
13use cu29_traits::WriteStream;
14use cu29_traits::{CopperListTuple, CuError};
15use cu29_unifiedlog::UnifiedLoggerWrite;
16use std::sync::{Arc, Mutex};
17
18use bincode::error::EncodeError;
19use bincode::{encode_into_std_write, Decode, Encode};
20use petgraph::prelude::*;
21use petgraph::visit::VisitMap;
22use petgraph::visit::Visitable;
23use std::fmt::Debug;
24
25pub struct CopperContext {
27 pub unified_logger: Arc<Mutex<UnifiedLoggerWrite>>,
28 pub logger_runtime: LoggerRuntime,
29 pub clock: RobotClock,
30}
31
32pub struct CopperListsManager<P: CopperListTuple, const NBCL: usize> {
34 pub inner: CuListsManager<P, NBCL>,
35 pub logger: Option<Box<dyn WriteStream<CopperList<P>>>>,
37}
38
39impl<P: CopperListTuple, const NBCL: usize> CopperListsManager<P, NBCL> {
40 pub fn end_of_processing(&mut self, culistid: u32) {
41 let mut is_top = true;
42 let mut nb_done = 0;
43 self.inner.iter_mut().for_each(|cl| {
44 if cl.id == culistid && cl.get_state() == CopperListState::Processing {
45 cl.change_state(CopperListState::DoneProcessing);
46 }
47 if is_top && cl.get_state() == CopperListState::DoneProcessing {
50 if let Some(logger) = &mut self.logger {
51 cl.change_state(CopperListState::BeingSerialized);
52 logger.log(cl).unwrap();
53 }
54 cl.change_state(CopperListState::Free);
55 nb_done += 1;
56 } else {
57 is_top = false;
58 }
59 });
60 for _ in 0..nb_done {
61 let _ = self.inner.pop();
62 }
63 }
64
65 pub fn available_copper_lists(&self) -> usize {
66 NBCL - self.inner.len()
67 }
68}
69
70pub struct KeyFramesManager {
72 inner: KeyFrame,
74
75 logger: Option<Box<dyn WriteStream<KeyFrame>>>,
77
78 keyframe_interval: u32,
80}
81
82impl KeyFramesManager {
83 fn is_keyframe(&self, culistid: u32) -> bool {
84 self.logger.is_some() && culistid % self.keyframe_interval == 0
85 }
86
87 pub fn reset(&mut self, culistid: u32) {
88 if self.is_keyframe(culistid) {
89 self.inner.reset(culistid);
90 }
91 }
92
93 pub fn freeze_task(&mut self, culistid: u32, task: &impl Freezable) -> CuResult<usize> {
94 if self.is_keyframe(culistid) {
95 if self.inner.culistid != culistid {
96 panic!("Freezing task for a different culistid");
97 }
98 self.inner
99 .add_frozen_task(task)
100 .map_err(|e| CuError::from(format!("Failed to serialize task: {e}")))
101 } else {
102 Ok(0)
103 }
104 }
105}
106
107pub struct CuRuntime<CT, P: CopperListTuple, M: CuMonitor, const NBCL: usize> {
111 pub clock: RobotClock, pub tasks: CT,
116
117 pub monitor: M,
119
120 pub copperlists_manager: CopperListsManager<P, NBCL>,
122
123 pub keyframes_manager: KeyFramesManager,
125}
126
127impl<CT, P: CopperListTuple, M: CuMonitor, const NBCL: usize> ClockProvider
129 for CuRuntime<CT, P, M, NBCL>
130{
131 fn get_clock(&self) -> RobotClock {
132 self.clock.clone()
133 }
134}
135
136#[derive(Encode, Decode)]
140pub struct KeyFrame {
141 pub culistid: u32,
143 pub serialized_tasks: Vec<u8>,
145}
146
147impl KeyFrame {
148 fn new() -> Self {
149 KeyFrame {
150 culistid: 0,
151 serialized_tasks: Vec::new(),
152 }
153 }
154
155 fn reset(&mut self, culistid: u32) {
157 self.culistid = culistid;
158 self.serialized_tasks.clear();
159 }
160
161 fn add_frozen_task(&mut self, task: &impl Freezable) -> Result<usize, EncodeError> {
163 let config = bincode::config::standard();
164 encode_into_std_write(BincodeAdapter(task), &mut self.serialized_tasks, config)
165 }
166}
167
168impl<CT, P: CopperListTuple + 'static, M: CuMonitor, const NBCL: usize> CuRuntime<CT, P, M, NBCL> {
169 pub fn new(
170 clock: RobotClock,
171 config: &CuConfig,
172 mission: Option<&str>,
173 tasks_instanciator: impl Fn(Vec<Option<&ComponentConfig>>) -> CuResult<CT>,
174 monitor_instanciator: impl Fn(&CuConfig) -> M,
175 copperlists_logger: impl WriteStream<CopperList<P>> + 'static,
176 keyframes_logger: impl WriteStream<KeyFrame> + 'static,
177 ) -> CuResult<Self> {
178 let graph = config.get_graph(mission)?;
179 let all_instances_configs: Vec<Option<&ComponentConfig>> = graph
180 .get_all_nodes()
181 .iter()
182 .map(|(_, node)| node.get_instance_config())
183 .collect();
184 let tasks = tasks_instanciator(all_instances_configs)?;
185
186 let monitor = monitor_instanciator(config);
187
188 let (copperlists_logger, keyframes_logger, keyframe_interval) = match &config.logging {
189 Some(logging_config) if logging_config.enable_task_logging => (
190 Some(Box::new(copperlists_logger) as Box<dyn WriteStream<CopperList<P>>>),
191 Some(Box::new(keyframes_logger) as Box<dyn WriteStream<KeyFrame>>),
192 logging_config.keyframe_interval.unwrap(), ),
194 Some(_) => (None, None, 0), None => (
196 Some(Box::new(copperlists_logger) as Box<dyn WriteStream<CopperList<P>>>),
198 Some(Box::new(keyframes_logger) as Box<dyn WriteStream<KeyFrame>>),
199 DEFAULT_KEYFRAME_INTERVAL,
200 ),
201 };
202
203 let copperlists_manager = CopperListsManager {
204 inner: CuListsManager::new(),
205 logger: copperlists_logger,
206 };
207
208 let keyframes_manager = KeyFramesManager {
209 inner: KeyFrame::new(),
210 logger: keyframes_logger,
211 keyframe_interval,
212 };
213
214 let runtime = Self {
215 tasks,
216 monitor,
217 clock,
218 copperlists_manager,
219 keyframes_manager,
220 };
221
222 Ok(runtime)
223 }
224}
225
226#[derive(Debug, PartialEq, Eq, Clone, Copy)]
231pub enum CuTaskType {
232 Source,
233 Regular,
234 Sink,
235}
236
237pub struct CuExecutionStep {
239 pub node_id: NodeId,
241 pub node: Node,
243 pub task_type: CuTaskType,
245
246 pub input_msg_indices_types: Vec<(u32, String)>,
248
249 pub output_msg_index_type: Option<(u32, String)>,
251}
252
253impl Debug for CuExecutionStep {
254 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255 f.write_str(format!(" CuExecutionStep: Node Id: {}\n", self.node_id).as_str())?;
256 f.write_str(format!(" task_type: {:?}\n", self.node.get_type()).as_str())?;
257 f.write_str(format!(" task: {:?}\n", self.task_type).as_str())?;
258 f.write_str(
259 format!(
260 " input_msg_types: {:?}\n",
261 self.input_msg_indices_types
262 )
263 .as_str(),
264 )?;
265 f.write_str(
266 format!(" output_msg_type: {:?}\n", self.output_msg_index_type).as_str(),
267 )?;
268 Ok(())
269 }
270}
271
272pub struct CuExecutionLoop {
277 pub steps: Vec<CuExecutionUnit>,
278 pub loop_count: Option<u32>,
279}
280
281impl Debug for CuExecutionLoop {
282 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283 f.write_str("CuExecutionLoop:\n")?;
284 for step in &self.steps {
285 match step {
286 CuExecutionUnit::Step(step) => {
287 step.fmt(f)?;
288 }
289 CuExecutionUnit::Loop(l) => {
290 l.fmt(f)?;
291 }
292 }
293 }
294
295 f.write_str(format!(" count: {:?}", self.loop_count).as_str())?;
296 Ok(())
297 }
298}
299
300#[derive(Debug)]
302pub enum CuExecutionUnit {
303 Step(CuExecutionStep),
304 Loop(CuExecutionLoop),
305}
306
307fn find_output_index_type_from_nodeid(
308 node_id: NodeId,
309 steps: &Vec<CuExecutionUnit>,
310) -> Option<(u32, String)> {
311 for step in steps {
312 match step {
313 CuExecutionUnit::Loop(loop_unit) => {
314 if let Some(index) = find_output_index_type_from_nodeid(node_id, &loop_unit.steps) {
315 return Some(index);
316 }
317 }
318 CuExecutionUnit::Step(step) => {
319 if step.node_id == node_id {
320 return step.output_msg_index_type.clone();
321 }
322 }
323 }
324 }
325 None
326}
327
328pub fn find_task_type_for_id(graph: &CuGraph, node_id: NodeId) -> CuTaskType {
329 if graph.0.neighbors_directed(node_id.into(), Incoming).count() == 0 {
330 CuTaskType::Source
331 } else if graph.0.neighbors_directed(node_id.into(), Outgoing).count() == 0 {
332 CuTaskType::Sink
333 } else {
334 CuTaskType::Regular
335 }
336}
337
338fn find_edge_with_plan_input_id(
341 plan: &[CuExecutionUnit],
342 graph: &CuGraph,
343 plan_id: u32,
344 output_node_id: NodeId,
345) -> usize {
346 let input_node = plan
347 .get(plan_id as usize)
348 .expect("Input step should've been added to plan before the step that receives the input");
349 let CuExecutionUnit::Step(input_step) = input_node else {
350 panic!("Expected input to be from a step, not a loop");
351 };
352 let input_node_id = input_step.node_id;
353
354 graph
355 .0
356 .edges_connecting(input_node_id.into(), output_node_id.into())
357 .map(|edge| edge.id().index())
358 .next()
359 .expect("An edge connecting the input to the output should exist")
360}
361
362fn sort_inputs_by_cnx_id(
365 input_msg_indices_types: &mut [(u32, String)],
366 plan: &[CuExecutionUnit],
367 graph: &CuGraph,
368 curr_node_id: NodeId,
369) {
370 input_msg_indices_types.sort_by(|(a_index, _), (b_index, _)| {
371 let a_edge_id = find_edge_with_plan_input_id(plan, graph, *a_index, curr_node_id);
372 let b_edge_id = find_edge_with_plan_input_id(plan, graph, *b_index, curr_node_id);
373 a_edge_id.cmp(&b_edge_id)
374 });
375}
376fn plan_tasks_tree_branch(
378 graph: &CuGraph,
379 mut next_culist_output_index: u32,
380 starting_point: NodeId,
381 plan: &mut Vec<CuExecutionUnit>,
382) -> (u32, bool) {
383 #[cfg(feature = "macro_debug")]
384 eprintln!("-- starting branch from node {starting_point}");
385
386 let mut visitor = Bfs::new(&graph.0, starting_point.into());
387 let mut handled = false;
388
389 while let Some(node) = visitor.next(&graph.0) {
390 let id = node.index() as NodeId;
391 let node_ref = graph.get_node(id).unwrap();
392 #[cfg(feature = "macro_debug")]
393 eprintln!(" Visiting node: {node_ref:?}");
394
395 let mut input_msg_indices_types: Vec<(u32, String)> = Vec::new();
396 let output_msg_index_type: Option<(u32, String)>;
397 let task_type = find_task_type_for_id(graph, id);
398
399 match task_type {
400 CuTaskType::Source => {
401 #[cfg(feature = "macro_debug")]
402 eprintln!(" → Source node, assign output index {next_culist_output_index}");
403 output_msg_index_type = Some((
404 next_culist_output_index,
405 graph
406 .0
407 .edge_weight(EdgeIndex::new(graph.get_src_edges(id).unwrap()[0]))
408 .unwrap() .msg
410 .clone(),
411 ));
412 next_culist_output_index += 1;
413 }
414 CuTaskType::Sink => {
415 let parents: Vec<NodeIndex> =
416 graph.0.neighbors_directed(id.into(), Incoming).collect();
417 #[cfg(feature = "macro_debug")]
418 eprintln!(" → Sink with parents: {parents:?}");
419 for parent in &parents {
420 let pid = parent.index() as NodeId;
421 let index_type = find_output_index_type_from_nodeid(pid, plan);
422 if let Some(index_type) = index_type {
423 #[cfg(feature = "macro_debug")]
424 eprintln!(" ✓ Input from {pid} ready: {index_type:?}");
425 input_msg_indices_types.push(index_type);
426 } else {
427 #[cfg(feature = "macro_debug")]
428 eprintln!(" ✗ Input from {pid} not ready, returning");
429 return (next_culist_output_index, handled);
430 }
431 }
432 output_msg_index_type = Some((next_culist_output_index, "()".to_string()));
433 next_culist_output_index += 1;
434 }
435 CuTaskType::Regular => {
436 let parents: Vec<NodeIndex> =
437 graph.0.neighbors_directed(id.into(), Incoming).collect();
438 #[cfg(feature = "macro_debug")]
439 eprintln!(" → Regular task with parents: {parents:?}");
440 for parent in &parents {
441 let pid = parent.index() as NodeId;
442 let index_type = find_output_index_type_from_nodeid(pid, plan);
443 if let Some(index_type) = index_type {
444 #[cfg(feature = "macro_debug")]
445 eprintln!(" ✓ Input from {pid} ready: {index_type:?}");
446 input_msg_indices_types.push(index_type);
447 } else {
448 #[cfg(feature = "macro_debug")]
449 eprintln!(" ✗ Input from {pid} not ready, returning");
450 return (next_culist_output_index, handled);
451 }
452 }
453 output_msg_index_type = Some((
454 next_culist_output_index,
455 graph
456 .0
457 .edge_weight(EdgeIndex::new(graph.get_src_edges(id).unwrap()[0])) .unwrap()
459 .msg
460 .clone(),
461 ));
462 next_culist_output_index += 1;
463 }
464 }
465
466 sort_inputs_by_cnx_id(&mut input_msg_indices_types, plan, graph, id);
467
468 if let Some(pos) = plan
469 .iter()
470 .position(|step| matches!(step, CuExecutionUnit::Step(s) if s.node_id == id))
471 {
472 #[cfg(feature = "macro_debug")]
473 eprintln!(" → Already in plan, modifying existing step");
474 let mut step = plan.remove(pos);
475 if let CuExecutionUnit::Step(ref mut s) = step {
476 s.input_msg_indices_types = input_msg_indices_types;
477 }
478 plan.push(step);
479 } else {
480 #[cfg(feature = "macro_debug")]
481 eprintln!(" → New step added to plan");
482 let step = CuExecutionStep {
483 node_id: id,
484 node: node_ref.clone(),
485 task_type,
486 input_msg_indices_types,
487 output_msg_index_type,
488 };
489 plan.push(CuExecutionUnit::Step(step));
490 }
491
492 handled = true;
493 }
494
495 #[cfg(feature = "macro_debug")]
496 eprintln!("-- finished branch from node {starting_point} with handled={handled}");
497 (next_culist_output_index, handled)
498}
499
500pub fn compute_runtime_plan(graph: &CuGraph) -> CuResult<CuExecutionLoop> {
503 #[cfg(feature = "macro_debug")]
504 eprintln!("[runtime plan]");
505 let visited = graph.0.visit_map();
506 let mut plan = Vec::new();
507 let mut next_culist_output_index = 0u32;
508
509 let mut queue: std::collections::VecDeque<NodeId> = graph
510 .node_indices()
511 .iter()
512 .filter(|&node| find_task_type_for_id(graph, node.index() as NodeId) == CuTaskType::Source)
513 .map(|node| node.index() as NodeId)
514 .collect();
515
516 #[cfg(feature = "macro_debug")]
517 eprintln!("Initial source nodes: {queue:?}");
518
519 while let Some(start_node) = queue.pop_front() {
520 if visited.is_visited(&start_node) {
521 #[cfg(feature = "macro_debug")]
522 eprintln!("→ Skipping already visited source {start_node}");
523 continue;
524 }
525
526 #[cfg(feature = "macro_debug")]
527 eprintln!("→ Starting BFS from source {start_node}");
528 let mut bfs = Bfs::new(&graph.0, start_node.into());
529
530 while let Some(node_index) = bfs.next(&graph.0) {
531 let node_id = node_index.index() as NodeId;
532 let already_in_plan = plan
533 .iter()
534 .any(|unit| matches!(unit, CuExecutionUnit::Step(s) if s.node_id == node_id));
535 if already_in_plan {
536 #[cfg(feature = "macro_debug")]
537 eprintln!(" → Node {node_id} already planned, skipping");
538 continue;
539 }
540
541 #[cfg(feature = "macro_debug")]
542 eprintln!(" Planning from node {node_id}");
543 let (new_index, handled) =
544 plan_tasks_tree_branch(graph, next_culist_output_index, node_id, &mut plan);
545 next_culist_output_index = new_index;
546
547 if !handled {
548 #[cfg(feature = "macro_debug")]
549 eprintln!(" ✗ Node {node_id} was not handled, skipping enqueue of neighbors");
550 continue;
551 }
552
553 #[cfg(feature = "macro_debug")]
554 eprintln!(" ✓ Node {node_id} handled successfully, enqueueing neighbors");
555 for neighbor in graph.0.neighbors(node_index) {
556 if !visited.is_visited(&neighbor) {
557 let nid = neighbor.index() as NodeId;
558 #[cfg(feature = "macro_debug")]
559 eprintln!(" → Enqueueing neighbor {nid}");
560 queue.push_back(nid);
561 }
562 }
563 }
564 }
565
566 Ok(CuExecutionLoop {
567 steps: plan,
568 loop_count: None,
569 })
570}
571
572#[cfg(test)]
574mod tests {
575 use super::*;
576 use crate::config::Node;
577 use crate::cutask::CuSinkTask;
578 use crate::cutask::{CuSrcTask, Freezable};
579 use crate::monitoring::NoMonitor;
580 use bincode::Encode;
581
582 pub struct TestSource {}
583
584 impl Freezable for TestSource {}
585
586 impl CuSrcTask<'_> for TestSource {
587 type Output = ();
588 fn new(_config: Option<&ComponentConfig>) -> CuResult<Self>
589 where
590 Self: Sized,
591 {
592 Ok(Self {})
593 }
594
595 fn process(&mut self, _clock: &RobotClock, _empty_msg: Self::Output) -> CuResult<()> {
596 Ok(())
597 }
598 }
599
600 pub struct TestSink {}
601
602 impl Freezable for TestSink {}
603
604 impl CuSinkTask<'_> for TestSink {
605 type Input = ();
606
607 fn new(_config: Option<&ComponentConfig>) -> CuResult<Self>
608 where
609 Self: Sized,
610 {
611 Ok(Self {})
612 }
613
614 fn process(&mut self, _clock: &RobotClock, _input: Self::Input) -> CuResult<()> {
615 Ok(())
616 }
617 }
618
619 type Tasks = (TestSource, TestSink);
621 type Msgs = ((),);
622
623 fn tasks_instanciator(all_instances_configs: Vec<Option<&ComponentConfig>>) -> CuResult<Tasks> {
624 Ok((
625 TestSource::new(all_instances_configs[0])?,
626 TestSink::new(all_instances_configs[1])?,
627 ))
628 }
629
630 fn monitor_instanciator(_config: &CuConfig) -> NoMonitor {
631 NoMonitor {}
632 }
633
634 #[derive(Debug)]
635 struct FakeWriter {}
636
637 impl<E: Encode> WriteStream<E> for FakeWriter {
638 fn log(&mut self, _obj: &E) -> CuResult<()> {
639 Ok(())
640 }
641 }
642
643 #[test]
644 fn test_runtime_instantiation() {
645 let mut config = CuConfig::default();
646 let graph = config.get_graph_mut(None).unwrap();
647 graph.add_node(Node::new("a", "TestSource")).unwrap();
648 graph.add_node(Node::new("b", "TestSink")).unwrap();
649 graph.connect(0, 1, "()").unwrap();
650 let runtime = CuRuntime::<Tasks, Msgs, NoMonitor, 2>::new(
651 RobotClock::default(),
652 &config,
653 None,
654 tasks_instanciator,
655 monitor_instanciator,
656 FakeWriter {},
657 FakeWriter {},
658 );
659 assert!(runtime.is_ok());
660 }
661
662 #[test]
663 fn test_copperlists_manager_lifecycle() {
664 let mut config = CuConfig::default();
665 let graph = config.get_graph_mut(None).unwrap();
666 graph.add_node(Node::new("a", "TestSource")).unwrap();
667 graph.add_node(Node::new("b", "TestSink")).unwrap();
668 graph.connect(0, 1, "()").unwrap();
669 let mut runtime = CuRuntime::<Tasks, Msgs, NoMonitor, 2>::new(
670 RobotClock::default(),
671 &config,
672 None,
673 tasks_instanciator,
674 monitor_instanciator,
675 FakeWriter {},
676 FakeWriter {},
677 )
678 .unwrap();
679
680 {
682 let copperlists = &mut runtime.copperlists_manager;
683 let culist0 = copperlists
684 .inner
685 .create()
686 .expect("Ran out of space for copper lists");
687 let id = culist0.id;
689 assert_eq!(id, 0);
690 culist0.change_state(CopperListState::Processing);
691 assert_eq!(copperlists.available_copper_lists(), 1);
692 }
693
694 {
695 let copperlists = &mut runtime.copperlists_manager;
696 let culist1 = copperlists
697 .inner
698 .create()
699 .expect("Ran out of space for copper lists"); let id = culist1.id;
701 assert_eq!(id, 1);
702 culist1.change_state(CopperListState::Processing);
703 assert_eq!(copperlists.available_copper_lists(), 0);
704 }
705
706 {
707 let copperlists = &mut runtime.copperlists_manager;
708 let culist2 = copperlists.inner.create();
709 assert!(culist2.is_none());
710 assert_eq!(copperlists.available_copper_lists(), 0);
711 copperlists.end_of_processing(1);
713 assert_eq!(copperlists.available_copper_lists(), 1);
714 }
715
716 {
718 let copperlists = &mut runtime.copperlists_manager;
719 let culist2 = copperlists
720 .inner
721 .create()
722 .expect("Ran out of space for copper lists"); let id = culist2.id;
724 assert_eq!(id, 2);
725 culist2.change_state(CopperListState::Processing);
726 assert_eq!(copperlists.available_copper_lists(), 0);
727 copperlists.end_of_processing(0);
729 assert_eq!(copperlists.available_copper_lists(), 0);
731
732 copperlists.end_of_processing(2);
734 assert_eq!(copperlists.available_copper_lists(), 2);
737 }
738 }
739
740 #[test]
741 fn test_runtime_task_input_order() {
742 let mut config = CuConfig::default();
743 let graph = config.get_graph_mut(None).unwrap();
744 let src1_id = graph.add_node(Node::new("a", "Source1")).unwrap();
745 let src2_id = graph.add_node(Node::new("b", "Source2")).unwrap();
746 let sink_id = graph.add_node(Node::new("c", "Sink")).unwrap();
747
748 assert_eq!(src1_id, 0);
749 assert_eq!(src2_id, 1);
750
751 let src1_type = "src1_type";
753 let src2_type = "src2_type";
754 graph.connect(src2_id, sink_id, src2_type).unwrap();
755 graph.connect(src1_id, sink_id, src1_type).unwrap();
756
757 let src1_edge_id = *graph.get_src_edges(src1_id).unwrap().first().unwrap();
758 let src2_edge_id = *graph.get_src_edges(src2_id).unwrap().first().unwrap();
759 assert_eq!(src1_edge_id, 1);
762 assert_eq!(src2_edge_id, 0);
763
764 let runtime = compute_runtime_plan(graph).unwrap();
765 let sink_step = runtime
766 .steps
767 .iter()
768 .find_map(|step| match step {
769 CuExecutionUnit::Step(step) if step.node_id == sink_id => Some(step),
770 _ => None,
771 })
772 .unwrap();
773
774 assert_eq!(sink_step.input_msg_indices_types[0].1, src2_type);
777 assert_eq!(sink_step.input_msg_indices_types[1].1, src1_type);
778 }
779
780 #[test]
781 fn test_runtime_plan_diamond_case1() {
782 let mut config = CuConfig::default();
784 let graph = config.get_graph_mut(None).unwrap();
785 let cam0_id = graph
786 .add_node(Node::new("cam0", "tasks::IntegerSrcTask"))
787 .unwrap();
788 let inf0_id = graph
789 .add_node(Node::new("inf0", "tasks::Integer2FloatTask"))
790 .unwrap();
791 let broadcast_id = graph
792 .add_node(Node::new("broadcast", "tasks::MergingSinkTask"))
793 .unwrap();
794
795 graph.connect(cam0_id, broadcast_id, "i32").unwrap();
797 graph.connect(cam0_id, inf0_id, "i32").unwrap();
798 graph.connect(inf0_id, broadcast_id, "f32").unwrap();
799
800 let edge_cam0_to_broadcast = *graph.get_src_edges(cam0_id).unwrap().first().unwrap();
801 let edge_cam0_to_inf0 = graph.get_src_edges(cam0_id).unwrap()[1];
802
803 assert_eq!(edge_cam0_to_inf0, 0);
804 assert_eq!(edge_cam0_to_broadcast, 1);
805
806 let runtime = compute_runtime_plan(graph).unwrap();
807 let broadcast_step = runtime
808 .steps
809 .iter()
810 .find_map(|step| match step {
811 CuExecutionUnit::Step(step) if step.node_id == broadcast_id => Some(step),
812 _ => None,
813 })
814 .unwrap();
815
816 assert_eq!(broadcast_step.input_msg_indices_types[0].1, "i32");
817 assert_eq!(broadcast_step.input_msg_indices_types[1].1, "f32");
818 }
819
820 #[test]
821 fn test_runtime_plan_diamond_case2() {
822 let mut config = CuConfig::default();
824 let graph = config.get_graph_mut(None).unwrap();
825 let cam0_id = graph
826 .add_node(Node::new("cam0", "tasks::IntegerSrcTask"))
827 .unwrap();
828 let inf0_id = graph
829 .add_node(Node::new("inf0", "tasks::Integer2FloatTask"))
830 .unwrap();
831 let broadcast_id = graph
832 .add_node(Node::new("broadcast", "tasks::MergingSinkTask"))
833 .unwrap();
834
835 graph.connect(cam0_id, inf0_id, "i32").unwrap();
837 graph.connect(cam0_id, broadcast_id, "i32").unwrap();
838 graph.connect(inf0_id, broadcast_id, "f32").unwrap();
839
840 let edge_cam0_to_inf0 = *graph.get_src_edges(cam0_id).unwrap().first().unwrap();
841 let edge_cam0_to_broadcast = graph.get_src_edges(cam0_id).unwrap()[1];
842
843 assert_eq!(edge_cam0_to_broadcast, 0);
844 assert_eq!(edge_cam0_to_inf0, 1);
845
846 let runtime = compute_runtime_plan(graph).unwrap();
847 let broadcast_step = runtime
848 .steps
849 .iter()
850 .find_map(|step| match step {
851 CuExecutionUnit::Step(step) if step.node_id == broadcast_id => Some(step),
852 _ => None,
853 })
854 .unwrap();
855
856 assert_eq!(broadcast_step.input_msg_indices_types[0].1, "i32");
857 assert_eq!(broadcast_step.input_msg_indices_types[1].1, "f32");
858 }
859}