Skip to content

Commit 5a83c6e

Browse files
chore: Force entity inference without modifying fv.schema (feast-dev#3448)
* Remove unnecessary wrapper `create_feature_view` Signed-off-by: Felix Wang <[email protected]> * Add `infer_entities` option to `driver_feature_view` Signed-off-by: Felix Wang <[email protected]> * Force entity inference Signed-off-by: Felix Wang <[email protected]> Signed-off-by: Felix Wang <[email protected]>
1 parent 73930f6 commit 5a83c6e

File tree

2 files changed

+24
-41
lines changed

2 files changed

+24
-41
lines changed

sdk/python/tests/integration/feature_repos/universal/feature_views.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
def driver_feature_view(
2727
data_source: DataSource,
2828
name="test_correctness",
29+
infer_entities: bool = False,
2930
infer_features: bool = False,
3031
dtype: FeastType = Float32,
3132
entity_type: FeastType = Int64,
@@ -34,7 +35,7 @@ def driver_feature_view(
3435
return FeatureView(
3536
name=name,
3637
entities=[d],
37-
schema=[Field(name=d.join_key, dtype=entity_type)]
38+
schema=([] if infer_entities else [Field(name=d.join_key, dtype=entity_type)])
3839
+ ([] if infer_features else [Field(name="value", dtype=dtype)]),
3940
ttl=timedelta(days=5),
4041
source=data_source,

sdk/python/tests/integration/registration/test_universal_types.py

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
from dataclasses import dataclass
33
from datetime import datetime, timedelta
4-
from typing import Any, Dict, List, Tuple, Union
4+
from typing import Any, Dict, List, Optional, Tuple, Union
55

66
import numpy as np
77
import pandas as pd
@@ -12,6 +12,7 @@
1212
from feast.types import (
1313
Array,
1414
Bool,
15+
FeastType,
1516
Float32,
1617
Float64,
1718
Int32,
@@ -42,20 +43,15 @@ def test_entity_inference_types_match(environment, entity_type):
4243
destination_name=f"entity_type_{entity_type.name.lower()}",
4344
field_mapping={"ts_1": "ts"},
4445
)
45-
fv = create_feature_view(
46-
f"fv_entity_type_{entity_type.name.lower()}",
47-
feature_dtype="int32",
48-
feature_is_list=False,
49-
has_empty_list=False,
46+
fv = driver_feature_view(
5047
data_source=data_source,
48+
name=f"fv_entity_type_{entity_type.name.lower()}",
49+
infer_entities=True, # Forces entity inference by not including a field for the entity.
50+
dtype=_get_feast_type("int32", False),
5151
entity_type=entity_type,
5252
)
5353

54-
# TODO(felixwang9817): Refactor this by finding a better way to force type inference.
55-
# Override the schema and entity_columns to force entity inference.
5654
entity = driver()
57-
fv.schema = list(filter(lambda x: x.name != entity.join_key, fv.schema))
58-
fv.entity_columns = []
5955
fs.apply([fv, entity])
6056

6157
entity_type_to_expected_inferred_entity_type = {
@@ -88,12 +84,10 @@ def test_feature_get_historical_features_types_match(
8884
config, data_source, fv = offline_types_test_fixtures
8985
fs = environment.feature_store
9086
entity = driver()
91-
fv = create_feature_view(
92-
"get_historical_features_types_match",
93-
config.feature_dtype,
94-
config.feature_is_list,
95-
config.has_empty_list,
96-
data_source,
87+
fv = driver_feature_view(
88+
data_source=data_source,
89+
name="get_historical_features_types_match",
90+
dtype=_get_feast_type(config.feature_dtype, config.feature_is_list),
9791
)
9892
fs.apply([fv, entity])
9993

@@ -139,12 +133,10 @@ def test_feature_get_online_features_types_match(
139133
):
140134
config, data_source, fv = online_types_test_fixtures
141135
entity = driver()
142-
fv = create_feature_view(
143-
"get_online_features_types_match",
144-
config.feature_dtype,
145-
config.feature_is_list,
146-
config.has_empty_list,
147-
data_source,
136+
fv = driver_feature_view(
137+
data_source=data_source,
138+
name="get_online_features_types_match",
139+
dtype=_get_feast_type(config.feature_dtype, config.feature_is_list),
148140
)
149141
fs = environment.feature_store
150142
features = [fv.name + ":value"]
@@ -188,14 +180,8 @@ def test_feature_get_online_features_types_match(
188180
assert isinstance(feature, expected_dtype)
189181

190182

191-
def create_feature_view(
192-
name,
193-
feature_dtype,
194-
feature_is_list,
195-
has_empty_list,
196-
data_source,
197-
entity_type=Int64,
198-
):
183+
def _get_feast_type(feature_dtype: str, feature_is_list: bool) -> FeastType:
184+
dtype: Optional[FeastType] = None
199185
if feature_is_list is True:
200186
if feature_dtype == "int32":
201187
dtype = Array(Int32)
@@ -218,10 +204,8 @@ def create_feature_view(
218204
dtype = Bool
219205
elif feature_dtype == "datetime":
220206
dtype = UnixTimestamp
221-
222-
return driver_feature_view(
223-
data_source, name=name, dtype=dtype, entity_type=entity_type
224-
)
207+
assert dtype
208+
return dtype
225209

226210

227211
def assert_expected_historical_feature_types(
@@ -388,12 +372,10 @@ def get_fixtures(request, environment):
388372
destination_name=destination_name,
389373
field_mapping={"ts_1": "ts"},
390374
)
391-
fv = create_feature_view(
392-
destination_name,
393-
config.feature_dtype,
394-
config.feature_is_list,
395-
config.has_empty_list,
396-
data_source,
375+
fv = driver_feature_view(
376+
data_source=data_source,
377+
name=destination_name,
378+
dtype=_get_feast_type(config.feature_dtype, config.feature_is_list),
397379
)
398380

399381
return config, data_source, fv

0 commit comments

Comments
 (0)