22from collections import defaultdict
33
44
5- def prisms_algorithm (l ): # noqa: E741
5+ class Heap :
6+ def __init__ (self ):
7+ self .node_position = []
68
7- node_position = []
9+ def get_position (self , vertex ):
10+ return self .node_position [vertex ]
811
9- def get_position ( vertex ):
10- return node_position [vertex ]
12+ def set_position ( self , vertex , pos ):
13+ self . node_position [vertex ] = pos
1114
12- def set_position (vertex , pos ):
13- node_position [vertex ] = pos
14-
15- def top_to_bottom (heap , start , size , positions ):
15+ def top_to_bottom (self , heap , start , size , positions ):
1616 if start > size // 2 - 1 :
1717 return
1818 else :
1919 if 2 * start + 2 >= size :
20- m = 2 * start + 1
20+ smallest_child = 2 * start + 1
2121 else :
2222 if heap [2 * start + 1 ] < heap [2 * start + 2 ]:
23- m = 2 * start + 1
23+ smallest_child = 2 * start + 1
2424 else :
25- m = 2 * start + 2
26- if heap [m ] < heap [start ]:
27- temp , temp1 = heap [m ], positions [m ]
28- heap [m ], positions [m ] = heap [start ], positions [start ]
25+ smallest_child = 2 * start + 2
26+ if heap [smallest_child ] < heap [start ]:
27+ temp , temp1 = heap [smallest_child ], positions [smallest_child ]
28+ heap [smallest_child ], positions [smallest_child ] = (
29+ heap [start ],
30+ positions [start ],
31+ )
2932 heap [start ], positions [start ] = temp , temp1
3033
31- temp = get_position (positions [m ])
32- set_position (positions [m ], get_position (positions [start ]))
33- set_position (positions [start ], temp )
34+ temp = self .get_position (positions [smallest_child ])
35+ self .set_position (
36+ positions [smallest_child ], self .get_position (positions [start ])
37+ )
38+ self .set_position (positions [start ], temp )
3439
35- top_to_bottom (heap , m , size , positions )
40+ self . top_to_bottom (heap , smallest_child , size , positions )
3641
3742 # Update function if value of any node in min-heap decreases
38- def bottom_to_top (val , index , heap , position ):
43+ def bottom_to_top (self , val , index , heap , position ):
3944 temp = position [index ]
4045
4146 while index != 0 :
@@ -47,70 +52,88 @@ def bottom_to_top(val, index, heap, position):
4752 if val < heap [parent ]:
4853 heap [index ] = heap [parent ]
4954 position [index ] = position [parent ]
50- set_position (position [parent ], index )
55+ self . set_position (position [parent ], index )
5156 else :
5257 heap [index ] = val
5358 position [index ] = temp
54- set_position (temp , index )
59+ self . set_position (temp , index )
5560 break
5661 index = parent
5762 else :
5863 heap [0 ] = val
5964 position [0 ] = temp
60- set_position (temp , 0 )
65+ self . set_position (temp , 0 )
6166
62- def heapify (heap , positions ):
67+ def heapify (self , heap , positions ):
6368 start = len (heap ) // 2 - 1
6469 for i in range (start , - 1 , - 1 ):
65- top_to_bottom (heap , i , len (heap ), positions )
70+ self . top_to_bottom (heap , i , len (heap ), positions )
6671
67- def delete_minimum (heap , positions ):
72+ def delete_minimum (self , heap , positions ):
6873 temp = positions [0 ]
6974 heap [0 ] = sys .maxsize
70- top_to_bottom (heap , 0 , len (heap ), positions )
75+ self . top_to_bottom (heap , 0 , len (heap ), positions )
7176 return temp
7277
73- visited = [0 for i in range (len (l ))]
74- nbr_tv = [- 1 for i in range (len (l ))] # Neighboring Tree Vertex of selected vertex
78+
79+ def prisms_algorithm (adjacency_list ):
80+ """
81+ >>> adjacency_list = {0: [[1, 1], [3, 3]],
82+ ... 1: [[0, 1], [2, 6], [3, 5], [4, 1]],
83+ ... 2: [[1, 6], [4, 5], [5, 2]],
84+ ... 3: [[0, 3], [1, 5], [4, 1]],
85+ ... 4: [[1, 1], [2, 5], [3, 1], [5, 4]],
86+ ... 5: [[2, 2], [4, 4]]}
87+ >>> prisms_algorithm(adjacency_list)
88+ [(0, 1), (1, 4), (4, 3), (4, 5), (5, 2)]
89+ """
90+
91+ heap = Heap ()
92+
93+ visited = [0 ] * len (adjacency_list )
94+ nbr_tv = [- 1 ] * len (adjacency_list ) # Neighboring Tree Vertex of selected vertex
7595 # Minimum Distance of explored vertex with neighboring vertex of partial tree
7696 # formed in graph
7797 distance_tv = [] # Heap of Distance of vertices from their neighboring vertex
7898 positions = []
7999
80- for x in range (len (l )):
81- p = sys .maxsize
82- distance_tv .append (p )
83- positions .append (x )
84- node_position .append (x )
100+ for vertex in range (len (adjacency_list )):
101+ distance_tv .append (sys .maxsize )
102+ positions .append (vertex )
103+ heap .node_position .append (vertex )
85104
86105 tree_edges = []
87106 visited [0 ] = 1
88107 distance_tv [0 ] = sys .maxsize
89- for x in l [0 ]:
90- nbr_tv [x [ 0 ] ] = 0
91- distance_tv [x [ 0 ]] = x [ 1 ]
92- heapify (distance_tv , positions )
108+ for neighbor , distance in adjacency_list [0 ]:
109+ nbr_tv [neighbor ] = 0
110+ distance_tv [neighbor ] = distance
111+ heap . heapify (distance_tv , positions )
93112
94- for _ in range (1 , len (l )):
95- vertex = delete_minimum (distance_tv , positions )
113+ for _ in range (1 , len (adjacency_list )):
114+ vertex = heap . delete_minimum (distance_tv , positions )
96115 if visited [vertex ] == 0 :
97116 tree_edges .append ((nbr_tv [vertex ], vertex ))
98117 visited [vertex ] = 1
99- for v in l [vertex ]:
100- if visited [v [0 ]] == 0 and v [1 ] < distance_tv [get_position (v [0 ])]:
101- distance_tv [get_position (v [0 ])] = v [1 ]
102- bottom_to_top (v [1 ], get_position (v [0 ]), distance_tv , positions )
103- nbr_tv [v [0 ]] = vertex
118+ for neighbor , distance in adjacency_list [vertex ]:
119+ if (
120+ visited [neighbor ] == 0
121+ and distance < distance_tv [heap .get_position (neighbor )]
122+ ):
123+ distance_tv [heap .get_position (neighbor )] = distance
124+ heap .bottom_to_top (
125+ distance , heap .get_position (neighbor ), distance_tv , positions
126+ )
127+ nbr_tv [neighbor ] = vertex
104128 return tree_edges
105129
106130
107131if __name__ == "__main__" : # pragma: no cover
108132 # < --------- Prims Algorithm --------- >
109- n = int (input ("Enter number of vertices: " ).strip ())
110- e = int (input ("Enter number of edges: " ).strip ())
111- adjlist = defaultdict (list )
112- for x in range (e ):
113- l = [int (x ) for x in input ().strip ().split ()] # noqa: E741
114- adjlist [l [0 ]].append ([l [1 ], l [2 ]])
115- adjlist [l [1 ]].append ([l [0 ], l [2 ]])
116- print (prisms_algorithm (adjlist ))
133+ edges_number = int (input ("Enter number of edges: " ).strip ())
134+ adjacency_list = defaultdict (list )
135+ for _ in range (edges_number ):
136+ edge = [int (x ) for x in input ().strip ().split ()]
137+ adjacency_list [edge [0 ]].append ([edge [1 ], edge [2 ]])
138+ adjacency_list [edge [1 ]].append ([edge [0 ], edge [2 ]])
139+ print (prisms_algorithm (adjacency_list ))
0 commit comments