Skip to content

Commit 3d0abf2

Browse files
author
Tsotne Tabidze
authored
Fix offline store (tz-naive & field_mapping issues) (feast-dev#1466)
* Fix offline store (tz-naive & field_mapping issues) Signed-off-by: Tsotne Tabidze <[email protected]> * Rename test_materialize.py to test_offline_online_store_consistency.py Signed-off-by: Tsotne Tabidze <[email protected]>
1 parent a288dff commit 3d0abf2

File tree

3 files changed

+100
-36
lines changed

3 files changed

+100
-36
lines changed

sdk/python/feast/infra/offline_stores/bigquery.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import time
22
from dataclasses import asdict, dataclass
33
from datetime import datetime, timedelta
4-
from typing import Dict, List, Optional, Union
4+
from typing import List, Optional, Union
55

66
import pandas
77
import pyarrow
@@ -130,9 +130,9 @@ class FeatureViewQueryContext:
130130
table_ref: str
131131
event_timestamp_column: str
132132
created_timestamp_column: str
133-
field_mapping: Dict[str, str]
134133
query: str
135134
table_subquery: str
135+
entity_selections: List[str]
136136

137137

138138
def _upload_entity_df_into_bigquery(project, entity_df) -> str:
@@ -178,9 +178,17 @@ def get_feature_view_query_context(
178178
query_context = []
179179
for feature_view, features in feature_views_to_feature_map.items():
180180
join_keys = []
181+
entity_selections = []
182+
reverse_field_mapping = {
183+
v: k for k, v in feature_view.input.field_mapping.items()
184+
}
181185
for entity_name in feature_view.entities:
182186
entity = registry.get_entity(entity_name, project)
183187
join_keys.append(entity.join_key)
188+
join_key_column = reverse_field_mapping.get(
189+
entity.join_key, entity.join_key
190+
)
191+
entity_selections.append(f"{join_key_column} AS {entity.join_key}")
184192

185193
if isinstance(feature_view.ttl, timedelta):
186194
ttl_seconds = int(feature_view.ttl.total_seconds())
@@ -189,18 +197,25 @@ def get_feature_view_query_context(
189197

190198
assert isinstance(feature_view.input, BigQuerySource)
191199

200+
event_timestamp_column = feature_view.input.event_timestamp_column
201+
created_timestamp_column = feature_view.input.created_timestamp_column
202+
192203
context = FeatureViewQueryContext(
193204
name=feature_view.name,
194205
ttl=ttl_seconds,
195206
entities=join_keys,
196207
features=features,
197208
table_ref=feature_view.input.table_ref,
198-
event_timestamp_column=feature_view.input.event_timestamp_column,
199-
created_timestamp_column=feature_view.input.created_timestamp_column,
209+
event_timestamp_column=reverse_field_mapping.get(
210+
event_timestamp_column, event_timestamp_column
211+
),
212+
created_timestamp_column=reverse_field_mapping.get(
213+
created_timestamp_column, created_timestamp_column
214+
),
200215
# TODO: Make created column optional and not hardcoded
201-
field_mapping=feature_view.input.field_mapping,
202216
query=feature_view.input.query,
203217
table_subquery=feature_view.input.get_table_query_string(),
218+
entity_selections=entity_selections,
204219
)
205220
query_context.append(context)
206221
return query_context
@@ -267,7 +282,7 @@ def build_point_in_time_query(
267282
{{ featureview.event_timestamp_column }} as event_timestamp,
268283
{{ featureview.event_timestamp_column }} as {{ featureview.name }}_feature_timestamp,
269284
{{ featureview.created_timestamp_column }} as created_timestamp,
270-
{{ featureview.entities | join(', ')}},
285+
{{ featureview.entity_selections | join(', ')}},
271286
false AS is_entity_table
272287
FROM {{ featureview.table_subquery }} WHERE {{ featureview.event_timestamp_column }} <= '{{ max_timestamp }}'
273288
{% if featureview.ttl == 0 %}{% else %}AND {{ featureview.event_timestamp_column }} >= Timestamp_sub(TIMESTAMP '{{ min_timestamp }}', interval {{ featureview.ttl }} second){% endif %}
@@ -308,7 +323,7 @@ def build_point_in_time_query(
308323
SELECT
309324
{{ featureview.event_timestamp_column }} as {{ featureview.name }}_feature_timestamp,
310325
{{ featureview.created_timestamp_column }} as created_timestamp,
311-
{{ featureview.entities | join(', ')}},
326+
{{ featureview.entity_selections | join(', ')}},
312327
{% for feature in featureview.features %}
313328
{{ feature }} as {{ featureview.name }}__{{ feature }}{% if loop.last %}{% else %}, {% endif %}
314329
{% endfor %}

sdk/python/feast/infra/offline_stores/file.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from feast.infra.provider import (
1212
ENTITY_DF_EVENT_TIMESTAMP_COL,
1313
_get_requested_feature_views_to_features_dict,
14+
_run_field_mapping,
1415
)
1516
from feast.registry import Registry
1617
from feast.repo_config import RepoConfig
@@ -55,6 +56,10 @@ def get_historical_features(
5556
# Create lazy function that is only called from the RetrievalJob object
5657
def evaluate_historical_retrieval():
5758

59+
# Make sure all event timestamp fields are tz-aware. We default tz-naive fields to UTC
60+
entity_df[ENTITY_DF_EVENT_TIMESTAMP_COL] = entity_df[
61+
ENTITY_DF_EVENT_TIMESTAMP_COL
62+
].apply(lambda x: x if x.tz is not None else x.replace(tzinfo=pytz.utc))
5863
# Sort entity dataframe prior to join, and create a copy to prevent modifying the original
5964
entity_df_with_features = entity_df.sort_values(
6065
ENTITY_DF_EVENT_TIMESTAMP_COL
@@ -65,10 +70,29 @@ def evaluate_historical_retrieval():
6570
event_timestamp_column = feature_view.input.event_timestamp_column
6671
created_timestamp_column = feature_view.input.created_timestamp_column
6772

68-
# Read dataframe to join to entity dataframe
69-
df_to_join = pd.read_parquet(feature_view.input.path).sort_values(
73+
# Read offline parquet data in pyarrow format
74+
table = pyarrow.parquet.read_table(feature_view.input.path)
75+
76+
# Rename columns by the field mapping dictionary if it exists
77+
if feature_view.input.field_mapping is not None:
78+
table = _run_field_mapping(table, feature_view.input.field_mapping)
79+
80+
# Convert pyarrow table to pandas dataframe
81+
df_to_join = table.to_pandas()
82+
83+
# Make sure all timestamp fields are tz-aware. We default tz-naive fields to UTC
84+
df_to_join[event_timestamp_column] = df_to_join[
7085
event_timestamp_column
71-
)
86+
].apply(lambda x: x if x.tz is not None else x.replace(tzinfo=pytz.utc))
87+
if created_timestamp_column:
88+
df_to_join[created_timestamp_column] = df_to_join[
89+
created_timestamp_column
90+
].apply(
91+
lambda x: x if x.tz is not None else x.replace(tzinfo=pytz.utc)
92+
)
93+
94+
# Sort dataframe by the event timestamp column
95+
df_to_join = df_to_join.sort_values(event_timestamp_column)
7296

7397
# Build a list of all the features we should select from this source
7498
feature_names = []

sdk/python/tests/test_materialize.py renamed to sdk/python/tests/test_offline_online_store_consistency.py

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def create_dataset() -> pd.DataFrame:
4747
def get_feature_view(data_source: Union[FileSource, BigQuerySource]) -> FeatureView:
4848
return FeatureView(
4949
name="test_bq_correctness",
50-
entities=["driver_id"],
50+
entities=["driver"],
5151
features=[Feature("value", ValueType.FLOAT)],
5252
ttl=timedelta(days=5),
5353
input=data_source,
@@ -83,20 +83,20 @@ def prep_bq_fs_and_fv(
8383
event_timestamp_column="ts",
8484
created_timestamp_column="created_ts",
8585
date_partition_column="",
86-
field_mapping={"ts_1": "ts", "id": "driver_ident"},
86+
field_mapping={"ts_1": "ts", "id": "driver_id"},
8787
)
8888

8989
fv = get_feature_view(bigquery_source)
9090
e = Entity(
91-
name="driver_id",
91+
name="driver",
9292
description="id for driver",
93-
join_key="driver_ident",
93+
join_key="driver_id",
9494
value_type=ValueType.INT32,
9595
)
9696
with tempfile.TemporaryDirectory() as repo_dir_name:
9797
config = RepoConfig(
9898
registry=str(Path(repo_dir_name) / "registry.db"),
99-
project=f"test_bq_correctness_{uuid.uuid4()}",
99+
project=f"test_bq_correctness_{str(uuid.uuid4()).replace('-', '')}",
100100
provider="gcp",
101101
)
102102
fs = FeatureStore(config=config)
@@ -121,7 +121,10 @@ def prep_local_fs_and_fv() -> Iterator[Tuple[FeatureStore, FeatureView]]:
121121
)
122122
fv = get_feature_view(file_source)
123123
e = Entity(
124-
name="driver_id", description="id for driver", value_type=ValueType.INT32
124+
name="driver",
125+
description="id for driver",
126+
join_key="driver_id",
127+
value_type=ValueType.INT32,
125128
)
126129
with tempfile.TemporaryDirectory() as repo_dir_name, tempfile.TemporaryDirectory() as data_dir_name:
127130
config = RepoConfig(
@@ -138,7 +141,34 @@ def prep_local_fs_and_fv() -> Iterator[Tuple[FeatureStore, FeatureView]]:
138141
yield fs, fv
139142

140143

141-
def run_materialization_test(fs: FeatureStore, fv: FeatureView) -> None:
144+
# Checks that both offline & online store values are as expected
145+
def check_offline_and_online_features(
146+
fs: FeatureStore,
147+
fv: FeatureView,
148+
driver_id: int,
149+
event_timestamp: datetime,
150+
expected_value: float,
151+
) -> None:
152+
# Check online store
153+
response_dict = fs.get_online_features(
154+
[f"{fv.name}:value"], [{"driver": driver_id}]
155+
).to_dict()
156+
assert abs(response_dict[f"{fv.name}__value"][0] - expected_value) < 1e-6
157+
158+
# Check offline store
159+
df = fs.get_historical_features(
160+
entity_df=pd.DataFrame.from_dict(
161+
{"driver_id": [driver_id], "event_timestamp": [event_timestamp]}
162+
),
163+
feature_refs=[f"{fv.name}:value"],
164+
).to_df()
165+
166+
assert abs(df.to_dict()[f"{fv.name}__value"][0] - expected_value) < 1e-6
167+
168+
169+
def run_offline_online_store_consistency_test(
170+
fs: FeatureStore, fv: FeatureView
171+
) -> None:
142172
now = datetime.utcnow()
143173
# Run materialize()
144174
# use both tz-naive & tz-aware timestamps to test that they're both correctly handled
@@ -147,38 +177,33 @@ def run_materialization_test(fs: FeatureStore, fv: FeatureView) -> None:
147177
fs.materialize(feature_views=[fv.name], start_date=start_date, end_date=end_date)
148178

149179
# check result of materialize()
150-
response_dict = fs.get_online_features(
151-
[f"{fv.name}:value"], [{"driver_id": 1}]
152-
).to_dict()
153-
assert abs(response_dict[f"{fv.name}__value"][0] - 0.3) < 1e-6
180+
check_offline_and_online_features(
181+
fs=fs, fv=fv, driver_id=1, event_timestamp=end_date, expected_value=0.3
182+
)
154183

155184
# check prior value for materialize_incremental()
156-
response_dict = fs.get_online_features(
157-
[f"{fv.name}:value"], [{"driver_id": 3}]
158-
).to_dict()
159-
assert abs(response_dict[f"{fv.name}__value"][0] - 4) < 1e-6
185+
check_offline_and_online_features(
186+
fs=fs, fv=fv, driver_id=3, event_timestamp=end_date, expected_value=4
187+
)
160188

161189
# run materialize_incremental()
162-
fs.materialize_incremental(
163-
feature_views=[fv.name], end_date=now - timedelta(seconds=0),
164-
)
190+
fs.materialize_incremental(feature_views=[fv.name], end_date=now)
165191

166192
# check result of materialize_incremental()
167-
response_dict = fs.get_online_features(
168-
[f"{fv.name}:value"], [{"driver_id": 3}]
169-
).to_dict()
170-
assert abs(response_dict[f"{fv.name}__value"][0] - 5) < 1e-6
193+
check_offline_and_online_features(
194+
fs=fs, fv=fv, driver_id=3, event_timestamp=now, expected_value=5
195+
)
171196

172197

173198
@pytest.mark.integration
174199
@pytest.mark.parametrize(
175200
"bq_source_type", ["query", "table"],
176201
)
177-
def test_bq_materialization(bq_source_type: str):
202+
def test_bq_offline_online_store_consistency(bq_source_type: str):
178203
with prep_bq_fs_and_fv(bq_source_type) as (fs, fv):
179-
run_materialization_test(fs, fv)
204+
run_offline_online_store_consistency_test(fs, fv)
180205

181206

182-
def test_local_materialization():
207+
def test_local_offline_online_store_consistency():
183208
with prep_local_fs_and_fv() as (fs, fv):
184-
run_materialization_test(fs, fv)
209+
run_offline_online_store_consistency_test(fs, fv)

0 commit comments

Comments
 (0)