cu29_runtime/
curuntime.rs

1//! CuRuntime is the heart of what copper is running on the robot.
2//! It is exposed to the user via the `copper_runtime` macro injecting it as a field in their application struct.
3//!
4
5use crate::config::{Cnx, CuConfig, NodeId};
6use crate::config::{ComponentConfig, Node};
7use crate::copperlist::{CopperList, CopperListState, CuListsManager};
8use crate::monitoring::CuMonitor;
9use cu29_clock::{ClockProvider, RobotClock};
10use cu29_log_runtime::LoggerRuntime;
11use cu29_traits::CopperListTuple;
12use cu29_traits::CuResult;
13use cu29_traits::WriteStream;
14use cu29_unifiedlog::UnifiedLoggerWrite;
15use std::sync::{Arc, Mutex};
16
17use petgraph::prelude::*;
18use petgraph::visit::VisitMap;
19use petgraph::visit::Visitable;
20use std::fmt::Debug;
21
22/// Just a simple struct to hold the various bits needed to run a Copper application.
23pub struct CopperContext {
24    pub unified_logger: Arc<Mutex<UnifiedLoggerWrite>>,
25    pub logger_runtime: LoggerRuntime,
26    pub clock: RobotClock,
27}
28
29/// This is the main structure that will be injected as a member of the Application struct.
30/// CT is the tuple of all the tasks in order of execution.
31/// CL is the type of the copper list, representing the input/output messages for all the tasks.
32pub struct CuRuntime<CT, P: CopperListTuple, M: CuMonitor, const NBCL: usize> {
33    /// The tuple of all the tasks in order of execution.
34    pub tasks: CT,
35
36    pub monitor: M,
37
38    /// Copper lists hold in order all the input/output messages for all the tasks.
39    pub copper_lists_manager: CuListsManager<P, NBCL>,
40
41    /// The base clock the runtime will be using to record time.
42    pub clock: RobotClock, // TODO: remove public at some point
43
44    /// Logger
45    logger: Option<Box<dyn WriteStream<CopperList<P>>>>,
46}
47
48/// To be able to share the clock we make the runtime a clock provider.
49impl<CT, P: CopperListTuple, M: CuMonitor, const NBCL: usize> ClockProvider
50    for CuRuntime<CT, P, M, NBCL>
51{
52    fn get_clock(&self) -> RobotClock {
53        self.clock.clone()
54    }
55}
56
57impl<CT, P: CopperListTuple + 'static, M: CuMonitor, const NBCL: usize> CuRuntime<CT, P, M, NBCL> {
58    pub fn new(
59        clock: RobotClock,
60        config: &CuConfig,
61        tasks_instanciator: impl Fn(Vec<Option<&ComponentConfig>>) -> CuResult<CT>,
62        monitor_instanciator: impl Fn(&CuConfig) -> M,
63        logger: impl WriteStream<CopperList<P>> + 'static,
64    ) -> CuResult<Self> {
65        let all_instances_configs: Vec<Option<&ComponentConfig>> = config
66            .get_all_nodes(None) // FIXME(gbin): Multimission support
67            .iter()
68            .map(|(_, node)| node.get_instance_config())
69            .collect();
70        let tasks = tasks_instanciator(all_instances_configs)?;
71
72        let monitor = monitor_instanciator(config);
73
74        // Needed to declare type explicitly as `cargo check` was failing without it
75        let logger_: Option<Box<dyn WriteStream<CopperList<P>>>> =
76            if let Some(logging_config) = &config.logging {
77                if logging_config.enable_task_logging {
78                    Some(Box::new(logger))
79                } else {
80                    None
81                }
82            } else {
83                Some(Box::new(logger))
84            };
85
86        let runtime = Self {
87            tasks,
88            monitor,
89            copper_lists_manager: CuListsManager::new(), // placeholder
90            clock,
91            logger: logger_,
92        };
93
94        Ok(runtime)
95    }
96
97    pub fn available_copper_lists(&self) -> usize {
98        NBCL - self.copper_lists_manager.len()
99    }
100
101    pub fn end_of_processing(&mut self, culistid: u32) {
102        let mut is_top = true;
103        let mut nb_done = 0;
104        self.copper_lists_manager.iter_mut().for_each(|cl| {
105            if cl.id == culistid && cl.get_state() == CopperListState::Processing {
106                cl.change_state(CopperListState::DoneProcessing);
107            }
108            // if we have a series of copper lists that are done processing at the top of the circular buffer
109            // serialize them all and Free them.
110            if is_top && cl.get_state() == CopperListState::DoneProcessing {
111                if let Some(logger) = &mut self.logger {
112                    cl.change_state(CopperListState::BeingSerialized);
113                    logger.log(cl).unwrap();
114                }
115                cl.change_state(CopperListState::Free);
116                nb_done += 1;
117            } else {
118                is_top = false;
119            }
120        });
121        for _ in 0..nb_done {
122            let _ = self.copper_lists_manager.pop();
123        }
124    }
125}
126
127/// Copper tasks can be of 3 types:
128/// - Source: only producing output messages (usually used for drivers)
129/// - Regular: processing input messages and producing output messages, more like compute nodes.
130/// - Sink: only consuming input messages (usually used for actuators)
131#[derive(Debug, PartialEq, Eq, Clone, Copy)]
132pub enum CuTaskType {
133    Source,
134    Regular,
135    Sink,
136}
137
138/// This structure represents a step in the execution plan.
139pub struct CuExecutionStep {
140    /// NodeId: node id of the task to execute
141    pub node_id: NodeId,
142    /// Node: node instance
143    pub node: Node,
144    /// CuTaskType: type of the task
145    pub task_type: CuTaskType,
146
147    /// the indices in the copper list of the input messages and their types
148    pub input_msg_indices_types: Vec<(u32, String)>,
149
150    /// the index in the copper list of the output message and its type
151    pub output_msg_index_type: Option<(u32, String)>,
152}
153
154impl Debug for CuExecutionStep {
155    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        f.write_str(format!("   CuExecutionStep: Node Id: {}\n", self.node_id).as_str())?;
157        f.write_str(format!("                  task_type: {:?}\n", self.node.get_type()).as_str())?;
158        f.write_str(format!("                       task: {:?}\n", self.task_type).as_str())?;
159        f.write_str(
160            format!(
161                "              input_msg_types: {:?}\n",
162                self.input_msg_indices_types
163            )
164            .as_str(),
165        )?;
166        f.write_str(
167            format!("       output_msg_type: {:?}\n", self.output_msg_index_type).as_str(),
168        )?;
169        Ok(())
170    }
171}
172
173/// This structure represents a loop in the execution plan.
174/// It is used to represent a sequence of Execution units (loop or steps) that are executed
175/// multiple times.
176/// if loop_count is None, the loop is infinite.
177pub struct CuExecutionLoop {
178    pub steps: Vec<CuExecutionUnit>,
179    pub loop_count: Option<u32>,
180}
181
182impl Debug for CuExecutionLoop {
183    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184        f.write_str("CuExecutionLoop:\n")?;
185        for step in &self.steps {
186            match step {
187                CuExecutionUnit::Step(step) => {
188                    step.fmt(f)?;
189                }
190                CuExecutionUnit::Loop(l) => {
191                    l.fmt(f)?;
192                }
193            }
194        }
195
196        f.write_str(format!("   count: {:?}", self.loop_count).as_str())?;
197        Ok(())
198    }
199}
200
201/// This structure represents a step in the execution plan.
202#[derive(Debug)]
203pub enum CuExecutionUnit {
204    Step(CuExecutionStep),
205    Loop(CuExecutionLoop),
206}
207
208fn find_output_index_type_from_nodeid(
209    node_id: NodeId,
210    steps: &Vec<CuExecutionUnit>,
211) -> Option<(u32, String)> {
212    for step in steps {
213        match step {
214            CuExecutionUnit::Loop(loop_unit) => {
215                if let Some(index) = find_output_index_type_from_nodeid(node_id, &loop_unit.steps) {
216                    return Some(index);
217                }
218            }
219            CuExecutionUnit::Step(step) => {
220                if step.node_id == node_id {
221                    return step.output_msg_index_type.clone();
222                }
223            }
224        }
225    }
226    None
227}
228
229pub fn find_task_type_for_id(
230    graph: &StableDiGraph<Node, Cnx, NodeId>,
231    node_id: NodeId,
232) -> CuTaskType {
233    if graph.neighbors_directed(node_id.into(), Incoming).count() == 0 {
234        CuTaskType::Source
235    } else if graph.neighbors_directed(node_id.into(), Outgoing).count() == 0 {
236        CuTaskType::Sink
237    } else {
238        CuTaskType::Regular
239    }
240}
241
242/// This function gets the input node by using the input step plan id, to get the edge that
243/// connects the input to the output in the config graph
244fn find_edge_with_plan_input_id(
245    plan: &[CuExecutionUnit],
246    config: &CuConfig,
247    plan_id: u32,
248    output_node_id: NodeId,
249) -> usize {
250    let input_node = plan
251        .get(plan_id as usize)
252        .expect("Input step should've been added to plan before the step that receives the input");
253    let CuExecutionUnit::Step(input_step) = input_node else {
254        panic!("Expected input to be from a step, not a loop");
255    };
256    let input_node_id = input_step.node_id;
257
258    config
259        .get_graph(None)
260        .unwrap() // FIXME(gbin): Error handling and multimission
261        .edges_connecting(input_node_id.into(), output_node_id.into())
262        .map(|edge| edge.id().index())
263        .next()
264        .expect("An edge connecting the input to the output should exist")
265}
266
267/// The connection id used here is the index of the config graph edge that equates to the wanted
268/// connection
269fn sort_inputs_by_cnx_id(
270    input_msg_indices_types: &mut [(u32, String)],
271    plan: &[CuExecutionUnit],
272    config: &CuConfig,
273    curr_node_id: NodeId,
274) {
275    input_msg_indices_types.sort_by(|(a_index, _), (b_index, _)| {
276        let a_edge_id = find_edge_with_plan_input_id(plan, config, *a_index, curr_node_id);
277        let b_edge_id = find_edge_with_plan_input_id(plan, config, *b_index, curr_node_id);
278        a_edge_id.cmp(&b_edge_id)
279    });
280}
281/// Explores a subbranch and build the partial plan out of it.
282fn plan_tasks_tree_branch(
283    config: &CuConfig,
284    mut next_culist_output_index: u32,
285    starting_point: NodeId,
286    plan: &mut Vec<CuExecutionUnit>,
287) -> (u32, bool) {
288    #[cfg(feature = "macro_debug")]
289    eprintln!("-- starting branch from node {starting_point}");
290
291    let graph = config.get_graph(None).unwrap(); // FIXME(gbin): Error handling and multimission
292    let mut visitor = Bfs::new(&graph, starting_point.into());
293    let mut handled = false;
294
295    while let Some(node) = visitor.next(&graph) {
296        let id = node.index() as NodeId;
297        let node_ref = config.graphs.get_node(id, None).unwrap();
298        #[cfg(feature = "macro_debug")]
299        eprintln!("  Visiting node: {node_ref:?}");
300
301        let mut input_msg_indices_types: Vec<(u32, String)> = Vec::new();
302        let output_msg_index_type: Option<(u32, String)>;
303        let task_type = find_task_type_for_id(graph, id);
304
305        match task_type {
306            CuTaskType::Source => {
307                #[cfg(feature = "macro_debug")]
308                eprintln!("    → Source node, assign output index {next_culist_output_index}");
309                output_msg_index_type = Some((
310                    next_culist_output_index,
311                    graph
312                        .edge_weight(EdgeIndex::new(config.get_src_edges(id, None).unwrap()[0])) // FIXME(gbin): Error handling and multimission
313                        .unwrap()
314                        .msg
315                        .clone(),
316                ));
317                next_culist_output_index += 1;
318            }
319            CuTaskType::Sink => {
320                let parents: Vec<NodeIndex> =
321                    graph.neighbors_directed(id.into(), Incoming).collect();
322                #[cfg(feature = "macro_debug")]
323                eprintln!("    → Sink with parents: {parents:?}");
324                for parent in &parents {
325                    let pid = parent.index() as NodeId;
326                    let index_type = find_output_index_type_from_nodeid(pid, plan);
327                    if let Some(index_type) = index_type {
328                        #[cfg(feature = "macro_debug")]
329                        eprintln!("      ✓ Input from {pid} ready: {index_type:?}");
330                        input_msg_indices_types.push(index_type);
331                    } else {
332                        #[cfg(feature = "macro_debug")]
333                        eprintln!("      ✗ Input from {pid} not ready, returning");
334                        return (next_culist_output_index, handled);
335                    }
336                }
337                output_msg_index_type = Some((next_culist_output_index, "()".to_string()));
338                next_culist_output_index += 1;
339            }
340            CuTaskType::Regular => {
341                let parents: Vec<NodeIndex> =
342                    graph.neighbors_directed(id.into(), Incoming).collect();
343                #[cfg(feature = "macro_debug")]
344                eprintln!("    → Regular task with parents: {parents:?}");
345                for parent in &parents {
346                    let pid = parent.index() as NodeId;
347                    let index_type = find_output_index_type_from_nodeid(pid, plan);
348                    if let Some(index_type) = index_type {
349                        #[cfg(feature = "macro_debug")]
350                        eprintln!("      ✓ Input from {pid} ready: {index_type:?}");
351                        input_msg_indices_types.push(index_type);
352                    } else {
353                        #[cfg(feature = "macro_debug")]
354                        eprintln!("      ✗ Input from {pid} not ready, returning");
355                        return (next_culist_output_index, handled);
356                    }
357                }
358                output_msg_index_type = Some((
359                    next_culist_output_index,
360                    graph
361                        .edge_weight(EdgeIndex::new(config.get_src_edges(id, None).unwrap()[0])) // FIXME(gbin): Error handling and multimission
362                        .unwrap()
363                        .msg
364                        .clone(),
365                ));
366                next_culist_output_index += 1;
367            }
368        }
369
370        sort_inputs_by_cnx_id(&mut input_msg_indices_types, plan, config, id);
371
372        if let Some(pos) = plan
373            .iter()
374            .position(|step| matches!(step, CuExecutionUnit::Step(s) if s.node_id == id))
375        {
376            #[cfg(feature = "macro_debug")]
377            eprintln!("    → Already in plan, modifying existing step");
378            let mut step = plan.remove(pos);
379            if let CuExecutionUnit::Step(ref mut s) = step {
380                s.input_msg_indices_types = input_msg_indices_types;
381            }
382            plan.push(step);
383        } else {
384            #[cfg(feature = "macro_debug")]
385            eprintln!("    → New step added to plan");
386            let step = CuExecutionStep {
387                node_id: id,
388                node: node_ref.clone(),
389                task_type,
390                input_msg_indices_types,
391                output_msg_index_type,
392            };
393            plan.push(CuExecutionUnit::Step(step));
394        }
395
396        handled = true;
397    }
398
399    #[cfg(feature = "macro_debug")]
400    eprintln!("-- finished branch from node {starting_point} with handled={handled}");
401    (next_culist_output_index, handled)
402}
403
404/// This is the main heuristics to compute an execution plan at compilation time.
405/// TODO: Make that heuristic pluggable.
406pub fn compute_runtime_plan(config: &CuConfig) -> CuResult<CuExecutionLoop> {
407    #[cfg(feature = "macro_debug")]
408    eprintln!("[runtime plan]");
409    let graph = config.get_graph(None).unwrap(); // FIXME(gbin): Error handling and multimission
410    let visited = graph.visit_map();
411    let mut plan = Vec::new();
412    let mut next_culist_output_index = 0u32;
413
414    let graph = config.get_graph(None).unwrap(); // FIXME(gbin): Error handling and multimission
415    let mut queue: std::collections::VecDeque<NodeId> = graph
416        .node_indices()
417        .filter(|&node| find_task_type_for_id(graph, node.index() as NodeId) == CuTaskType::Source)
418        .map(|node| node.index() as NodeId)
419        .collect();
420
421    #[cfg(feature = "macro_debug")]
422    eprintln!("Initial source nodes: {queue:?}");
423
424    while let Some(start_node) = queue.pop_front() {
425        if visited.is_visited(&start_node) {
426            #[cfg(feature = "macro_debug")]
427            eprintln!("→ Skipping already visited source {start_node}");
428            continue;
429        }
430
431        #[cfg(feature = "macro_debug")]
432        eprintln!("→ Starting BFS from source {start_node}");
433        let graph = config.get_graph(None).unwrap(); // FIXME(gbin): Error handling and multimission
434        let mut bfs = Bfs::new(graph, start_node.into());
435
436        while let Some(node_index) = bfs.next(graph) {
437            let node_id = node_index.index() as NodeId;
438            let already_in_plan = plan
439                .iter()
440                .any(|unit| matches!(unit, CuExecutionUnit::Step(s) if s.node_id == node_id));
441            if already_in_plan {
442                #[cfg(feature = "macro_debug")]
443                eprintln!("    → Node {node_id} already planned, skipping");
444                continue;
445            }
446
447            #[cfg(feature = "macro_debug")]
448            eprintln!("    Planning from node {node_id}");
449            let (new_index, handled) =
450                plan_tasks_tree_branch(config, next_culist_output_index, node_id, &mut plan);
451            next_culist_output_index = new_index;
452
453            if !handled {
454                #[cfg(feature = "macro_debug")]
455                eprintln!("    ✗ Node {node_id} was not handled, skipping enqueue of neighbors");
456                continue;
457            }
458
459            #[cfg(feature = "macro_debug")]
460            eprintln!("    ✓ Node {node_id} handled successfully, enqueueing neighbors");
461            let graph = config.get_graph(None).unwrap(); // FIXME(gbin): Error handling and multimission
462            for neighbor in graph.neighbors(node_index) {
463                if !visited.is_visited(&neighbor) {
464                    let nid = neighbor.index() as NodeId;
465                    #[cfg(feature = "macro_debug")]
466                    eprintln!("      → Enqueueing neighbor {nid}");
467                    queue.push_back(nid);
468                }
469            }
470        }
471    }
472
473    Ok(CuExecutionLoop {
474        steps: plan,
475        loop_count: None,
476    })
477}
478
479//tests
480#[cfg(test)]
481mod tests {
482    use super::*;
483    use crate::config::Node;
484    use crate::cutask::CuSinkTask;
485    use crate::cutask::{CuSrcTask, Freezable};
486    use crate::monitoring::NoMonitor;
487    use bincode::Encode;
488
489    pub struct TestSource {}
490
491    impl Freezable for TestSource {}
492
493    impl CuSrcTask<'_> for TestSource {
494        type Output = ();
495        fn new(_config: Option<&ComponentConfig>) -> CuResult<Self>
496        where
497            Self: Sized,
498        {
499            Ok(Self {})
500        }
501
502        fn process(&mut self, _clock: &RobotClock, _empty_msg: Self::Output) -> CuResult<()> {
503            Ok(())
504        }
505    }
506
507    pub struct TestSink {}
508
509    impl Freezable for TestSink {}
510
511    impl CuSinkTask<'_> for TestSink {
512        type Input = ();
513
514        fn new(_config: Option<&ComponentConfig>) -> CuResult<Self>
515        where
516            Self: Sized,
517        {
518            Ok(Self {})
519        }
520
521        fn process(&mut self, _clock: &RobotClock, _input: Self::Input) -> CuResult<()> {
522            Ok(())
523        }
524    }
525
526    // Those should be generated by the derive macro
527    type Tasks = (TestSource, TestSink);
528    type Msgs = ((),);
529
530    fn tasks_instanciator(all_instances_configs: Vec<Option<&ComponentConfig>>) -> CuResult<Tasks> {
531        Ok((
532            TestSource::new(all_instances_configs[0])?,
533            TestSink::new(all_instances_configs[1])?,
534        ))
535    }
536
537    fn monitor_instanciator(_config: &CuConfig) -> NoMonitor {
538        NoMonitor {}
539    }
540
541    #[derive(Debug)]
542    struct FakeWriter {}
543
544    impl<E: Encode> WriteStream<E> for FakeWriter {
545        fn log(&mut self, _obj: &E) -> CuResult<()> {
546            Ok(())
547        }
548    }
549
550    #[test]
551    fn test_runtime_instantiation() {
552        let mut config = CuConfig::default();
553        let graph = config.get_graph_mut(None).unwrap();
554        graph.add_node(Node::new("a", "TestSource"));
555        graph.add_node(Node::new("b", "TestSink"));
556        config.connect(0, 1, "()").unwrap();
557        let runtime = CuRuntime::<Tasks, Msgs, NoMonitor, 2>::new(
558            RobotClock::default(),
559            &config,
560            tasks_instanciator,
561            monitor_instanciator,
562            FakeWriter {},
563        );
564        assert!(runtime.is_ok());
565    }
566
567    #[test]
568    fn test_copperlists_manager_lifecycle() {
569        let mut config = CuConfig::default();
570        let graph = config.get_graph_mut(None).unwrap();
571        graph.add_node(Node::new("a", "TestSource"));
572        graph.add_node(Node::new("b", "TestSink"));
573        config.connect(0, 1, "()").unwrap();
574        let mut runtime = CuRuntime::<Tasks, Msgs, NoMonitor, 2>::new(
575            RobotClock::default(),
576            &config,
577            tasks_instanciator,
578            monitor_instanciator,
579            FakeWriter {},
580        )
581        .unwrap();
582
583        // Now emulates the generated runtime
584        {
585            let copperlists = &mut runtime.copper_lists_manager;
586            let culist0 = copperlists
587                .create()
588                .expect("Ran out of space for copper lists");
589            // FIXME: error handling.
590            let id = culist0.id;
591            assert_eq!(id, 0);
592            culist0.change_state(CopperListState::Processing);
593            assert_eq!(runtime.available_copper_lists(), 1);
594        }
595
596        {
597            let copperlists = &mut runtime.copper_lists_manager;
598            let culist1 = copperlists
599                .create()
600                .expect("Ran out of space for copper lists"); // FIXME: error handling.
601            let id = culist1.id;
602            assert_eq!(id, 1);
603            culist1.change_state(CopperListState::Processing);
604            assert_eq!(runtime.available_copper_lists(), 0);
605        }
606
607        {
608            let copperlists = &mut runtime.copper_lists_manager;
609            let culist2 = copperlists.create();
610            assert!(culist2.is_none());
611            assert_eq!(runtime.available_copper_lists(), 0);
612        }
613
614        // Free in order, should let the top of the stack be serialized and freed.
615        runtime.end_of_processing(1);
616        assert_eq!(runtime.available_copper_lists(), 1);
617
618        // Readd a CL
619        {
620            let copperlists = &mut runtime.copper_lists_manager;
621            let culist2 = copperlists
622                .create()
623                .expect("Ran out of space for copper lists"); // FIXME: error handling.
624            let id = culist2.id;
625            assert_eq!(id, 2);
626            culist2.change_state(CopperListState::Processing);
627            assert_eq!(runtime.available_copper_lists(), 0);
628        }
629
630        // Free out of order, the #0 first
631        runtime.end_of_processing(0);
632        // Should not free up the top of the stack
633        assert_eq!(runtime.available_copper_lists(), 0);
634
635        // Free up the top of the stack
636        runtime.end_of_processing(2);
637        // This should free up 2 CLs
638
639        assert_eq!(runtime.available_copper_lists(), 2);
640    }
641
642    #[test]
643    fn test_runtime_task_input_order() {
644        let mut config = CuConfig::default();
645        let src1_id = config.add_node(Node::new("a", "Source1"), None).unwrap();
646        let src2_id = config.add_node(Node::new("b", "Source2"), None).unwrap();
647        let sink_id = config.add_node(Node::new("c", "Sink"), None).unwrap();
648
649        assert_eq!(src1_id, 0);
650        assert_eq!(src2_id, 1);
651
652        // note that the source2 connection is before the source1
653        let src1_type = "src1_type";
654        let src2_type = "src2_type";
655        config.connect(src2_id, sink_id, src2_type).unwrap();
656        config.connect(src1_id, sink_id, src1_type).unwrap();
657
658        let src1_edge_id = *config
659            .get_src_edges(src1_id, None)
660            .unwrap()
661            .first()
662            .unwrap();
663        let src2_edge_id = *config
664            .get_src_edges(src2_id, None)
665            .unwrap()
666            .first()
667            .unwrap();
668        // the edge id depends on the order the connection is created, not
669        // on the node id, and that is what determines the input order
670        assert_eq!(src1_edge_id, 1);
671        assert_eq!(src2_edge_id, 0);
672
673        let runtime = compute_runtime_plan(&config).unwrap();
674        let sink_step = runtime
675            .steps
676            .iter()
677            .find_map(|step| match step {
678                CuExecutionUnit::Step(step) if step.node_id == sink_id => Some(step),
679                _ => None,
680            })
681            .unwrap();
682
683        // since the src2 connection was added before src1 connection, the src2 type should be
684        // first
685        assert_eq!(sink_step.input_msg_indices_types[0].1, src2_type);
686        assert_eq!(sink_step.input_msg_indices_types[1].1, src1_type);
687    }
688
689    #[test]
690    fn test_runtime_plan_diamond_case1() {
691        // more complex topology that tripped the scheduler
692        let mut config = CuConfig::default();
693        let cam0_id = config
694            .add_node(Node::new("cam0", "tasks::IntegerSrcTask"), None)
695            .unwrap();
696        let inf0_id = config
697            .add_node(Node::new("inf0", "tasks::Integer2FloatTask"), None)
698            .unwrap();
699        let broadcast_id = config
700            .add_node(Node::new("broadcast", "tasks::MergingSinkTask"), None)
701            .unwrap();
702
703        // case 1 order
704        config.connect(cam0_id, broadcast_id, "i32").unwrap();
705        config.connect(cam0_id, inf0_id, "i32").unwrap();
706        config.connect(inf0_id, broadcast_id, "f32").unwrap();
707
708        let edge_cam0_to_broadcast = *config
709            .get_src_edges(cam0_id, None)
710            .unwrap()
711            .first()
712            .unwrap();
713        let edge_cam0_to_inf0 = config.get_src_edges(cam0_id, None).unwrap()[1];
714
715        assert_eq!(edge_cam0_to_inf0, 0);
716        assert_eq!(edge_cam0_to_broadcast, 1);
717
718        let runtime = compute_runtime_plan(&config).unwrap();
719        let broadcast_step = runtime
720            .steps
721            .iter()
722            .find_map(|step| match step {
723                CuExecutionUnit::Step(step) if step.node_id == broadcast_id => Some(step),
724                _ => None,
725            })
726            .unwrap();
727
728        assert_eq!(broadcast_step.input_msg_indices_types[0].1, "i32");
729        assert_eq!(broadcast_step.input_msg_indices_types[1].1, "f32");
730    }
731
732    #[test]
733    fn test_runtime_plan_diamond_case2() {
734        // more complex topology that tripped the scheduler variation 2
735        let mut config = CuConfig::default();
736        let cam0_id = config
737            .add_node(Node::new("cam0", "tasks::IntegerSrcTask"), None)
738            .unwrap();
739        let inf0_id = config
740            .add_node(Node::new("inf0", "tasks::Integer2FloatTask"), None)
741            .unwrap();
742        let broadcast_id = config
743            .add_node(Node::new("broadcast", "tasks::MergingSinkTask"), None)
744            .unwrap();
745
746        // case 2 order
747        config.connect(cam0_id, inf0_id, "i32").unwrap();
748        config.connect(cam0_id, broadcast_id, "i32").unwrap();
749        config.connect(inf0_id, broadcast_id, "f32").unwrap();
750
751        let edge_cam0_to_inf0 = *config
752            .get_src_edges(cam0_id, None)
753            .unwrap()
754            .first()
755            .unwrap();
756        let edge_cam0_to_broadcast = config.get_src_edges(cam0_id, None).unwrap()[1];
757
758        assert_eq!(edge_cam0_to_broadcast, 0);
759        assert_eq!(edge_cam0_to_inf0, 1);
760
761        let runtime = compute_runtime_plan(&config).unwrap();
762        let broadcast_step = runtime
763            .steps
764            .iter()
765            .find_map(|step| match step {
766                CuExecutionUnit::Step(step) if step.node_id == broadcast_id => Some(step),
767                _ => None,
768            })
769            .unwrap();
770
771        assert_eq!(broadcast_step.input_msg_indices_types[0].1, "i32");
772        assert_eq!(broadcast_step.input_msg_indices_types[1].1, "f32");
773    }
774}