Skip to content

Commit f331358

Browse files
Qin Xuyewjsi
authored andcommitted
Add FuseChunkData and FuseChunk to represent fused chunk (mars-project#412)
* Add FuseChunkData and FuseChunk to represent fused chunk, refactor serialization for all kinds of chunks
1 parent 91183bb commit f331358

File tree

14 files changed

+900
-80
lines changed

14 files changed

+900
-80
lines changed

mars/core.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from .utils import tokenize, AttributeDict, on_serialize_shape, \
2626
on_deserialize_shape, on_serialize_nsplits, is_eager_mode
2727
from .serialize import HasKey, ValueType, ProviderType, Serializable, AttributeAsDict, \
28-
TupleField, DictField, KeyField, BoolField, StringField
28+
TupleField, ListField, DictField, KeyField, BoolField, StringField, OneOfField
2929
from .tiles import Tileable, handler
3030
from .graph import DAG
3131

@@ -281,6 +281,73 @@ class Chunk(Entity):
281281
_allow_data_type_ = (ChunkData,)
282282

283283

284+
def _on_serialize_composed(composed):
285+
return [FuseChunkData.ChunkRef(c.data if isinstance(c, Entity) else c) for c in composed]
286+
287+
288+
def _on_deserialize_composed(refs):
289+
return [r.chunk for r in refs]
290+
291+
292+
class FuseChunkData(ChunkData):
293+
294+
class ChunkRef(Serializable):
295+
_chunk = OneOfField('chunk', tensor_chunk='mars.tensor.core.TensorChunkData',
296+
tensor='mars.tensor.core.TensorData',
297+
dataframe_chunk='mars.dataframe.core.DataFrameChunkData',
298+
dataframe='mars.dataframe.core.DataFrameData',
299+
index_chunk='mars.dataframe.core.IndexChunkData',
300+
index='mars.dataframe.core.IndexData',
301+
series_chunk='mars.dataframe.core.SeriesChunkData',
302+
series='mars.dataframe.core.SeriesData')
303+
304+
@property
305+
def chunk(self):
306+
return self._chunk
307+
308+
_composed = ListField('composed', ValueType.reference(ChunkRef),
309+
on_serialize=_on_serialize_composed,
310+
on_deserialize=_on_deserialize_composed)
311+
312+
@classmethod
313+
def cls(cls, provider):
314+
if provider.type == ProviderType.protobuf:
315+
from .serialize.protos.fusechunk_pb2 import FuseChunkDef
316+
return FuseChunkDef
317+
return super(FuseChunkData, cls).cls(provider)
318+
319+
@property
320+
def params(self):
321+
# params return the properties which useful to rebuild a new chunk
322+
p = {
323+
'index': self.index,
324+
}
325+
if self.dtype:
326+
p['dtype'] = self.dtype
327+
return p
328+
329+
@property
330+
def shape(self):
331+
return self._extra_params.get('shape')
332+
333+
@property
334+
def dtype(self):
335+
# have dtype when the last compose is tensor
336+
return self._extra_params.get('dtype') or self.op.dtype
337+
338+
@property
339+
def nbytes(self):
340+
return np.prod(self.shape) * self.dtype.itemsize
341+
342+
343+
class FuseChunk(Chunk):
344+
__slots__ = ()
345+
_allow_data_type_ = (FuseChunkData,)
346+
347+
348+
FUSE_CHUNK_TYPE = (FuseChunkData, FuseChunk)
349+
350+
284351
class TileableData(SerializableWithKey, Tileable):
285352
__slots__ = '__weakref__', '_siblings', '_cix'
286353
_no_copy_attrs_ = SerializableWithKey._no_copy_attrs_ | {'_cix'}

mars/dataframe/core.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,13 @@ class IndexChunkData(ChunkData):
260260
_dtype = DataTypeField('dtype')
261261
_index_value = ReferenceField('index_value', IndexValue)
262262

263+
@classmethod
264+
def cls(cls, provider):
265+
if provider.type == ProviderType.protobuf:
266+
from ..serialize.protos.dataframe_pb2 import IndexChunkDef
267+
return IndexChunkDef
268+
return super(IndexChunkData, cls).cls(provider)
269+
263270
@property
264271
def dtype(self):
265272
return self._dtype
@@ -314,6 +321,13 @@ class SeriesChunkData(ChunkData):
314321
_dtype = DataTypeField('dtype')
315322
_index_value = ReferenceField('index_value', IndexValue)
316323

324+
@classmethod
325+
def cls(cls, provider):
326+
if provider.type == ProviderType.protobuf:
327+
from ..serialize.protos.dataframe_pb2 import SeriesChunkDef
328+
return SeriesChunkDef
329+
return super(SeriesChunkData, cls).cls(provider)
330+
317331
@property
318332
def dtype(self):
319333
return self._dtype
@@ -375,9 +389,16 @@ class DataFrameChunkData(ChunkData):
375389
_index_value = ReferenceField('index_value', IndexValue)
376390
_columns_value = ReferenceField('columns_value', IndexValue)
377391

392+
@classmethod
393+
def cls(cls, provider):
394+
if provider.type == ProviderType.protobuf:
395+
from ..serialize.protos.dataframe_pb2 import DataFrameChunkDef
396+
return DataFrameChunkDef
397+
return super(DataFrameChunkData, cls).cls(provider)
398+
378399
@property
379400
def params(self):
380-
# params return the properties which useful to rebuild a new tileable object
401+
# params return the properties which useful to rebuild a new chunk
381402
return {
382403
'shape': self.shape,
383404
'dtypes': self.dtypes,

mars/graph.pyx

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ cdef class DirectedGraph:
320320

321321
from .tensor.core import CHUNK_TYPE as TENSOR_CHUNK_TYPE, TENSOR_TYPE, Chunk, Tensor
322322
from .dataframe.core import CHUNK_TYPE as DATAFRAME_CHUNK_TYPE, DATAFRAME_TYPE, DataFrame
323+
from .core import FUSE_CHUNK_TYPE
323324

324325
level = None
325326

@@ -334,7 +335,7 @@ cdef class DirectedGraph:
334335
visited.add(c)
335336

336337
for node in self.iter_nodes():
337-
if isinstance(node, TENSOR_CHUNK_TYPE + DATAFRAME_CHUNK_TYPE):
338+
if isinstance(node, TENSOR_CHUNK_TYPE + DATAFRAME_CHUNK_TYPE + FUSE_CHUNK_TYPE):
338339
node = node.data if isinstance(node, Chunk) else node
339340
add_obj(node)
340341
if node.composed:
@@ -510,7 +511,12 @@ class SerializableGraphNode(Serializable):
510511
tensor_chunk='mars.tensor.core.TensorChunkData',
511512
tensor='mars.tensor.core.TensorData',
512513
dataframe_chunk='mars.dataframe.core.DataFrameChunkData',
513-
dataframe='mars.dataframe.core.DataFrameData')
514+
dataframe='mars.dataframe.core.DataFrameData',
515+
index_chunk='mars.dataframe.core.IndexChunkData',
516+
index='mars.dataframe.core.Index',
517+
series_chunk='mars.dataframe.core.SeriesChunkData',
518+
series='mars.dataframe.core.Series',
519+
fuse_chunk='mars.core.FuseChunkData')
514520

