@@ -30,9 +30,9 @@ class LocalClusterSession(object):
3030 def __init__ (self , endpoint , ** kwargs ):
3131 self ._session_id = uuid .uuid4 ()
3232 self ._endpoint = endpoint
33- # dict structure: {tensor_key -> graph_key, tensor_ids }
34- # dict value is a tuple object which records graph key and tensor id
35- self ._executed_tensors = dict ()
33+ # dict structure: {tileable_key -> graph_key, tileable_ids }
34+ # dict value is a tuple object which records graph key and tilable id
35+ self ._executed_tileables = dict ()
3636 self ._api = MarsAPI (self ._endpoint )
3737
3838 # create session on the cluster side
@@ -51,35 +51,35 @@ def endpoint(self, endpoint):
5151 self ._endpoint = endpoint
5252 self ._api = MarsAPI (self ._endpoint )
5353
54- def _get_tensor_graph_key (self , tensor_key ):
55- return self ._executed_tensors [ tensor_key ][0 ]
54+ def _get_tileable_graph_key (self , tileable_key ):
55+ return self ._executed_tileables [ tileable_key ][0 ]
5656
57- def _set_tensor_graph_key (self , tensor , graph_key ):
58- tensor_key = tensor .key
59- tensor_id = tensor .id
60- if tensor_key in self ._executed_tensors :
61- self ._executed_tensors [ tensor_key ][1 ].add (tensor_id )
57+ def _set_tileable_graph_key (self , tileable , graph_key ):
58+ tileable_key = tileable .key
59+ tileable_id = tileable .id
60+ if tileable_key in self ._executed_tileables :
61+ self ._executed_tileables [ tileable_key ][1 ].add (tileable_id )
6262 else :
63- self ._executed_tensors [ tensor_key ] = graph_key , {tensor_id }
63+ self ._executed_tileables [ tileable_key ] = graph_key , {tileable_id }
6464
65- def _update_tensor_shape (self , tensor ):
66- graph_key = self ._get_tensor_graph_key ( tensor .key )
67- new_nsplits = self ._api .get_tensor_nsplits (self ._session_id , graph_key , tensor .key )
68- tensor ._update_shape (tuple (sum (nsplit ) for nsplit in new_nsplits ))
69- tensor .nsplits = new_nsplits
65+ def _update_tileable_shape (self , tileable ):
66+ graph_key = self ._get_tileable_graph_key ( tileable .key )
67+ new_nsplits = self ._api .get_tileable_nsplits (self ._session_id , graph_key , tileable .key )
68+ tileable ._update_shape (tuple (sum (nsplit ) for nsplit in new_nsplits ))
69+ tileable .nsplits = new_nsplits
7070
71- def run (self , * tensors , ** kw ):
71+ def run (self , * tileables , ** kw ):
7272 timeout = kw .pop ('timeout' , - 1 )
7373 fetch = kw .pop ('fetch' , True )
7474 compose = kw .pop ('compose' , True )
7575 if kw :
7676 raise TypeError ('run got unexpected key arguments {0}' .format (', ' .join (kw .keys ())))
7777
78- # those executed tensors should fetch data directly, submit the others
79- run_tensors = [t for t in tensors if t .key not in self ._executed_tensors ]
78+ # those executed tileables should fetch data directly, submit the others
79+ run_tileables = [t for t in tileables if t .key not in self ._executed_tileables ]
8080
81- graph = build_graph (run_tensors , executed_keys = list (self ._executed_tensors .keys ()))
82- targets = [t .key for t in run_tensors ]
81+ graph = build_graph (run_tileables , executed_keys = list (self ._executed_tileables .keys ()))
82+ targets = [t .key for t in run_tileables ]
8383 graph_key = uuid .uuid4 ()
8484
8585 # submit graph to local cluster
@@ -100,40 +100,40 @@ def run(self, *tensors, **kw):
100100 if 0 < timeout < time .time () - exec_start_time :
101101 raise TimeoutError
102102
103- for t in tensors :
104- self ._set_tensor_graph_key (t , graph_key )
103+ for t in tileables :
104+ self ._set_tileable_graph_key (t , graph_key )
105105
106106 if not fetch :
107107 return
108108 else :
109- return self .fetch (* tensors )
109+ return self .fetch (* tileables )
110110
111- def fetch (self , * tensors ):
111+ def fetch (self , * tileables ):
112112 futures = []
113- for tensor in tensors :
114- key = tensor .key
113+ for tileable in tileables :
114+ key = tileable .key
115115
116- if key not in self ._executed_tensors :
117- raise ValueError ('Cannot fetch the unexecuted tensor ' )
116+ if key not in self ._executed_tileables :
117+ raise ValueError ('Cannot fetch the unexecuted tileable ' )
118118
119- graph_key = self ._get_tensor_graph_key ( tensor .key )
119+ graph_key = self ._get_tileable_graph_key ( tileable .key )
120120 compressions = dataserializer .get_supported_compressions ()
121121 future = self ._api .fetch_data (self ._session_id , graph_key , key , compressions , wait = False )
122122 futures .append (future )
123123 return [dataserializer .loads (f .result ()) for f in futures ]
124124
125125 def decref (self , * keys ):
126- for tensor_key , tensor_id in keys :
127- if tensor_key not in self ._executed_tensors :
126+ for tileable_key , tileable_id in keys :
127+ if tileable_key not in self ._executed_tileables :
128128 continue
129- graph_key , ids = self ._executed_tensors [ tensor_key ]
130- if tensor_id in ids :
131- ids .remove (tensor_id )
132- # for those same key tensors , do decref only when all those tensors are garbage collected
129+ graph_key , ids = self ._executed_tileables [ tileable_key ]
130+ if tileable_id in ids :
131+ ids .remove (tileable_id )
132+ # for those same key tileables , do decref only when all those tileables are garbage collected
133133 if len (ids ) != 0 :
134134 continue
135- self ._api .delete_data (self ._session_id , graph_key , tensor_key )
136- del self ._executed_tensors [ tensor_key ]
135+ self ._api .delete_data (self ._session_id , graph_key , tileable_key )
136+ del self ._executed_tileables [ tileable_key ]
137137
138138 def __enter__ (self ):
139139 return self
0 commit comments