Skip to content

Commit ce45062

Browse files
committed
add timezone to type definition
Signed-off-by: pyalex <[email protected]>
1 parent 12a8459 commit ce45062

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

sdk/python/feast/embedded_go/type_map.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from typing import List
22

33
import pyarrow as pa
4+
import pytz
45

56
from feast.protos.feast.types import Value_pb2
67
from feast.types import Array, PrimitiveFeastType
78

9+
PA_TIMESTAMP_TYPE = pa.timestamp("s", tz=pytz.UTC)
10+
811
ARROW_TYPE_TO_PROTO_FIELD = {
912
pa.int32(): "int32_val",
1013
pa.int64(): "int64_val",
@@ -13,7 +16,7 @@
1316
pa.bool_(): "bool_val",
1417
pa.string(): "string_val",
1518
pa.binary(): "bytes_val",
16-
pa.timestamp("s"): "unix_timestamp_val",
19+
PA_TIMESTAMP_TYPE: "unix_timestamp_val",
1720
}
1821

1922
ARROW_LIST_TYPE_TO_PROTO_FIELD = {
@@ -24,7 +27,7 @@
2427
pa.bool_(): "bool_list_val",
2528
pa.string(): "string_list_val",
2629
pa.binary(): "bytes_list_val",
27-
pa.timestamp("s"): "unix_timestamp_list_val",
30+
PA_TIMESTAMP_TYPE: "unix_timestamp_list_val",
2831
}
2932

3033
ARROW_LIST_TYPE_TO_PROTO_LIST_CLASS = {
@@ -35,7 +38,7 @@
3538
pa.bool_(): Value_pb2.BoolList,
3639
pa.string(): Value_pb2.StringList,
3740
pa.binary(): Value_pb2.BytesList,
38-
pa.timestamp("s"): Value_pb2.Int64List,
41+
PA_TIMESTAMP_TYPE: Value_pb2.Int64List,
3942
}
4043

4144
FEAST_TYPE_TO_ARROW_TYPE = {
@@ -66,7 +69,7 @@ def arrow_array_to_array_of_proto(
6669
proto_list_class = ARROW_LIST_TYPE_TO_PROTO_LIST_CLASS[arrow_type.value_type]
6770
proto_field_name = ARROW_LIST_TYPE_TO_PROTO_FIELD[arrow_type.value_type]
6871

69-
if arrow_type.value_type == pa.timestamp("s"):
72+
if arrow_type.value_type == PA_TIMESTAMP_TYPE:
7073
arrow_array = arrow_array.cast(pa.list_(pa.int64()))
7174

7275
for v in arrow_array.tolist():
@@ -76,7 +79,7 @@ def arrow_array_to_array_of_proto(
7679
else:
7780
proto_field_name = ARROW_TYPE_TO_PROTO_FIELD[arrow_type]
7881

79-
if arrow_type == pa.timestamp("s"):
82+
if arrow_type == PA_TIMESTAMP_TYPE:
8083
arrow_array = arrow_array.cast(pa.int64())
8184

8285
for v in arrow_array.tolist():

0 commit comments

Comments
 (0)