Skip to content

Commit ef1884f

Browse files
authored
Fix ValueType.UNIX_TIMESTAMP conversions (feast-dev#2219)
* Handle `np.datetime64` to `ValueType.UNIX_TIMESTAMP` conversion Signed-off-by: Judah Rand <[email protected]> * Add `datetime` feature to tests Signed-off-by: Judah Rand <[email protected]> * Fix `datetime` features in `type_map.py` Signed-off-by: Judah Rand <[email protected]>
1 parent 6f1174a commit ef1884f

File tree

3 files changed

+52
-12
lines changed

3 files changed

+52
-12
lines changed

sdk/python/feast/type_map.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def _type_err(item, dtype):
214214
ValueType.UNIX_TIMESTAMP_LIST: (
215215
Int64List,
216216
"int64_list_val",
217-
[np.int64, np.int32, int],
217+
[np.datetime64, np.int64, np.int32, int, datetime, Timestamp],
218218
),
219219
ValueType.STRING_LIST: (StringList, "string_list_val", [np.str_, str]),
220220
ValueType.BOOL_LIST: (BoolList, "bool_list_val", [np.bool_, bool]),
@@ -274,6 +274,24 @@ def _python_value_to_proto_value(
274274
)
275275
raise _type_err(first_invalid, valid_types[0])
276276

277+
if feast_value_type == ValueType.UNIX_TIMESTAMP_LIST:
278+
converted_values = []
279+
for value in values:
280+
converted_sub_values = []
281+
for sub_value in value:
282+
if isinstance(sub_value, datetime):
283+
converted_sub_values.append(int(sub_value.timestamp()))
284+
elif isinstance(sub_value, Timestamp):
285+
converted_sub_values.append(int(sub_value.ToSeconds()))
286+
elif isinstance(sub_value, np.datetime64):
287+
converted_sub_values.append(
288+
sub_value.astype("datetime64[s]").astype("int")
289+
)
290+
else:
291+
converted_sub_values.append(sub_value)
292+
converted_values.append(converted_sub_values)
293+
values = converted_values
294+
277295
return [
278296
ProtoValue(**{field_name: proto_type(val=value)}) # type: ignore
279297
if value is not None
@@ -292,6 +310,11 @@ def _python_value_to_proto_value(
292310
return [
293311
ProtoValue(int64_val=int(value.ToSeconds())) for value in values
294312
]
313+
elif isinstance(sample, np.datetime64):
314+
return [
315+
ProtoValue(int64_val=value.astype("datetime64[s]").astype("int"))
316+
for value in values
317+
]
295318
return [ProtoValue(int64_val=int(value)) for value in values]
296319

297320
if feast_value_type in PYTHON_SCALAR_VALUE_TYPE_TO_PROTO_VALUE:

sdk/python/tests/data/data_creator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ def get_feature_values_for_dtype(
6060
"float": [1.0, None, 3.0, 4.0, 5.0],
6161
"string": ["1", None, "3", "4", "5"],
6262
"bool": [True, None, False, True, False],
63+
"datetime": [
64+
datetime(1980, 1, 1),
65+
None,
66+
datetime(1981, 1, 1),
67+
datetime(1982, 1, 1),
68+
datetime(1982, 1, 1),
69+
],
6370
}
6471
non_list_val = dtype_map[dtype]
6572
if is_list:

sdk/python/tests/integration/registration/test_universal_types.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import re
23
from dataclasses import dataclass
34
from datetime import datetime, timedelta
45
from typing import Any, Dict, List, Tuple, Union
@@ -28,6 +29,7 @@ def populate_test_configs(offline: bool):
2829
(ValueType.INT64, "int64"),
2930
(ValueType.STRING, "float"),
3031
(ValueType.STRING, "bool"),
32+
(ValueType.INT32, "datetime"),
3133
]
3234
configs: List[TypeTestConfig] = []
3335
for test_repo_config in FULL_REPO_CONFIGS:
@@ -232,6 +234,7 @@ def test_feature_get_online_features_types_match(online_types_test_fixtures):
232234
"float": float,
233235
"string": str,
234236
"bool": bool,
237+
"datetime": int,
235238
}
236239
expected_dtype = feature_list_dtype_to_expected_online_response_value_type[
237240
config.feature_dtype
@@ -258,6 +261,8 @@ def create_feature_view(
258261
value_type = ValueType.FLOAT_LIST
259262
elif feature_dtype == "bool":
260263
value_type = ValueType.BOOL_LIST
264+
elif feature_dtype == "datetime":
265+
value_type = ValueType.UNIX_TIMESTAMP_LIST
261266
else:
262267
if feature_dtype == "int32":
263268
value_type = ValueType.INT32
@@ -267,6 +272,8 @@ def create_feature_view(
267272
value_type = ValueType.FLOAT
268273
elif feature_dtype == "bool":
269274
value_type = ValueType.BOOL
275+
elif feature_dtype == "datetime":
276+
value_type = ValueType.UNIX_TIMESTAMP
270277

271278
return driver_feature_view(data_source, name=name, value_type=value_type,)
272279

@@ -281,6 +288,7 @@ def assert_expected_historical_feature_types(
281288
"float": (pd.api.types.is_float_dtype,),
282289
"string": (pd.api.types.is_string_dtype,),
283290
"bool": (pd.api.types.is_bool_dtype, pd.api.types.is_object_dtype),
291+
"datetime": (pd.api.types.is_datetime64_any_dtype,),
284292
}
285293
dtype_checkers = feature_dtype_to_expected_historical_feature_dtype[feature_dtype]
286294
assert any(
@@ -309,6 +317,7 @@ def assert_feature_list_types(
309317
bool,
310318
np.bool_,
311319
), # Can be `np.bool_` if from `np.array` rather that `list`
320+
"datetime": np.datetime64,
312321
}
313322
expected_dtype = feature_list_dtype_to_expected_historical_feature_list_dtype[
314323
feature_dtype
@@ -330,22 +339,23 @@ def assert_expected_arrow_types(
330339
historical_features_arrow = historical_features.to_arrow()
331340
print(historical_features_arrow)
332341
feature_list_dtype_to_expected_historical_feature_arrow_type = {
333-
"int32": "int64",
334-
"int64": "int64",
335-
"float": "double",
336-
"string": "string",
337-
"bool": "bool",
342+
"int32": r"int64",
343+
"int64": r"int64",
344+
"float": r"double",
345+
"string": r"string",
346+
"bool": r"bool",
347+
"datetime": r"timestamp\[.+\]",
338348
}
339349
arrow_type = feature_list_dtype_to_expected_historical_feature_arrow_type[
340350
feature_dtype
341351
]
342352
if feature_is_list:
343-
assert (
344-
str(historical_features_arrow.schema.field_by_name("value").type)
345-
== f"list<item: {arrow_type}>"
353+
assert re.match(
354+
f"list<item: {arrow_type}>",
355+
str(historical_features_arrow.schema.field_by_name("value").type),
346356
)
347357
else:
348-
assert (
349-
str(historical_features_arrow.schema.field_by_name("value").type)
350-
== arrow_type
358+
assert re.match(
359+
arrow_type,
360+
str(historical_features_arrow.schema.field_by_name("value").type),
351361
)

0 commit comments

Comments
 (0)