@@ -586,6 +586,98 @@ def test_state_connect_6a():
586586 ]
587587
588588
589+ def test_state_connect_7 ():
590+ """two 'connected' states with multiple fields that are connected
591+ no explicit splitter for the second state
592+ """
593+ st1 = State (name = "NA" , splitter = "a" )
594+ st2 = State (name = "NB" , other_states = {"NA" : (st1 , ["x" , "y" ])})
595+ # should take into account that x, y come from the same task
596+ assert st2 .splitter == "_NA"
597+ assert st2 .splitter_rpn == ["NA.a" ]
598+ assert st2 .prev_state_splitter == st2 .splitter
599+ assert st2 .prev_state_splitter_rpn == st2 .splitter_rpn
600+ assert st2 .current_splitter is None
601+ assert st2 .current_splitter_rpn == []
602+
603+ st2 .prepare_states (inputs = {"NA.a" : [3 , 5 ]})
604+ assert st2 .group_for_inputs_final == {"NA.a" : 0 }
605+ assert st2 .groups_stack_final == [[0 ]]
606+ assert st2 .states_ind == [{"NA.a" : 0 }, {"NA.a" : 1 }]
607+ assert st2 .states_val == [{"NA.a" : 3 }, {"NA.a" : 5 }]
608+
609+ st2 .prepare_inputs ()
610+ # since x,y come from the same state, they should have the same index
611+ assert st2 .inputs_ind == [{"NB.x" : 0 , "NB.y" : 0 }, {"NB.x" : 1 , "NB.y" : 1 }]
612+
613+
614+ def test_state_connect_8 ():
615+ """three 'connected' states: NA -> NB -> NC; NA -> NC (only NA has its own splitter)
616+ pydra should recognize, that there is only one splitter - NA
617+ and it should give the same as the previous test
618+ """
619+ st1 = State (name = "NA" , splitter = "a" )
620+ st2 = State (name = "NB" , other_states = {"NA" : (st1 , "b" )})
621+ st3 = State (name = "NC" , other_states = {"NA" : (st1 , "x" ), "NB" : (st2 , "y" )})
622+ # x comes from NA and y comes from NB, but NB has only NA's splitter,
623+ # so it should be treated as both inputs are from NA state
624+ assert st3 .splitter == "_NA"
625+ assert st3 .splitter_rpn == ["NA.a" ]
626+ assert st3 .prev_state_splitter == st3 .splitter
627+ assert st3 .prev_state_splitter_rpn == st3 .splitter_rpn
628+ assert st3 .current_splitter is None
629+ assert st3 .current_splitter_rpn == []
630+
631+ st3 .prepare_states (inputs = {"NA.a" : [3 , 5 ]})
632+ assert st3 .group_for_inputs_final == {"NA.a" : 0 }
633+ assert st3 .groups_stack_final == [[0 ]]
634+ assert st3 .states_ind == [{"NA.a" : 0 }, {"NA.a" : 1 }]
635+ assert st3 .states_val == [{"NA.a" : 3 }, {"NA.a" : 5 }]
636+
637+ st3 .prepare_inputs ()
638+ # since x,y come from the same state (although y indirectly), they should have the same index
639+ assert st3 .inputs_ind == [{"NC.x" : 0 , "NC.y" : 0 }, {"NC.x" : 1 , "NC.y" : 1 }]
640+
641+
642+ @pytest .mark .xfail (
643+ reason = "doesn't recognize that NC.y has 4 elements (not independend on NC.x)"
644+ )
645+ def test_state_connect_9 ():
646+ """four 'connected' states: NA1 -> NB; NA2 -> NB, NA1 -> NC; NB -> NC
647+ pydra should recognize, that there is only one splitter - NA_1 and NA_2
648+
649+ """
650+ st1 = State (name = "NA_1" , splitter = "a" )
651+ st1a = State (name = "NA_2" , splitter = "a" )
652+ st2 = State (name = "NB" , other_states = {"NA_1" : (st1 , "b" ), "NA_2" : (st1a , "c" )})
653+ st3 = State (name = "NC" , other_states = {"NA_1" : (st1 , "x" ), "NB" : (st2 , "y" )})
654+ # x comes from NA_1 and y comes from NB, but NB has only NA_1/2's splitters,
655+ assert st3 .splitter == ["_NA_1" , "_NA_2" ]
656+ assert st3 .splitter_rpn == ["NA_1.a" , "NA_2.a" , "*" ]
657+ assert st3 .prev_state_splitter == st3 .splitter
658+ assert st3 .prev_state_splitter_rpn == st3 .splitter_rpn
659+ assert st3 .current_splitter is None
660+ assert st3 .current_splitter_rpn == []
661+
662+ st3 .prepare_states (inputs = {"NA_1.a" : [3 , 5 ], "NA_2.a" : [11 , 12 ]})
663+ assert st3 .group_for_inputs_final == {"NA_1.a" : 0 , "NA_2.a" : 1 }
664+ assert st3 .groups_stack_final == [[0 , 1 ]]
665+ assert st3 .states_ind == [
666+ {"NA_1.a" : 0 , "NA_2.a" : 0 },
667+ {"NA_1.a" : 0 , "NA_2.a" : 1 },
668+ {"NA_1.a" : 1 , "NA_2.a" : 0 },
669+ {"NA_1.a" : 1 , "NA_2.a" : 1 },
670+ ]
671+
672+ st3 .prepare_inputs ()
673+ assert st3 .inputs_ind == [
674+ {"NC.x" : 0 , "NC.y" : 0 },
675+ {"NC.x" : 0 , "NC.y" : 1 },
676+ {"NC.x" : 1 , "NC.y" : 2 },
677+ {"NC.x" : 1 , "NC.y" : 3 },
678+ ]
679+
680+
589681def test_state_connect_innerspl_1 ():
590682 """two 'connected' states: testing groups, prepare_states and prepare_inputs,
591683 the second state has an inner splitter, full splitter provided
@@ -605,7 +697,7 @@ def test_state_connect_innerspl_1():
605697 inputs = {"NA.a" : [3 , 5 ], "NB.b" : [[1 , 10 , 100 ], [2 , 20 , 200 ]]},
606698 cont_dim = {"NB.b" : 2 }, # will be treated as 2d container
607699 )
608- assert st2 .other_states ["NA" ][1 ] == "b"
700+ assert st2 .other_states ["NA" ][1 ] == [ "b" ]
609701 assert st2 .group_for_inputs_final == {"NA.a" : 0 , "NB.b" : 1 }
610702 assert st2 .groups_stack_final == [[0 ], [1 ]]
611703
@@ -653,7 +745,7 @@ def test_state_connect_innerspl_1a():
653745 assert st2 .current_splitter == "NB.b"
654746 assert st2 .current_splitter_rpn == ["NB.b" ]
655747
656- assert st2 .other_states ["NA" ][1 ] == "b"
748+ assert st2 .other_states ["NA" ][1 ] == [ "b" ]
657749
658750 st2 .prepare_states (
659751 inputs = {"NA.a" : [3 , 5 ], "NB.b" : [[1 , 10 , 100 ], [2 , 20 , 200 ]]},
@@ -717,7 +809,7 @@ def test_state_connect_innerspl_2():
717809 inputs = {"NA.a" : [3 , 5 ], "NB.b" : [[1 , 10 , 100 ], [2 , 20 , 200 ]], "NB.c" : [13 , 17 ]},
718810 cont_dim = {"NB.b" : 2 }, # will be treated as 2d container
719811 )
720- assert st2 .other_states ["NA" ][1 ] == "b"
812+ assert st2 .other_states ["NA" ][1 ] == [ "b" ]
721813 assert st2 .group_for_inputs_final == {"NA.a" : 0 , "NB.c" : 1 , "NB.b" : 2 }
722814 assert st2 .groups_stack_final == [[0 ], [1 , 2 ]]
723815
@@ -778,7 +870,7 @@ def test_state_connect_innerspl_2a():
778870
779871 assert st2 .splitter == ["_NA" , ["NB.b" , "NB.c" ]]
780872 assert st2 .splitter_rpn == ["NA.a" , "NB.b" , "NB.c" , "*" , "*" ]
781- assert st2 .other_states ["NA" ][1 ] == "b"
873+ assert st2 .other_states ["NA" ][1 ] == [ "b" ]
782874
783875 st2 .prepare_states (
784876 inputs = {"NA.a" : [3 , 5 ], "NB.b" : [[1 , 10 , 100 ], [2 , 20 , 200 ]], "NB.c" : [13 , 17 ]},
@@ -839,6 +931,7 @@ def test_state_connect_innerspl_3():
839931 the second state has one inner splitter and one 'normal' splitter
840932 the prev-state parts of the splitter have to be added
841933 """
934+
842935 st1 = State (name = "NA" , splitter = "a" )
843936 st2 = State (name = "NB" , splitter = ["c" , "b" ], other_states = {"NA" : (st1 , "b" )})
844937 st3 = State (name = "NC" , splitter = "d" , other_states = {"NB" : (st2 , "a" )})
@@ -986,8 +1079,8 @@ def test_state_connect_innerspl_4():
9861079
9871080 assert st3 .splitter == [["_NA" , "_NB" ], "NC.d" ]
9881081 assert st3 .splitter_rpn == ["NA.a" , "NB.b" , "NB.c" , "*" , "*" , "NC.d" , "*" ]
989- assert st3 .other_states ["NA" ][1 ] == "e"
990- assert st3 .other_states ["NB" ][1 ] == "f"
1082+ assert st3 .other_states ["NA" ][1 ] == [ "e" ]
1083+ assert st3 .other_states ["NB" ][1 ] == [ "f" ]
9911084
9921085 st3 .prepare_states (
9931086 inputs = {
@@ -1736,12 +1829,12 @@ def test_connect_splitters_exception_1(splitter, other_states):
17361829
17371830
17381831def test_connect_splitters_exception_2 ():
1739- st = State (
1740- name = "CN" ,
1741- splitter = "_NB" ,
1742- other_states = {"NA" : (State (name = "NA" , splitter = "a" ), "b" )},
1743- )
17441832 with pytest .raises (PydraStateError ) as excinfo :
1833+ st = State (
1834+ name = "CN" ,
1835+ splitter = "_NB" ,
1836+ other_states = {"NA" : (State (name = "NA" , splitter = "a" ), "b" )},
1837+ )
17451838 st .set_input_groups ()
17461839 assert "can't ask for splitter from NB" in str (excinfo .value )
17471840
0 commit comments