Skip to content

Commit 66cf6a4

Browse files
authored
Refactor Environment class and DataSourceCreator API, and use fixtures for datasets and data sources (feast-dev#1790)
* Fix API cruft from DataSourceCreator Signed-off-by: Achal Shah <[email protected]> * Remove the need for get_prefixed_table_name Signed-off-by: Achal Shah <[email protected]> * major refactor Signed-off-by: Achal Shah <[email protected]> * move start time Signed-off-by: Achal Shah <[email protected]> * Remove one dimension of variation to be added in later Signed-off-by: Achal Shah <[email protected]> * Fix default Signed-off-by: Achal Shah <[email protected]> * Fixups Signed-off-by: Achal Shah <[email protected]> * Fixups Signed-off-by: Achal Shah <[email protected]> * Fix up tests Signed-off-by: Achal Shah <[email protected]> * Add retries to execute_redshift_statement_async Signed-off-by: Achal Shah <[email protected]> * Add retries to execute_redshift_statement_async Signed-off-by: Achal Shah <[email protected]> * refactoooor Signed-off-by: Achal Shah <[email protected]> * remove retries Signed-off-by: Achal Shah <[email protected]> * Remove provider variation since they don't really play a big role Signed-off-by: Achal Shah <[email protected]> * Session scoped cache for test datasets and skipping older tests whose functionality is present in other universal tests Signed-off-by: Achal Shah <[email protected]> * make format Signed-off-by: Achal Shah <[email protected]> * make format Signed-off-by: Achal Shah <[email protected]> * remove import Signed-off-by: Achal Shah <[email protected]> * fix merge Signed-off-by: Achal Shah <[email protected]> * Use an enum for the stopping procedure instead of the bools Signed-off-by: Achal Shah <[email protected]> * Fix refs Signed-off-by: Achal Shah <[email protected]> * fix step Signed-off-by: Achal Shah <[email protected]> * WIP fixes Signed-off-by: Achal Shah <[email protected]> * Fix for feature inferencing Signed-off-by: Achal Shah <[email protected]> * C901 '_python_value_to_proto_value' is too complex :( Signed-off-by: Achal Shah <[email protected]> * Split out construct_test_repo and construct_universal_test_repo Signed-off-by: Achal Shah <[email protected]> * remove import Signed-off-by: Achal Shah <[email protected]> * add unsafe_hash Signed-off-by: Achal Shah <[email protected]> * Update testrepoconfig Signed-off-by: Achal Shah <[email protected]> * Update testrepoconfig Signed-off-by: Achal Shah <[email protected]> * Remove kwargs from construct_universal_test_environment Signed-off-by: Achal Shah <[email protected]> * Remove unneeded method Signed-off-by: Achal Shah <[email protected]> * Docs Signed-off-by: Achal Shah <[email protected]> * Kill skipped tests Signed-off-by: Achal Shah <[email protected]> * reorder Signed-off-by: Achal Shah <[email protected]> * add todo Signed-off-by: Achal Shah <[email protected]> * Split universal vs non data_source_cache Signed-off-by: Achal Shah <[email protected]> * make format Signed-off-by: Achal Shah <[email protected]> * WIP fixtures Signed-off-by: Achal Shah <[email protected]> * WIP Trying fixtures more effectively Signed-off-by: Achal Shah <[email protected]> * fix refs Signed-off-by: Achal Shah <[email protected]> * Fix refs Signed-off-by: Achal Shah <[email protected]> * Fix refs Signed-off-by: Achal Shah <[email protected]> * Fix refs Signed-off-by: Achal Shah <[email protected]> * fix historical tests Signed-off-by: Achal Shah <[email protected]> * renames Signed-off-by: Achal Shah <[email protected]> * CR updates Signed-off-by: Achal Shah <[email protected]> * use the actual ref to data source creators Signed-off-by: Achal Shah <[email protected]> * format Signed-off-by: Achal Shah <[email protected]> * unused imports' Signed-off-by: Achal Shah <[email protected]> * Add ids for pytest params Signed-off-by: Achal Shah <[email protected]>
1 parent ef7200a commit 66cf6a4

File tree

20 files changed

+629
-1195
lines changed

20 files changed

+629
-1195
lines changed

sdk/python/feast/feature_view.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,28 @@ def infer_features_from_batch_source(self, config: RepoConfig):
325325
self.batch_source.created_timestamp_column,
326326
} | set(self.entities)
327327

