Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions cluster/sdk/python/feast_spark/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ class ConfigOptions(metaclass=ConfigMeta):

#: Spark Job launcher. The choice of storage is connected to the choice of SPARK_LAUNCHER.
#:
#: Options: "standalone", "dataproc", "emr"
#: Options: "standalone", "dataproc", "emr", "databricks"
SPARK_LAUNCHER: Optional[str] = None

#: Feast Spark Job ingestion jobs staging location. The choice of storage is connected to the choice of SPARK_LAUNCHER.
#:
#: Eg. gs://some-bucket/output/, s3://some-bucket/output/, file:///data/subfolder/
#: Eg. gs://some-bucket/output/, s3://some-bucket/output/, file:///data/subfolder/, dbfs:/mnt/subfolder
SPARK_STAGING_LOCATION: Optional[str] = None

#: Feast Spark Job ingestion jar file. The choice of storage is connected to the choice of SPARK_LAUNCHER.
Expand Down Expand Up @@ -94,7 +94,7 @@ class ConfigOptions(metaclass=ConfigMeta):
SPARK_K8S_JOB_TEMPLATE_PATH = None

# Synapse dev url
AZURE_SYNAPSE_DEV_URL: Optional[str] = None
AZURE_SYNAPSE_DEV_URL: Optional[str] = None

# Synapse pool name
AZURE_SYNAPSE_POOL_NAME: Optional[str] = None
Expand All @@ -110,10 +110,29 @@ class ConfigOptions(metaclass=ConfigMeta):

# Azure EventHub Connection String (with Kafka API). See more details here:
# https://docs.microsoft.com/en-us/azure/event-hubs/apache-kafka-migration-guide
# Code Sample is here:
# Code Sample is here:
# https://github.com/Azure/azure-event-hubs-for-kafka/blob/master/tutorials/spark/sparkConsumer.scala
AZURE_EVENTHUB_KAFKA_CONNECTION_STRING = ""



# Databricks: Access Token
DATABRICKS_ACCESS_TOKEN: Optional[str] = None

# Databricks: Host (https included URL of the databricks workspace)
DATABRICKS_HOST_URL: Optional[str] = None

# Databricks: Common Cluster Id
DATABRICKS_COMMON_CLUSTER_ID: Optional[str] = None

# Databricks: Dedicated Streaming Cluster Id [Optional Dedicated Cluster for streaming use-cases]
DATABRICKS_STREAMING_CLUSTER_ID: Optional[str] = None

# Databricks: Maximum runs to retrieve
DATABRICKS_MAXIMUM_RUNS_TO_RETRIEVE: Optional[str] = None

# Databricks: Mounted Storage Path (Ex: /mnt/)
DATABRICKS_MOUNTED_STORAGE_PATH: Optional[str] = None

#: File format of historical retrieval features
HISTORICAL_FEATURE_OUTPUT_FORMAT: str = "parquet"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,10 +623,18 @@ def filter_feature_table_by_time_range(
def _read_and_verify_entity_df_from_source(
spark: SparkSession, source: Source
) -> DataFrame:
spark_path = source.spark_path

# Handle read for databricks mounted storage
if mounted_staging_location is not None:
print(f"Using Databricks Mounted path:{mounted_staging_location}")
relative_path = source.spark_path.rsplit("/", 1)[1] if not source.spark_path.endswith(
"/") else source.spark_path.rsplit("/", 2)[1]
spark_path = mounted_staging_location + relative_path
entity_df = (
spark.read.format(source.spark_format)
.options(**source.spark_read_options)
.load(source.spark_path)
.load(spark_path)
)

mapped_entity_df = _map_column(entity_df, source.field_mapping)
Expand Down Expand Up @@ -860,6 +868,10 @@ def _get_args():
parser.add_argument(
"--destination", type=str, help="Retrieval result destination in json string"
)
parser.add_argument(
"--mounted_staging_location", type=str, help="dbfs mounted staging path for verifying entity_source "
"(Only for databricks)", default=""
)
parser.add_argument("--checkpoint", type=str, help="Spark Checkpoint location")
return parser.parse_args()

Expand All @@ -882,13 +894,19 @@ def json_b64_decode(s: str) -> Any:
return json.loads(b64decode(s.encode("ascii")))


def b64_decode(obj: str) -> str:
return str(b64decode(obj.encode("ascii")).decode("ascii"))


if __name__ == "__main__":
spark = SparkSession.builder.getOrCreate()
args = _get_args()
feature_tables_conf = json_b64_decode(args.feature_tables)
feature_tables_sources_conf = json_b64_decode(args.feature_tables_sources)
entity_source_conf = json_b64_decode(args.entity_source)
destination_conf = json_b64_decode(args.destination)
mounted_staging_location = b64_decode(args.mounted_staging_location) if not b64_decode(
args.mounted_staging_location) == "" else None
if args.checkpoint:
spark.sparkContext.setCheckpointDir(args.checkpoint)

Expand All @@ -903,4 +921,7 @@ def json_b64_decode(s: str) -> Any:
except Exception as e:
logger.exception(e)
raise e
spark.stop()

# Databricks clusters do not allow this
if mounted_staging_location is None:
spark.stop()
21 changes: 21 additions & 0 deletions cluster/sdk/python/feast_spark/pyspark/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,34 @@ def _synapse_launcher(config: Config) -> JobLauncher:
executors=int(config.get(opt.AZURE_SYNAPSE_EXECUTORS))
)

def _databricks_launcher(config: Config) -> JobLauncher:
from feast_spark.pyspark.launchers import databricks

staging_location = config.get(opt.SPARK_STAGING_LOCATION)
staging_uri = urlparse(staging_location)

return databricks.DatabricksJobLauncher(
databricks_access_token=config.get(opt.DATABRICKS_ACCESS_TOKEN),
databricks_host_url=config.get(opt.DATABRICKS_HOST_URL),
staging_client=get_staging_client(staging_uri.scheme, config),
databricks_common_cluster_id=config.get(opt.DATABRICKS_COMMON_CLUSTER_ID),
databricks_streaming_cluster_id=config.get(opt.DATABRICKS_STREAMING_CLUSTER_ID, None),
databricks_max_active_jobs_to_retrieve=config.getint(opt.DATABRICKS_MAXIMUM_RUNS_TO_RETRIEVE, None),
mounted_staging_location=config.get(opt.DATABRICKS_MOUNTED_STORAGE_PATH, None),
azure_account_name=config.get(opt.AZURE_BLOB_ACCOUNT_NAME, None),
azure_account_key=config.get(opt.AZURE_BLOB_ACCOUNT_ACCESS_KEY, None),
staging_location=staging_location
)



_launchers = {
"standalone": _standalone_launcher,
"dataproc": _dataproc_launcher,
"emr": _emr_launcher,
"k8s": _k8s_launcher,
'synapse': _synapse_launcher,
'databricks': _databricks_launcher,
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .databricks import (
DatabricksBatchIngestionJob,
DatabricksJobLauncher,
DatabricksRetrievalJob,
DatabricksStreamIngestionJob,
)

__all__ = [
"DatabricksRetrievalJob",
"DatabricksBatchIngestionJob",
"DatabricksStreamIngestionJob",
"DatabricksJobLauncher",
]
Loading