Skip to content

Commit f1df40a

Browse files
Implement entity aliasing for feature retrival (feast-dev#1868)
* Initial commit for implementing shadow entities 8b Signed-off-by: Cody Lin <[email protected]> * fix small bug failing unit tests Signed-off-by: Cody Lin <[email protected]> * revert changes using FVProjection Signed-off-by: Cody Lin <[email protected]> * initial design for testing Signed-off-by: Cody Lin <[email protected]> * add join_key_map to FV proto definition Signed-off-by: Cody Lin <[email protected]> * handle join_key_map for proto conversions Signed-off-by: Cody Lin <[email protected]> * try to grab underlying changes Signed-off-by: Cody Lin <[email protected]> * Tests fixed after rebasing with dependencies Signed-off-by: Cody Lin <[email protected]> * Remove small unnecessary changes Signed-off-by: Cody Lin <[email protected]> * Remove small unnecessary changes Signed-off-by: Cody Lin <[email protected]> * Fix lint issues Signed-off-by: Cody Lin <[email protected]> * Fix redshift query error Signed-off-by: Cody Lin <[email protected]> * Fix cache error Signed-off-by: Cody Lin <[email protected]> * Fix benchmark test error Signed-off-by: Cody Lin <[email protected]> * remove unnecessary changes Signed-off-by: Cody Lin <[email protected]> * respond to David's comments Signed-off-by: Cody Lin <[email protected]> * Improve docstring, revert to entity_name_to_join_key_map Signed-off-by: Cody Lin <[email protected]> * Fix docstring test Signed-off-by: Cody Lin <[email protected]> * Remove code from previous design Signed-off-by: Cody Lin <[email protected]> * Improve comments and variable naming Signed-off-by: Cody Lin <[email protected]> * Use get with default Signed-off-by: Cody Lin <[email protected]> * Fix use get with default Signed-off-by: Cody Lin <[email protected]> Co-authored-by: David Y Liu <[email protected]>
1 parent b7d2d8f commit f1df40a

File tree

14 files changed

+528
-54
lines changed

14 files changed

+528
-54
lines changed

protos/feast/core/FeatureViewProjection.proto

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,7 @@ message FeatureViewProjection {
1919

2020
// The features of the feature view that are a part of the feature reference.
2121
repeated FeatureSpecV2 feature_columns = 2;
22+
23+
// Map for entity join_key overrides of feature data entity join_key to entity data join_key
24+
map<string,string> join_key_map = 4;
2225
}

sdk/python/feast/driver_test_data.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# This module generates dummy data to be used for tests and examples.
2+
import itertools
23
from enum import Enum
34

45
import numpy as np
@@ -29,22 +30,29 @@ def _convert_event_timestamp(event_timestamp: pd.Timestamp, t: EventTimestampTyp
2930

3031

3132
def create_orders_df(
32-
customers, drivers, start_date, end_date, order_count,
33+
customers, drivers, start_date, end_date, order_count, locations=None,
3334
) -> pd.DataFrame:
3435
"""
35-
Example df generated by this function:
36+
Example df generated by this function (if locations):
3637
37-
| order_id | driver_id | customer_id | order_is_success | event_timestamp |
38-
+----------+-----------+-------------+------------------+---------------------+
39-
| 100 | 5004 | 1007 | 0 | 2021-03-10 19:31:15 |
40-
| 101 | 5003 | 1006 | 0 | 2021-03-11 22:02:50 |
41-
| 102 | 5010 | 1005 | 0 | 2021-03-13 00:34:24 |
42-
| 103 | 5010 | 1001 | 1 | 2021-03-14 03:05:59 |
38+
| order_id | driver_id | customer_id | origin_id | destination_id | order_is_success | event_timestamp |
39+
+----------+-----------+-------------+-----------+----------------+------------------+---------------------+
40+
| 100 | 5004 | 1007 | 1 | 18 | 0 | 2021-03-10 19:31:15 |
41+
| 101 | 5003 | 1006 | 24 | 42 | 0 | 2021-03-11 22:02:50 |
42+
| 102 | 5010 | 1005 | 19 | 12 | 0 | 2021-03-13 00:34:24 |
43+
| 103 | 5010 | 1001 | 35 | 8 | 1 | 2021-03-14 03:05:59 |
4344
"""
4445
df = pd.DataFrame()
4546
df["order_id"] = [order_id for order_id in range(100, 100 + order_count)]
4647
df["driver_id"] = np.random.choice(drivers, order_count)
4748
df["customer_id"] = np.random.choice(customers, order_count)
49+
if locations:
50+
location_pairs = np.array(list(itertools.permutations(locations, 2)))
51+
locations_sample = location_pairs[
52+
np.random.choice(len(location_pairs), order_count)
53+
].T
54+
df["origin_id"] = locations_sample[0]
55+
df["destination_id"] = locations_sample[1]
4856
df["order_is_success"] = np.random.randint(0, 2, size=order_count).astype(np.int32)
4957
df[DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL] = [
5058
_convert_event_timestamp(
@@ -180,6 +188,46 @@ def create_customer_daily_profile_df(customers, start_date, end_date) -> pd.Data
180188
return df_all_customers
181189

182190

191+
def create_location_stats_df(locations, start_date, end_date) -> pd.DataFrame:
192+
"""
193+
Example df generated by this function:
194+
195+
| event_timestamp | location_id | temperature | created |
196+
+------------------+-------------+-------------+------------------+
197+
| 2021-03-17 19:31 | 1 | 74 | 2021-03-24 19:38 |
198+
| 2021-03-17 20:31 | 24 | 63 | 2021-03-24 19:38 |
199+
| 2021-03-17 21:31 | 19 | 65 | 2021-03-24 19:38 |
200+
| 2021-03-17 22:31 | 35 | 86 | 2021-03-24 19:38 |
201+
"""
202+
df_hourly = pd.DataFrame(
203+
{
204+
"event_timestamp": [
205+
pd.Timestamp(dt, unit="ms", tz="UTC").round("ms")
206+
for dt in pd.date_range(
207+
start=start_date, end=end_date, freq="1H", closed="left"
208+
)
209+
]
210+
}
211+
)
212+
df_all_locations = pd.DataFrame()
213+
214+
for location in locations:
215+
df_hourly_copy = df_hourly.copy()
216+
df_hourly_copy["location_id"] = location
217+
df_all_locations = pd.concat([df_hourly_copy, df_all_locations])
218+
219+
df_all_locations.reset_index(drop=True, inplace=True)
220+
rows = df_all_locations["event_timestamp"].count()
221+
222+
df_all_locations["temperature"] = np.random.randint(50, 100, size=rows).astype(
223+
np.int32
224+
)
225+
226+
# TODO: Remove created timestamp in order to test whether its really optional
227+
df_all_locations["created"] = pd.to_datetime(pd.Timestamp.now(tz=None).round("ms"))
228+
return df_all_locations
229+
230+
183231
def create_global_daily_stats_df(start_date, end_date) -> pd.DataFrame:
184232
"""
185233
Example df generated by this function:

sdk/python/feast/feature_store.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,20 @@ def get_online_features(
839839
entity_name_to_join_key_map = {}
840840
for entity in entities:
841841
entity_name_to_join_key_map[entity.name] = entity.join_key
842+
for feature_view in all_feature_views:
843+
for entity_name in feature_view.entities:
844+
entity = self._registry.get_entity(
845+
entity_name, self.project, allow_cache=True
846+
)
847+
# User directly uses join_key as the entity reference in the entity_rows for the
848+
# entity mapping case.
849+
entity_name = feature_view.projection.join_key_map.get(
850+
entity.join_key, entity.name
851+
)
852+
join_key = feature_view.projection.join_key_map.get(
853+
entity.join_key, entity.join_key
854+
)
855+
entity_name_to_join_key_map[entity_name] = join_key
842856

843857
needed_request_data_features = self._get_needed_request_data_features(
844858
grouped_odfv_refs
@@ -895,8 +909,12 @@ def get_online_features(
895909
] = GetOnlineFeaturesResponse.FieldStatus.PRESENT
896910

897911
for table, requested_features in grouped_refs:
912+
table_join_keys = [
913+
entity_name_to_join_key_map[entity_name]
914+
for entity_name in table.entities
915+
]
898916
self._populate_result_rows_from_feature_view(
899-
entity_name_to_join_key_map,
917+
table_join_keys,
900918
full_feature_names,
901919
provider,
902920
requested_features,
@@ -918,7 +936,7 @@ def get_online_features(
918936

919937
def _populate_result_rows_from_feature_view(
920938
self,
921-
entity_name_to_join_key_map: Dict[str, str],
939+
table_join_keys: List[str],
922940
full_feature_names: bool,
923941
provider: Provider,
924942
requested_features: List[str],
@@ -927,7 +945,7 @@ def _populate_result_rows_from_feature_view(
927945
union_of_entity_keys: List[EntityKeyProto],
928946
):
929947
entity_keys = _get_table_entity_keys(
930-
table, union_of_entity_keys, entity_name_to_join_key_map
948+
table, union_of_entity_keys, table_join_keys
931949
)
932950
read_rows = provider.online_read(
933951
config=self.config,
@@ -1045,8 +1063,8 @@ def _get_feature_views_to_use(
10451063
)
10461064
}
10471065

1048-
fvs_to_use, od_fvs_to_use = [], []
10491066
if isinstance(features, FeatureService):
1067+
fvs_to_use, od_fvs_to_use = [], []
10501068
for fv_name, projection in [
10511069
(projection.name, projection)
10521070
for projection in features.feature_view_projections
@@ -1137,10 +1155,12 @@ def _group_feature_refs(
11371155
""" Get list of feature views and corresponding feature names based on feature references"""
11381156

11391157
# view name to view proto
1140-
view_index = {view.name: view for view in all_feature_views}
1158+
view_index = {view.projection.name_to_use(): view for view in all_feature_views}
11411159

11421160
# on demand view to on demand view proto
1143-
on_demand_view_index = {view.name: view for view in all_on_demand_feature_views}
1161+
on_demand_view_index = {
1162+
view.projection.name_to_use(): view for view in all_on_demand_feature_views
1163+
}
11441164

11451165
# view name to feature names
11461166
views_features = defaultdict(list)
@@ -1168,15 +1188,19 @@ def _group_feature_refs(
11681188

11691189

11701190
def _get_table_entity_keys(
1171-
table: FeatureView, entity_keys: List[EntityKeyProto], join_key_map: Dict[str, str],
1191+
table: FeatureView, entity_keys: List[EntityKeyProto], table_join_keys: List[str]
11721192
) -> List[EntityKeyProto]:
1173-
table_join_keys = [join_key_map[entity_name] for entity_name in table.entities]
1193+
reverse_join_key_map = {
1194+
alias: original for original, alias in table.projection.join_key_map.items()
1195+
}
11741196
required_entities = OrderedDict.fromkeys(sorted(table_join_keys))
11751197
entity_key_protos = []
11761198
for entity_key in entity_keys:
11771199
required_entities_to_values = required_entities.copy()
11781200
for i in range(len(entity_key.join_keys)):
1179-
entity_name = entity_key.join_keys[i]
1201+
entity_name = reverse_join_key_map.get(
1202+
entity_key.join_keys[i], entity_key.join_keys[i]
1203+
)
11801204
entity_value = entity_key.entity_values[i]
11811205

11821206
if entity_name in required_entities_to_values:

sdk/python/feast/feature_view.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def __init__(
126126
self.name = name
127127
self.entities = entities if entities else [DUMMY_ENTITY_NAME]
128128
self.features = _features
129-
self.tags = tags if tags is not None else {}
129+
self.tags = tags if tags else {}
130130

131131
if isinstance(ttl, Duration):
132132
self.ttl = timedelta(seconds=int(ttl.seconds))
@@ -238,6 +238,44 @@ def with_name(self, name: str):
238238

239239
return cp
240240

241+
def with_join_key_map(self, join_key_map: Dict[str, str]):
242+
"""
243+
Sets the join_key_map by returning a copy of this feature view with that field set.
244+
This join_key mapping operation is only used as part of query operations and will
245+
not modify the underlying FeatureView.
246+
247+
Args:
248+
join_key_map: A map of join keys in which the left is the join_key that
249+
corresponds with the feature data and the right corresponds with the entity data.
250+
251+
Returns:
252+
A copy of this FeatureView with the join_key_map replaced with the 'join_key_map' input.
253+
254+
Examples:
255+
Join a location feature data table to both the origin column and destination
256+
column of the entity data.
257+
258+
temperatures_feature_service = FeatureService(
259+
name="temperatures",
260+
features=[
261+
location_stats_feature_view
262+
.with_name("origin_stats")
263+
.with_join_key_map(
264+
{"location_id": "origin_id"}
265+
),
266+
location_stats_feature_view
267+
.with_name("destination_stats")
268+
.with_join_key_map(
269+
{"location_id": "destination_id"}
270+
),
271+
],
272+
)
273+
"""
274+
cp = self.__copy__()
275+
cp.projection.join_key_map = join_key_map
276+
277+
return cp
278+
241279
def with_projection(self, feature_view_projection: FeatureViewProjection):
242280
"""
243281
Sets the feature view projection by returning a copy of this feature view

sdk/python/feast/feature_view_projection.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional
1+
from typing import Dict, List, Optional
22

33
from attr import dataclass
44

@@ -13,13 +13,16 @@ class FeatureViewProjection:
1313
name: str
1414
name_alias: Optional[str]
1515
features: List[Feature]
16+
join_key_map: Dict[str, str] = {}
1617

1718
def name_to_use(self):
1819
return self.name_alias or self.name
1920

2021
def to_proto(self):
2122
feature_reference_proto = FeatureViewProjectionProto(
22-
feature_view_name=self.name, feature_view_name_alias=self.name_alias
23+
feature_view_name=self.name,
24+
feature_view_name_alias=self.name_alias,
25+
join_key_map=self.join_key_map,
2326
)
2427
for feature in self.features:
2528
feature_reference_proto.feature_columns.append(feature.to_proto())
@@ -32,6 +35,7 @@ def from_proto(proto: FeatureViewProjectionProto):
3235
name=proto.feature_view_name,
3336
name_alias=proto.feature_view_name_alias,
3437
features=[],
38+
join_key_map=dict(proto.join_key_map),
3539
)
3640
for feature_column in proto.feature_columns:
3741
ref.features.append(Feature.from_proto(feature_column))

sdk/python/feast/infra/offline_stores/file.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,11 @@ def evaluate_historical_retrieval():
143143
table = _run_field_mapping(
144144
table, feature_view.batch_source.field_mapping
145145
)
146+
# Rename entity columns by the join_key_map dictionary if it exists
147+
if feature_view.projection.join_key_map:
148+
table = _run_field_mapping(
149+
table, feature_view.projection.join_key_map
150+
)
146151

147152
# Convert pyarrow table to pandas dataframe. Note, if the underlying data has missing values,
148153
# pandas will convert those values to np.nan if the dtypes are numerical (floats, ints, etc.) or boolean
@@ -176,7 +181,9 @@ def evaluate_historical_retrieval():
176181
# double underscore as separator for consistency with other databases like BigQuery,
177182
# where there are very few characters available for use as separators
178183
if full_feature_names:
179-
formatted_feature_name = f"{feature_view.name}__{feature}"
184+
formatted_feature_name = (
185+
f"{feature_view.projection.name_to_use()}__{feature}"
186+
)
180187
else:
181188
formatted_feature_name = feature
182189
# Add the feature name to the list of columns
@@ -191,7 +198,10 @@ def evaluate_historical_retrieval():
191198
join_keys = []
192199
for entity_name in feature_view.entities:
193200
entity = registry.get_entity(entity_name, project)
194-
join_keys.append(entity.join_key)
201+
join_key = feature_view.projection.join_key_map.get(
202+
entity.join_key, entity.join_key
203+
)
204+
join_keys.append(join_key)
195205
right_entity_columns = join_keys
196206
right_entity_key_columns = [
197207
event_timestamp_column

sdk/python/feast/infra/offline_stores/offline_utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ def get_expected_join_keys(
6464
entities = feature_view.entities
6565
for entity_name in entities:
6666
entity = registry.get_entity(entity_name, project)
67-
join_keys.add(entity.join_key)
67+
join_key = feature_view.projection.join_key_map.get(
68+
entity.join_key, entity.join_key
69+
)
70+
join_keys.add(join_key)
6871
return join_keys
6972

7073

@@ -113,11 +116,11 @@ def get_feature_view_query_context(
113116
}
114117
for entity_name in feature_view.entities:
115118
entity = registry.get_entity(entity_name, project)
116-
join_keys.append(entity.join_key)
117-
join_key_column = reverse_field_mapping.get(
119+
join_key = feature_view.projection.join_key_map.get(
118120
entity.join_key, entity.join_key
119121
)
120-
entity_selections.append(f"{join_key_column} AS {entity.join_key}")
122+
join_keys.append(join_key)
123+
entity_selections.append(f"{entity.join_key} AS {join_key}")
121124

122125
if isinstance(feature_view.ttl, timedelta):
123126
ttl_seconds = int(feature_view.ttl.total_seconds())

sdk/python/feast/infra/offline_stores/redshift.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def query_generator() -> Iterator[str]:
176176
),
177177
drop_columns=["entity_timestamp"]
178178
+ [
179-
f"{feature_view.name}__entity_row_unique_id"
179+
f"{feature_view.projection.name_to_use()}__entity_row_unique_id"
180180
for feature_view in feature_views
181181
],
182182
)

sdk/python/tests/benchmarks/test_benchmark_universal_online_retrieval.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
from tests.integration.feature_repos.repo_configuration import (
77
construct_universal_feature_views,
88
)
9-
from tests.integration.feature_repos.universal.entities import customer, driver
9+
from tests.integration.feature_repos.universal.entities import (
10+
customer,
11+
driver,
12+
location,
13+
)
1014

1115

1216
@pytest.mark.benchmark
@@ -24,7 +28,7 @@ def test_online_retrieval(environment, universal_data_sources, benchmark):
2428

2529
feast_objects = []
2630
feast_objects.extend(feature_views.values())
27-
feast_objects.extend([driver(), customer(), feature_service])
31+
feast_objects.extend([driver(), customer(), location(), feature_service])
2832
fs.apply(feast_objects)
2933
fs.materialize(environment.start_date, environment.end_date)
3034

0 commit comments

Comments
 (0)