Skip to content

Commit 17231d0

Browse files
authored
Add entity column validations when getting historical features from bigquery (feast-dev#1614)
* Add entity column validations when getting historical features from bigquery Signed-off-by: Achal Shah <[email protected]> * make format Signed-off-by: Achal Shah <[email protected]> * Remove wrong file Signed-off-by: Achal Shah <[email protected]> * Add tests Signed-off-by: Achal Shah <[email protected]>
1 parent e712782 commit 17231d0

File tree

3 files changed

+102
-2
lines changed

3 files changed

+102
-2
lines changed

sdk/python/feast/errors.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,11 @@ def __init__(self, offline_store_name: str, data_source_name: str):
6767
super().__init__(
6868
f"Offline Store '{offline_store_name}' does not support data source '{data_source_name}'"
6969
)
70+
71+
72+
class FeastEntityDFMissingColumnsError(Exception):
73+
def __init__(self, expected, missing):
74+
super().__init__(
75+
f"The entity dataframe you have provided must contain columns {expected}, "
76+
f"but {missing} were missing."
77+
)

sdk/python/feast/infra/offline_stores/bigquery.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import time
22
from dataclasses import asdict, dataclass
33
from datetime import datetime, timedelta
4-
from typing import List, Optional, Union
4+
from typing import List, Optional, Set, Union
55

66
import pandas
77
import pyarrow
88
from jinja2 import BaseLoader, Environment
99

10+
from feast import errors
1011
from feast.data_source import BigQuerySource, DataSource
1112
from feast.errors import FeastProviderLoginError
1213
from feast.feature_view import FeatureView
@@ -87,13 +88,18 @@ def get_historical_features(
8788

8889
client = _get_bigquery_client()
8990

91+
expected_join_keys = _get_join_keys(project, feature_views, registry)
92+
9093
if type(entity_df) is str:
9194
entity_df_job = client.query(entity_df)
9295
entity_df_result = entity_df_job.result() # also starts job
9396

9497
entity_df_event_timestamp_col = _infer_event_timestamp_from_bigquery_query(
9598
entity_df_result
9699
)
100+
_assert_expected_columns_in_bigquery(
101+
expected_join_keys, entity_df_event_timestamp_col, entity_df_result
102+
)
97103

98104
entity_df_sql_table = f"`{entity_df_job.destination.project}.{entity_df_job.destination.dataset_id}.{entity_df_job.destination.table_id}`"
99105
elif isinstance(entity_df, pandas.DataFrame):
@@ -103,6 +109,10 @@ def get_historical_features(
103109

104110
assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)
105111

112+
_assert_expected_columns_in_dataframe(
113+
expected_join_keys, entity_df_event_timestamp_col, entity_df
114+
)
115+
106116
table_id = _upload_entity_df_into_bigquery(
107117
config.project, config.offline_store.dataset, entity_df, client
108118
)
@@ -132,6 +142,47 @@ def get_historical_features(
132142
return job
133143

134144

145+
def _assert_expected_columns_in_dataframe(
146+
join_keys: Set[str], entity_df_event_timestamp_col: str, entity_df: pandas.DataFrame
147+
):
148+
entity_df_columns = set(entity_df.columns.values)
149+
expected_columns = join_keys.copy()
150+
expected_columns.add(entity_df_event_timestamp_col)
151+
152+
missing_keys = expected_columns - entity_df_columns
153+
154+
if len(missing_keys) != 0:
155+
raise errors.FeastEntityDFMissingColumnsError(expected_columns, missing_keys)
156+
157+
158+
def _assert_expected_columns_in_bigquery(
159+
join_keys: Set[str], entity_df_event_timestamp_col: str, entity_df_result
160+
):
161+
entity_columns = set()
162+
for schema_field in entity_df_result.schema:
163+
entity_columns.add(schema_field.name)
164+
165+
expected_columns = join_keys.copy()
166+
expected_columns.add(entity_df_event_timestamp_col)
167+
168+
missing_keys = expected_columns - entity_columns
169+
170+
if len(missing_keys) != 0:
171+
raise errors.FeastEntityDFMissingColumnsError(expected_columns, missing_keys)
172+
173+
174+
def _get_join_keys(
175+
project: str, feature_views: List[FeatureView], registry: Registry
176+
) -> Set[str]:
177+
join_keys = set()
178+
for feature_view in feature_views:
179+
entities = feature_view.entities
180+
for entity_name in entities:
181+
entity = registry.get_entity(entity_name, project)
182+
join_keys.add(entity.join_key)
183+
return join_keys
184+
185+
135186
def _infer_event_timestamp_from_bigquery_query(entity_df_result) -> str:
136187
if any(
137188
schema_field.name == DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL

sdk/python/tests/test_historical_retrieval.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pytz import utc
1515

1616
import feast.driver_test_data as driver_data
17-
from feast import utils
17+
from feast import errors, utils
1818
from feast.data_source import BigQuerySource, FileSource
1919
from feast.entity import Entity
2020
from feast.feature import Feature
@@ -454,6 +454,30 @@ def test_historical_features_from_bigquery_sources(
454454
check_dtype=False,
455455
)
456456

457+
timestamp_column = (
458+
"e_ts"
459+
if infer_event_timestamp_col
460+
else DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
461+
)
462+
463+
entity_df_query_with_invalid_join_key = (
464+
f"select order_id, driver_id, customer_id as customer, "
465+
f"order_is_success, {timestamp_column}, FROM {gcp_project}.{table_id}"
466+
)
467+
# Rename the join key; this should now raise an error.
468+
assertpy.assert_that(store.get_historical_features).raises(
469+
errors.FeastEntityDFMissingColumnsError
470+
).when_called_with(
471+
entity_df=entity_df_query_with_invalid_join_key,
472+
feature_refs=[
473+
"driver_stats:conv_rate",
474+
"driver_stats:avg_daily_trips",
475+
"customer_profile:current_balance",
476+
"customer_profile:avg_passenger_count",
477+
"customer_profile:lifetime_trip_count",
478+
],
479+
)
480+
457481
job_from_df = store.get_historical_features(
458482
entity_df=orders_df,
459483
feature_refs=[
@@ -465,6 +489,23 @@ def test_historical_features_from_bigquery_sources(
465489
],
466490
)
467491

492+
# Rename the join key; this should now raise an error.
493+
orders_df_with_invalid_join_key = orders_df.rename(
494+
{"customer_id": "customer"}, axis="columns"
495+
)
496+
assertpy.assert_that(store.get_historical_features).raises(
497+
errors.FeastEntityDFMissingColumnsError
498+
).when_called_with(
499+
entity_df=orders_df_with_invalid_join_key,
500+
feature_refs=[
501+
"driver_stats:conv_rate",
502+
"driver_stats:avg_daily_trips",
503+
"customer_profile:current_balance",
504+
"customer_profile:avg_passenger_count",
505+
"customer_profile:lifetime_trip_count",
506+
],
507+
)
508+
468509
# Make sure that custom dataset name is being used from the offline_store config
469510
if provider_type == "gcp_custom_offline_config":
470511
assertpy.assert_that(job_from_df.query).contains("foo.entity_df")

0 commit comments

Comments
 (0)