Skip to content

Commit 2107ce2

Browse files
authored
feat: Add to_remote_storage functionality to SparkOfflineStore (feast-dev#3175)
implement to_remote_storage method Signed-off-by: niklasvm <[email protected]> Signed-off-by: niklasvm <[email protected]>
1 parent 7c50ab5 commit 2107ce2

File tree

2 files changed

+80
-1
lines changed

2 files changed

+80
-1
lines changed

sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import os
12
import tempfile
3+
import uuid
24
import warnings
35
from datetime import datetime
46
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@@ -13,6 +15,7 @@
1315
from pyspark import SparkConf
1416
from pyspark.sql import SparkSession
1517
from pytz import utc
18+
from sdk.python.feast.infra.utils import aws_utils
1619

1720
from feast import FeatureView, OnDemandFeatureView
1821
from feast.data_source import DataSource
@@ -46,6 +49,12 @@ class SparkOfflineStoreConfig(FeastConfigBaseModel):
4649
""" Configuration overlay for the spark session """
4750
# sparksession is not serializable and we dont want to pass it around as an argument
4851

52+
staging_location: Optional[StrictStr] = None
53+
""" Remote path for batch materialization jobs"""
54+
55+
region: Optional[StrictStr] = None
56+
""" AWS Region if applicable for s3-based staging locations"""
57+
4958

5059
class SparkOfflineStore(OfflineStore):
5160
@staticmethod
@@ -105,6 +114,7 @@ def pull_latest_from_table_or_query(
105114
return SparkRetrievalJob(
106115
spark_session=spark_session,
107116
query=query,
117+
config=config,
108118
full_feature_names=False,
109119
on_demand_feature_views=None,
110120
)
@@ -129,6 +139,7 @@ def get_historical_features(
129139
"Some functionality may still be unstable so functionality can change in the future.",
130140
RuntimeWarning,
131141
)
142+
132143
spark_session = get_spark_session_or_start_new_with_repoconfig(
133144
store_config=config.offline_store
134145
)
@@ -192,6 +203,7 @@ def get_historical_features(
192203
min_event_timestamp=entity_df_event_timestamp_range[0],
193204
max_event_timestamp=entity_df_event_timestamp_range[1],
194205
),
206+
config=config,
195207
)
196208

197209
@staticmethod
@@ -286,7 +298,10 @@ def pull_all_from_table_or_query(
286298
"""
287299

288300
return SparkRetrievalJob(
289-
spark_session=spark_session, query=query, full_feature_names=False
301+
spark_session=spark_session,
302+
query=query,
303+
full_feature_names=False,
304+
config=config,
290305
)
291306

292307

@@ -296,6 +311,7 @@ def __init__(
296311
spark_session: SparkSession,
297312
query: str,
298313
full_feature_names: bool,
314+
config: RepoConfig,
299315
on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None,
300316
metadata: Optional[RetrievalMetadata] = None,
301317
):
@@ -305,6 +321,7 @@ def __init__(
305321
self._full_feature_names = full_feature_names
306322
self._on_demand_feature_views = on_demand_feature_views or []
307323
self._metadata = metadata
324+
self._config = config
308325

309326
@property
310327
def full_feature_names(self) -> bool:
@@ -342,6 +359,53 @@ def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
342359
raise ValueError("Cannot persist, table_name is not defined")
343360
self.to_spark_df().createOrReplaceTempView(table_name)
344361

362+
def supports_remote_storage_export(self) -> bool:
363+
return self._config.offline_store.staging_location is not None
364+
365+
def to_remote_storage(self) -> List[str]:
366+
"""Currently only works for local and s3-based staging locations"""
367+
if self.supports_remote_storage_export():
368+
369+
sdf: pyspark.sql.DataFrame = self.to_spark_df()
370+
371+
if self._config.offline_store.staging_location.startswith("file://"):
372+
local_file_staging_location = os.path.abspath(
373+
self._config.offline_store.staging_location
374+
)
375+
376+
# write to staging location
377+
output_uri = os.path.join(
378+
str(local_file_staging_location), str(uuid.uuid4())
379+
)
380+
sdf.write.parquet(output_uri)
381+
382+
return _list_files_in_folder(output_uri)
383+
elif self._config.offline_store.staging_location.startswith("s3://"):
384+
385+
spark_compatible_s3_staging_location = (
386+
self._config.offline_store.staging_location.replace(
387+
"s3://", "s3a://"
388+
)
389+
)
390+
391+
# write to staging location
392+
output_uri = os.path.join(
393+
str(spark_compatible_s3_staging_location), str(uuid.uuid4())
394+
)
395+
sdf.write.parquet(output_uri)
396+
397+
return aws_utils.list_s3_files(
398+
self._config.offline_store.region, output_uri
399+
)
400+
401+
else:
402+
raise NotImplementedError(
403+
"to_remote_storage is only implemented for file:// and s3:// uri schemes"
404+
)
405+
406+
else:
407+
raise NotImplementedError()
408+
345409
@property
346410
def metadata(self) -> Optional[RetrievalMetadata]:
347411
"""
@@ -444,6 +508,17 @@ def _format_datetime(t: datetime) -> str:
444508
return dt
445509

446510

511+
def _list_files_in_folder(folder):
512+
"""List full filenames in a folder"""
513+
files = []
514+
for file in os.listdir(folder):
515+
filename = os.path.join(folder, file)
516+
if os.path.isfile(filename):
517+
files.append(filename)
518+
519+
return files
520+
521+
447522
def _cast_data_frame(
448523
df_new: pyspark.sql.DataFrame, df_existing: pyspark.sql.DataFrame
449524
) -> pyspark.sql.DataFrame:

sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ def create_offline_store_config(self):
5858
self.spark_offline_store_config = SparkOfflineStoreConfig()
5959
self.spark_offline_store_config.type = "spark"
6060
self.spark_offline_store_config.spark_conf = self.spark_conf
61+
self.spark_offline_store_config.staging_location = "file://" + str(
62+
tempfile.TemporaryDirectory()
63+
)
64+
self.spark_offline_store_config.region = "eu-west-1"
6165
return self.spark_offline_store_config
6266

6367
def create_data_source(

0 commit comments

Comments
 (0)