1+ import os
12import tempfile
3+ import uuid
24import warnings
35from datetime import datetime
46from typing import Any , Callable , Dict , List , Optional , Tuple , Union
1315from pyspark import SparkConf
1416from pyspark .sql import SparkSession
1517from pytz import utc
18+ from sdk .python .feast .infra .utils import aws_utils
1619
1720from feast import FeatureView , OnDemandFeatureView
1821from feast .data_source import DataSource
@@ -46,6 +49,12 @@ class SparkOfflineStoreConfig(FeastConfigBaseModel):
4649 """ Configuration overlay for the spark session """
4750 # sparksession is not serializable and we dont want to pass it around as an argument
4851
52+ staging_location : Optional [StrictStr ] = None
53+ """ Remote path for batch materialization jobs"""
54+
55+ region : Optional [StrictStr ] = None
56+ """ AWS Region if applicable for s3-based staging locations"""
57+
4958
5059class SparkOfflineStore (OfflineStore ):
5160 @staticmethod
@@ -105,6 +114,7 @@ def pull_latest_from_table_or_query(
105114 return SparkRetrievalJob (
106115 spark_session = spark_session ,
107116 query = query ,
117+ config = config ,
108118 full_feature_names = False ,
109119 on_demand_feature_views = None ,
110120 )
@@ -129,6 +139,7 @@ def get_historical_features(
129139 "Some functionality may still be unstable so functionality can change in the future." ,
130140 RuntimeWarning ,
131141 )
142+
132143 spark_session = get_spark_session_or_start_new_with_repoconfig (
133144 store_config = config .offline_store
134145 )
@@ -192,6 +203,7 @@ def get_historical_features(
192203 min_event_timestamp = entity_df_event_timestamp_range [0 ],
193204 max_event_timestamp = entity_df_event_timestamp_range [1 ],
194205 ),
206+ config = config ,
195207 )
196208
197209 @staticmethod
@@ -286,7 +298,10 @@ def pull_all_from_table_or_query(
286298 """
287299
288300 return SparkRetrievalJob (
289- spark_session = spark_session , query = query , full_feature_names = False
301+ spark_session = spark_session ,
302+ query = query ,
303+ full_feature_names = False ,
304+ config = config ,
290305 )
291306
292307
@@ -296,6 +311,7 @@ def __init__(
296311 spark_session : SparkSession ,
297312 query : str ,
298313 full_feature_names : bool ,
314+ config : RepoConfig ,
299315 on_demand_feature_views : Optional [List [OnDemandFeatureView ]] = None ,
300316 metadata : Optional [RetrievalMetadata ] = None ,
301317 ):
@@ -305,6 +321,7 @@ def __init__(
305321 self ._full_feature_names = full_feature_names
306322 self ._on_demand_feature_views = on_demand_feature_views or []
307323 self ._metadata = metadata
324+ self ._config = config
308325
309326 @property
310327 def full_feature_names (self ) -> bool :
@@ -342,6 +359,53 @@ def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
342359 raise ValueError ("Cannot persist, table_name is not defined" )
343360 self .to_spark_df ().createOrReplaceTempView (table_name )
344361
362+ def supports_remote_storage_export (self ) -> bool :
363+ return self ._config .offline_store .staging_location is not None
364+
365+ def to_remote_storage (self ) -> List [str ]:
366+ """Currently only works for local and s3-based staging locations"""
367+ if self .supports_remote_storage_export ():
368+
369+ sdf : pyspark .sql .DataFrame = self .to_spark_df ()
370+
371+ if self ._config .offline_store .staging_location .startswith ("file://" ):
372+ local_file_staging_location = os .path .abspath (
373+ self ._config .offline_store .staging_location
374+ )
375+
376+ # write to staging location
377+ output_uri = os .path .join (
378+ str (local_file_staging_location ), str (uuid .uuid4 ())
379+ )
380+ sdf .write .parquet (output_uri )
381+
382+ return _list_files_in_folder (output_uri )
383+ elif self ._config .offline_store .staging_location .startswith ("s3://" ):
384+
385+ spark_compatible_s3_staging_location = (
386+ self ._config .offline_store .staging_location .replace (
387+ "s3://" , "s3a://"
388+ )
389+ )
390+
391+ # write to staging location
392+ output_uri = os .path .join (
393+ str (spark_compatible_s3_staging_location ), str (uuid .uuid4 ())
394+ )
395+ sdf .write .parquet (output_uri )
396+
397+ return aws_utils .list_s3_files (
398+ self ._config .offline_store .region , output_uri
399+ )
400+
401+ else :
402+ raise NotImplementedError (
403+ "to_remote_storage is only implemented for file:// and s3:// uri schemes"
404+ )
405+
406+ else :
407+ raise NotImplementedError ()
408+
345409 @property
346410 def metadata (self ) -> Optional [RetrievalMetadata ]:
347411 """
@@ -444,6 +508,17 @@ def _format_datetime(t: datetime) -> str:
444508 return dt
445509
446510
511+ def _list_files_in_folder (folder ):
512+ """List full filenames in a folder"""
513+ files = []
514+ for file in os .listdir (folder ):
515+ filename = os .path .join (folder , file )
516+ if os .path .isfile (filename ):
517+ files .append (filename )
518+
519+ return files
520+
521+
447522def _cast_data_frame (
448523 df_new : pyspark .sql .DataFrame , df_existing : pyspark .sql .DataFrame
449524) -> pyspark .sql .DataFrame :
0 commit comments