328+
if (
329+
self.batch_source.event_timestamp_column
330+
in self.batch_source.field_mapping
331+
):
332+
columns_to_exclude.add(
333+
self.batch_source.field_mapping[
334+
self.batch_source.event_timestamp_column
335+
]
336+
)
337+
if (
338+
self.batch_source.created_timestamp_column
339+
in self.batch_source.field_mapping
340+
):
341+
columns_to_exclude.add(
342+
self.batch_source.field_mapping[
343+
self.batch_source.created_timestamp_column
344+
]
345+
)
346+
for e in self.entities:
347+
if e in self.batch_source.field_mapping:
348+
columns_to_exclude.add(self.batch_source.field_mapping[e])
349+
328350
for (
329351
col_name,
330352
col_datatype,
@@ -335,7 +357,7 @@ def infer_features_from_batch_source(self, config: RepoConfig):
335357
):
336358
feature_name = (
337359
self.batch_source.field_mapping[col_name]
338-
if col_name in self.batch_source.field_mapping.keys()
360+
if col_name in self.batch_source.field_mapping
339361
else col_name
340362
)
341363
self.features.append(

sdk/python/feast/infra/offline_stores/bigquery.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def pull_latest_from_table_or_query(
8686
)
8787
WHERE _feast_row = 1
8888
"""
89+
8990
return BigQueryRetrievalJob(query=query, client=client, config=config)
9091

9192
@staticmethod

sdk/python/feast/infra/provider.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,12 @@ def _get_column_names(
242242
reverse_field_mapping[col] if col in reverse_field_mapping.keys() else col
243243
for col in feature_names
244244
]
245+
246+
# We need to exclude join keys and timestamp columns from the list of features, after they are mapped to
247+
# their final column names via the `field_mapping` field of the source.
248+
_feature_names = set(feature_names) - set(join_keys)
249+
_feature_names = _feature_names - {event_timestamp_column, created_timestamp_column}
250+
feature_names = list(_feature_names)
245251
return (
246252
join_keys,
247253
feature_names,

sdk/python/feast/infra/utils/aws_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class RedshiftStatementNotFinishedError(Exception):
8282

8383

8484
@retry(
85-
wait=wait_exponential(multiplier=0.1, max=30),
85+
wait=wait_exponential(multiplier=1, max=30),
8686
retry=retry_if_exception_type(RedshiftStatementNotFinishedError),
8787
)
8888
def wait_for_redshift_statement(redshift_data_client, statement: dict) -> None:

sdk/python/feast/type_map.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
# limitations under the License.
1414

1515
import re
16+
from datetime import datetime
1617
from typing import Any, Dict, Union
1718

1819
import numpy as np
1920
import pandas as pd
2021
from google.protobuf.json_format import MessageToDict
22+
from google.protobuf.timestamp_pb2 import Timestamp
2123

2224
from feast.protos.feast.types.Value_pb2 import (
2325
BoolList,
@@ -104,6 +106,8 @@ def python_type_to_feast_value_type(
104106
"int8": ValueType.INT32,
105107
"bool": ValueType.BOOL,
106108
"timedelta": ValueType.UNIX_TIMESTAMP,
109+
"Timestamp": ValueType.UNIX_TIMESTAMP,
110+
"datetime": ValueType.UNIX_TIMESTAMP,
107111
"datetime64[ns]": ValueType.UNIX_TIMESTAMP,
108112
"datetime64[ns, tz]": ValueType.UNIX_TIMESTAMP,
109113
"category": ValueType.STRING,
@@ -160,7 +164,8 @@ def _type_err(item, dtype):
160164
raise ValueError(f'Value "{item}" is of type {type(item)} not of type {dtype}')
161165

162166

163-
def _python_value_to_proto_value(feast_value_type, value) -> ProtoValue:
167+
# TODO(achals): Simplify this method and remove the noqa.
168+
def _python_value_to_proto_value(feast_value_type, value) -> ProtoValue: # noqa: C901
164169
"""
165170
Converts a Python (native, pandas) value to a Feast Proto Value based
166171
on a provided value type
@@ -281,6 +286,10 @@ def _python_value_to_proto_value(feast_value_type, value) -> ProtoValue:
281286
elif feast_value_type == ValueType.INT64:
282287
return ProtoValue(int64_val=int(value))
283288
elif feast_value_type == ValueType.UNIX_TIMESTAMP:
289+
if isinstance(value, datetime):
290+
return ProtoValue(int64_val=int(value.timestamp()))
291+
elif isinstance(value, Timestamp):
292+
return ProtoValue(int64_val=int(value.ToSeconds()))
284293
return ProtoValue(int64_val=int(value))
285294
elif feast_value_type == ValueType.FLOAT:
286295
return ProtoValue(float_val=float(value))

sdk/python/tests/conftest.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@
1818
import pandas as pd
1919
import pytest
2020

21+
from tests.data.data_creator import create_dataset
22+
from tests.integration.feature_repos.repo_configuration import (
23+
FULL_REPO_CONFIGS,
24+
Environment,
25+
construct_test_environment,
26+
construct_universal_data_sources,
27+
construct_universal_datasets,
28+
construct_universal_entities,
29+
)
30+
2131

2232
def pytest_configure(config):
2333
if platform in ["darwin", "windows"]:
@@ -87,3 +97,44 @@ def simple_dataset_2() -> pd.DataFrame:
8797
],
8898
}
8999
return pd.DataFrame.from_dict(data)
100+
101+
102+
@pytest.fixture(
103+
params=FULL_REPO_CONFIGS, scope="session", ids=[str(c) for c in FULL_REPO_CONFIGS]
104+
)
105+
def environment(request):
106+
with construct_test_environment(request.param) as e:
107+
yield e
108+
109+
110+
@pytest.fixture(scope="session")
111+
def universal_data_sources(environment):
112+
entities = construct_universal_entities()
113+
datasets = construct_universal_datasets(
114+
entities, environment.start_date, environment.end_date
115+
)
116+
datasources = construct_universal_data_sources(
117+
datasets, environment.data_source_creator
118+
)
119+
120+
yield entities, datasets, datasources
121+
122+
environment.data_source_creator.teardown()
123+
124+
125+
@pytest.fixture(scope="session")
126+
def e2e_data_sources(environment: Environment):
127+
df = create_dataset()
128+
data_source = environment.data_source_creator.create_data_source(
129+
df, environment.feature_store.project, field_mapping={"ts_1": "ts"},
130+
)
131+
132+
yield df, data_source
133+
134+
environment.data_source_creator.teardown()
135+
136+
137+
@pytest.fixture(params=FULL_REPO_CONFIGS, scope="session")
138+
def type_test_environment(request):
139+
with construct_test_environment(request.param) as e:
140+
yield e

