11import logging
22from dataclasses import dataclass
33from datetime import datetime , timedelta
4- from typing import Any , Dict , List , Tuple , Union
4+ from typing import Any , Dict , List , Optional , Tuple , Union
55
66import numpy as np
77import pandas as pd
1212from feast .types import (
1313 Array ,
1414 Bool ,
15+ FeastType ,
1516 Float32 ,
1617 Float64 ,
1718 Int32 ,
@@ -42,20 +43,15 @@ def test_entity_inference_types_match(environment, entity_type):
4243 destination_name = f"entity_type_{ entity_type .name .lower ()} " ,
4344 field_mapping = {"ts_1" : "ts" },
4445 )
45- fv = create_feature_view (
46- f"fv_entity_type_{ entity_type .name .lower ()} " ,
47- feature_dtype = "int32" ,
48- feature_is_list = False ,
49- has_empty_list = False ,
46+ fv = driver_feature_view (
5047 data_source = data_source ,
48+ name = f"fv_entity_type_{ entity_type .name .lower ()} " ,
49+ infer_entities = True , # Forces entity inference by not including a field for the entity.
50+ dtype = _get_feast_type ("int32" , False ),
5151 entity_type = entity_type ,
5252 )
5353
54- # TODO(felixwang9817): Refactor this by finding a better way to force type inference.
55- # Override the schema and entity_columns to force entity inference.
5654 entity = driver ()
57- fv .schema = list (filter (lambda x : x .name != entity .join_key , fv .schema ))
58- fv .entity_columns = []
5955 fs .apply ([fv , entity ])
6056
6157 entity_type_to_expected_inferred_entity_type = {
@@ -88,12 +84,10 @@ def test_feature_get_historical_features_types_match(
8884 config , data_source , fv = offline_types_test_fixtures
8985 fs = environment .feature_store
9086 entity = driver ()
91- fv = create_feature_view (
92- "get_historical_features_types_match" ,
93- config .feature_dtype ,
94- config .feature_is_list ,
95- config .has_empty_list ,
96- data_source ,
87+ fv = driver_feature_view (
88+ data_source = data_source ,
89+ name = "get_historical_features_types_match" ,
90+ dtype = _get_feast_type (config .feature_dtype , config .feature_is_list ),
9791 )
9892 fs .apply ([fv , entity ])
9993
@@ -139,12 +133,10 @@ def test_feature_get_online_features_types_match(
139133):
140134 config , data_source , fv = online_types_test_fixtures
141135 entity = driver ()
142- fv = create_feature_view (
143- "get_online_features_types_match" ,
144- config .feature_dtype ,
145- config .feature_is_list ,
146- config .has_empty_list ,
147- data_source ,
136+ fv = driver_feature_view (
137+ data_source = data_source ,
138+ name = "get_online_features_types_match" ,
139+ dtype = _get_feast_type (config .feature_dtype , config .feature_is_list ),
148140 )
149141 fs = environment .feature_store
150142 features = [fv .name + ":value" ]
@@ -188,14 +180,8 @@ def test_feature_get_online_features_types_match(
188180 assert isinstance (feature , expected_dtype )
189181
190182
191- def create_feature_view (
192- name ,
193- feature_dtype ,
194- feature_is_list ,
195- has_empty_list ,
196- data_source ,
197- entity_type = Int64 ,
198- ):
183+ def _get_feast_type (feature_dtype : str , feature_is_list : bool ) -> FeastType :
184+ dtype : Optional [FeastType ] = None
199185 if feature_is_list is True :
200186 if feature_dtype == "int32" :
201187 dtype = Array (Int32 )
@@ -218,10 +204,8 @@ def create_feature_view(
218204 dtype = Bool
219205 elif feature_dtype == "datetime" :
220206 dtype = UnixTimestamp
221-
222- return driver_feature_view (
223- data_source , name = name , dtype = dtype , entity_type = entity_type
224- )
207+ assert dtype
208+ return dtype
225209
226210
227211def assert_expected_historical_feature_types (
@@ -388,12 +372,10 @@ def get_fixtures(request, environment):
388372 destination_name = destination_name ,
389373 field_mapping = {"ts_1" : "ts" },
390374 )
391- fv = create_feature_view (
392- destination_name ,
393- config .feature_dtype ,
394- config .feature_is_list ,
395- config .has_empty_list ,
396- data_source ,
375+ fv = driver_feature_view (
376+ data_source = data_source ,
377+ name = destination_name ,
378+ dtype = _get_feast_type (config .feature_dtype , config .feature_is_list ),
397379 )
398380
399381 return config , data_source , fv
0 commit comments