515521
@classmethod
516522
def cls(cls, provider):

mars/serialize/protos/chunk.proto

Lines changed: 0 additions & 20 deletions
This file was deleted.

mars/serialize/protos/dataframe.proto

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,18 @@ syntax = "proto3";
22

33
import "mars/serialize/protos/value.proto";
44
import "mars/serialize/protos/indexvalue.proto";
5-
import "mars/serialize/protos/chunk.proto";
5+
6+
7+
message IndexChunkDef {
8+
string key = 1;
9+
repeated uint32 index = 2 [packed = true];
10+
repeated int64 shape = 3;
11+
Value op = 4; // just store key here
12+
bool cached = 5;
13+
Value dtype = 6;
14+
Value extra_params = 8;
15+
string id = 9;
16+
}
617

718

819
message IndexDef {
@@ -11,36 +22,64 @@ message IndexDef {
1122
Value dtype = 3;
1223
Value op = 4; // store operand's key and id
1324
Value nsplits = 5;
14-
repeated ChunkDef chunks = 6;
25+
repeated IndexChunkDef chunks = 6;
1526
Value extra_params = 7;
1627
string id = 8;
1728
IndexValue index_value = 9;
1829
}
1930

2031

32+
message SeriesChunkDef {
33+
string key = 1;
34+
repeated uint32 index = 2 [packed = true];
35+
repeated int64 shape = 3;
36+
Value op = 4; // just store key here
37+
bool cached = 5;
38+
Value dtype = 6;
39+
IndexValue index_value = 11;
40+
Value extra_params = 8;
41+
string id = 9;
42+
}
43+
44+
2145
message SeriesDef {
2246
string key = 1;
2347
repeated int64 shape = 2;
2448
Value dtype = 3;
2549
Value op = 4; // store operand's key and id
2650
Value nsplits = 5;
27-
repeated ChunkDef chunks = 6;
51+
repeated IndexChunkDef chunks = 6;
2852
Value extra_params = 7;
2953
string id = 8;
3054
Value name = 9;
3155
IndexValue index_value = 10;
3256
}
3357

3458

59+
message DataFrameChunkDef {
60+
string key = 1;
61+
repeated uint32 index = 2 [packed = true];
62+
repeated int64 shape = 3;
63+
Value op = 4; // just store key here
64+
bool cached = 5;
65+
Value dtypes = 10;
66+
IndexValue index_value = 11;
67+
IndexValue columns_value = 12;
68+
Value extra_params = 8;
69+
string id = 9;
70+
}
71+
72+
3573
message DataFrameDef {
3674
string key = 1;
3775
repeated int64 shape = 2;
3876
Value dtypes = 3;
3977
Value op = 4; // store operand's key and id
4078
Value nsplits = 5;
41-
repeated ChunkDef chunks = 6;
79+
repeated DataFrameChunkDef chunks = 6;
4280
Value extra_params = 7;
4381
string id = 8;
4482
IndexValue index_value = 9;
4583
IndexValue columns_value = 10;
4684
}
85+

0 commit comments

Comments
 (0)