sdk/python/tests/integration/e2e/test_universal_e2e.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,21 @@
33
from typing import Optional
44

55
import pandas as pd
6+
import pytest
67
from pytz import utc
78

89
from feast import FeatureStore, FeatureView
9-
from tests.integration.feature_repos.test_repo_configuration import (
10-
Environment,
11-
parametrize_e2e_test,
12-
)
1310
from tests.integration.feature_repos.universal.entities import driver
1411
from tests.integration.feature_repos.universal.feature_views import driver_feature_view
1512

1613

17-
@parametrize_e2e_test
18-
def test_e2e_consistency(test_environment: Environment):
19-
fs, fv = (
20-
test_environment.feature_store,
21-
driver_feature_view(test_environment.data_source),
22-
)
14+
@pytest.mark.integration
15+
@pytest.mark.parametrize("infer_features", [True, False])
16+
def test_e2e_consistency(environment, e2e_data_sources, infer_features):
17+
fs = environment.feature_store
18+
df, data_source = e2e_data_sources
19+
fv = driver_feature_view(data_source=data_source, infer_features=infer_features)
20+
2321
entity = driver()
2422
fs.apply([fv, entity])
2523

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import tempfile
2+
import uuid
3+
from contextlib import contextmanager
4+
from dataclasses import dataclass, field
5+
from datetime import datetime, timedelta
6+
from pathlib import Path
7+
from typing import Any, Dict, List, Optional, Type, Union
8+
9+
import pandas as pd
10+
11+
from feast import FeatureStore, FeatureView, RepoConfig, driver_test_data
12+
from feast.data_source import DataSource
13+
from tests.integration.feature_repos.universal.data_source_creator import (
14+
DataSourceCreator,
15+
)
16+
from tests.integration.feature_repos.universal.data_sources.bigquery import (
17+
BigQueryDataSourceCreator,
18+
)
19+
from tests.integration.feature_repos.universal.data_sources.file import (
20+
FileDataSourceCreator,
21+
)
22+
from tests.integration.feature_repos.universal.data_sources.redshift import (
23+
RedshiftDataSourceCreator,
24+
)
25+
from tests.integration.feature_repos.universal.feature_views import (
26+
create_customer_daily_profile_feature_view,
27+
create_driver_hourly_stats_feature_view,
28+
)
29+
30+
31+
@dataclass(frozen=True, repr=True)
32+
class IntegrationTestRepoConfig:
33+
"""
34+
This class should hold all possible parameters that may need to be varied by individual tests.
35+
"""
36+
37+
provider: str = "local"
38+
online_store: Union[str, Dict] = "sqlite"
39+
40+
offline_store_creator: Type[DataSourceCreator] = FileDataSourceCreator
41+
42+
full_feature_names: bool = True
43+
infer_event_timestamp_col: bool = True
44+
infer_features: bool = False
45+
46+
47+
DYNAMO_CONFIG = {"type": "dynamodb", "region": "us-west-2"}
48+
REDIS_CONFIG = {"type": "redis", "connection_string": "localhost:6379,db=0"}
49+
FULL_REPO_CONFIGS: List[IntegrationTestRepoConfig] = [
50+
# Local configurations
51+
IntegrationTestRepoConfig(),
52+
IntegrationTestRepoConfig(online_store=REDIS_CONFIG),
53+
# GCP configurations
54+
IntegrationTestRepoConfig(
55+
provider="gcp",
56+
offline_store_creator=BigQueryDataSourceCreator,
57+
online_store="datastore",
58+
),
59+
IntegrationTestRepoConfig(
60+
provider="gcp",
61+
offline_store_creator=BigQueryDataSourceCreator,
62+
online_store=REDIS_CONFIG,
63+
),
64+
# AWS configurations
65+
IntegrationTestRepoConfig(
66+
provider="aws",
67+
offline_store_creator=RedshiftDataSourceCreator,
68+
online_store=DYNAMO_CONFIG,
69+
),
70+
IntegrationTestRepoConfig(
71+
provider="aws",
72+
offline_store_creator=RedshiftDataSourceCreator,
73+
online_store=REDIS_CONFIG,
74+
),
75+
]
76+
77+
78+
def construct_universal_entities() -> Dict[str, List[Any]]:
79+
return {"customer": list(range(1001, 1110)), "driver": list(range(5001, 5110))}
80+
81+
82+
def construct_universal_datasets(
83+
entities: Dict[str, List[Any]], start_time: datetime, end_time: datetime
84+
) -> Dict[str, pd.DataFrame]:
85+
customer_df = driver_test_data.create_customer_daily_profile_df(
86+
entities["customer"], start_time, end_time
87+
)
88+
driver_df = driver_test_data.create_driver_hourly_stats_df(
89+
entities["driver"], start_time, end_time
90+
)
91+
orders_df = driver_test_data.create_orders_df(
92+
customers=entities["customer"],
93+
drivers=entities["driver"],
94+
start_date=end_time - timedelta(days=365),
95+
end_date=end_time + timedelta(days=365),
96+
order_count=1000,
97+
)
98+
99+
return {"customer": customer_df, "driver": driver_df, "orders": orders_df}
100+
101+
102+
def construct_universal_data_sources(
103+
datasets: Dict[str, pd.DataFrame], data_source_creator: DataSourceCreator
104+
) -> Dict[str, DataSource]:
105+
customer_ds = data_source_creator.create_data_source(
106+
datasets["customer"],
107+
destination_name="customer_profile",
108+
event_timestamp_column="event_timestamp",
109+
created_timestamp_column="created",
110+
)
111+
driver_ds = data_source_creator.create_data_source(
112+
datasets["driver"],
113+
destination_name="driver_hourly",
114+
event_timestamp_column="event_timestamp",
115+
created_timestamp_column="created",
116+
)
117+
orders_ds = data_source_creator.create_data_source(
118+
datasets["orders"],
119+
destination_name="orders",
120+
event_timestamp_column="event_timestamp",
121+
created_timestamp_column="created",
122+
)
123+
return {"customer": customer_ds, "driver": driver_ds, "orders": orders_ds}
124+
125+
126+
def construct_universal_feature_views(
127+
data_sources: Dict[str, DataSource],
128+
) -> Dict[str, FeatureView]:
129+
return {
130+
"customer": create_customer_daily_profile_feature_view(
131+
data_sources["customer"]
132+
),
133+
"driver": create_driver_hourly_stats_feature_view(data_sources["driver"]),
134+
}
135+
136+
137+
@dataclass
138+
class Environment:
139+
name: str
140+
test_repo_config: IntegrationTestRepoConfig
141+
feature_store: FeatureStore
142+
data_source_creator: DataSourceCreator
143+
144+
end_date: datetime = field(
145+
default=datetime.now().replace(microsecond=0, second=0, minute=0)
146+
)
147+
148+
def __post_init__(self):
149+
self.start_date: datetime = self.end_date - timedelta(days=7)
150+
151+
152+
def table_name_from_data_source(ds: DataSource) -> Optional[str]:
153+
if hasattr(ds, "table_ref"):
154+
return ds.table_ref
155+
elif hasattr(ds, "table"):
156+
return ds.table
157+
return None
158+
159+
160+
@contextmanager
161+
def construct_test_environment(
162+
test_repo_config: IntegrationTestRepoConfig,
163+
test_suite_name: str = "integration_test",
164+
) -> Environment:
165+
project = f"{test_suite_name}_{str(uuid.uuid4()).replace('-', '')[:8]}"
166+
167+
offline_creator: DataSourceCreator = test_repo_config.offline_store_creator(project)
168+
169+
offline_store_config = offline_creator.create_offline_store_config()
170+
online_store = test_repo_config.online_store
171+
172+
with tempfile.TemporaryDirectory() as repo_dir_name:
173+
config = RepoConfig(
174+
registry=str(Path(repo_dir_name) / "registry.db"),
175+
project=project,
176+
provider=test_repo_config.provider,
177+
offline_store=offline_store_config,
178+
online_store=online_store,
179+
repo_path=repo_dir_name,
180+
)
181+
fs = FeatureStore(config=config)
182+
environment = Environment(
183+
name=project,
184+
test_repo_config=test_repo_config,
185+
feature_store=fs,
186+
data_source_creator=offline_creator,
187+
)
188+
189+
try:
190+
yield environment
191+
finally:
192+
fs.teardown()

0 commit comments

Comments
 (0)