11import uuid
22from datetime import date , datetime
3- from typing import Any , Dict , List , Optional , Tuple , Union
3+ from typing import Any , Dict , List , Literal , Optional , Tuple , Union
44
55import numpy as np
66import pandas as pd
77import pyarrow
8- from pydantic import StrictStr
9- from trino .auth import Authentication
8+ from pydantic import Field , FilePath , SecretStr , StrictBool , StrictStr , root_validator
9+ from trino .auth import (
10+ BasicAuthentication ,
11+ CertificateAuthentication ,
12+ JWTAuthentication ,
13+ KerberosAuthentication ,
14+ OAuth2Authentication ,
15+ )
1016
1117from feast .data_source import DataSource
1218from feast .errors import InvalidEntityType
3238from feast .usage import log_exceptions_and_usage
3339
3440
41+ class BasicAuthModel (FeastConfigBaseModel ):
42+ username : StrictStr
43+ password : SecretStr
44+
45+
46+ class KerberosAuthModel (FeastConfigBaseModel ):
47+ config : Optional [FilePath ] = Field (default = None , alias = "config-file" )
48+ service_name : Optional [StrictStr ] = Field (default = None , alias = "service-name" )
49+ mutual_authentication : StrictBool = Field (
50+ default = False , alias = "mutual-authentication"
51+ )
52+ force_preemptive : StrictBool = Field (default = False , alias = "force-preemptive" )
53+ hostname_override : Optional [StrictStr ] = Field (
54+ default = None , alias = "hostname-override"
55+ )
56+ sanitize_mutual_error_response : StrictBool = Field (
57+ default = True , alias = "sanitize-mutual-error-response"
58+ )
59+ principal : Optional [StrictStr ]
60+ delegate : StrictBool = False
61+ ca_bundle : Optional [FilePath ] = Field (default = None , alias = "ca-bundle-file" )
62+
63+
64+ class JWTAuthModel (FeastConfigBaseModel ):
65+ token : SecretStr
66+
67+
68+ class CertificateAuthModel (FeastConfigBaseModel ):
69+ cert : FilePath = Field (default = None , alias = "cert-file" )
70+ key : FilePath = Field (default = None , alias = "key-file" )
71+
72+
73+ CLASSES_BY_AUTH_TYPE = {
74+ "kerberos" : {
75+ "auth_model" : KerberosAuthModel ,
76+ "trino_auth" : KerberosAuthentication ,
77+ },
78+ "basic" : {
79+ "auth_model" : BasicAuthModel ,
80+ "trino_auth" : BasicAuthentication ,
81+ },
82+ "jwt" : {
83+ "auth_model" : JWTAuthModel ,
84+ "trino_auth" : JWTAuthentication ,
85+ },
86+ "oauth2" : {
87+ "auth_model" : None ,
88+ "trino_auth" : OAuth2Authentication ,
89+ },
90+ "certificate" : {
91+ "auth_model" : CertificateAuthModel ,
92+ "trino_auth" : CertificateAuthentication ,
93+ },
94+ }
95+
96+
97+ class AuthConfig (FeastConfigBaseModel ):
98+ type : Literal ["kerberos" , "basic" , "jwt" , "oauth2" , "certificate" ]
99+ config : Optional [Dict [StrictStr , Any ]]
100+
101+ @root_validator
102+ def config_only_nullable_for_oauth2 (cls , values ):
103+ auth_type = values ["type" ]
104+ auth_config = values ["config" ]
105+ if auth_type != "oauth2" and auth_config is None :
106+ raise ValueError (f"config cannot be null for auth type '{ auth_type } '" )
107+
108+ return values
109+
110+ def to_trino_auth (self ):
111+ auth_type = self .type
112+ trino_auth_cls = CLASSES_BY_AUTH_TYPE [auth_type ]["trino_auth" ]
113+
114+ if auth_type == "oauth2" :
115+ return trino_auth_cls ()
116+
117+ model_cls = CLASSES_BY_AUTH_TYPE [auth_type ]["auth_model" ]
118+ model = model_cls (** self .config )
119+ return trino_auth_cls (** model .dict ())
120+
121+
35122class TrinoOfflineStoreConfig (FeastConfigBaseModel ):
36123 """Online store config for Trino"""
37124
@@ -47,6 +134,23 @@ class TrinoOfflineStoreConfig(FeastConfigBaseModel):
47134 catalog : StrictStr
48135 """ Catalog of the Trino cluster """
49136
137+ user : StrictStr
138+ """ User of the Trino cluster """
139+
140+ source : Optional [StrictStr ] = "trino-python-client"
141+ """ ID of the feast's Trino Python client, useful for debugging """
142+
143+ http_scheme : Literal ["http" , "https" ] = Field (default = "http" , alias = "http-scheme" )
144+ """ HTTP scheme that should be used while establishing a connection to the Trino cluster """
145+
146+ verify : StrictBool = Field (default = True , alias = "ssl-verify" )
147+ """ Whether the SSL certificate emited by the Trino cluster should be verified or not """
148+
149+ extra_credential : Optional [StrictStr ] = Field (
150+ default = None , alias = "x-trino-extra-credential-header"
151+ )
152+ """ Specifies the HTTP header X-Trino-Extra-Credential, e.g. user1=pwd1, user2=pwd2 """
153+
50154 connector : Dict [str , str ]
51155 """
52156 Trino connector to use as well as potential extra parameters.
@@ -59,6 +163,16 @@ class TrinoOfflineStoreConfig(FeastConfigBaseModel):
59163 dataset : StrictStr = "feast"
60164 """ (optional) Trino Dataset name for temporary tables """
61165
166+ auth : Optional [AuthConfig ]
167+ """
168+ (optional) Authentication mechanism to use when connecting to Trino. Supported options are:
169+ - kerberos
170+ - basic
171+ - jwt
172+ - oauth2
173+ - certificate
174+ """
175+
62176
63177class TrinoRetrievalJob (RetrievalJob ):
64178 def __init__ (
@@ -162,9 +276,6 @@ def pull_latest_from_table_or_query(
162276 created_timestamp_column : Optional [str ],
163277 start_date : datetime ,
164278 end_date : datetime ,
165- user : Optional [str ] = None ,
166- auth : Optional [Authentication ] = None ,
167- http_scheme : Optional [str ] = None ,
168279 ) -> TrinoRetrievalJob :
169280 assert isinstance (config .offline_store , TrinoOfflineStoreConfig )
170281 assert isinstance (data_source , TrinoSource )
@@ -181,9 +292,7 @@ def pull_latest_from_table_or_query(
181292 timestamps .append (created_timestamp_column )
182293 timestamp_desc_string = " DESC, " .join (timestamps ) + " DESC"
183294 field_string = ", " .join (join_key_columns + feature_name_columns + timestamps )
184- client = _get_trino_client (
185- config = config , user = user , auth = auth , http_scheme = http_scheme
186- )
295+ client = _get_trino_client (config = config )
187296
188297 query = f"""
189298 SELECT
@@ -216,17 +325,12 @@ def get_historical_features(
216325 registry : Registry ,
217326 project : str ,
218327 full_feature_names : bool = False ,
219- user : Optional [str ] = None ,
220- auth : Optional [Authentication ] = None ,
221- http_scheme : Optional [str ] = None ,
222328 ) -> TrinoRetrievalJob :
223329 assert isinstance (config .offline_store , TrinoOfflineStoreConfig )
224330 for fv in feature_views :
225331 assert isinstance (fv .batch_source , TrinoSource )
226332
227- client = _get_trino_client (
228- config = config , user = user , auth = auth , http_scheme = http_scheme
229- )
333+ client = _get_trino_client (config = config )
230334
231335 table_reference = _get_table_reference_for_new_entity (
232336 catalog = config .offline_store .catalog ,
@@ -307,17 +411,12 @@ def pull_all_from_table_or_query(
307411 timestamp_field : str ,
308412 start_date : datetime ,
309413 end_date : datetime ,
310- user : Optional [str ] = None ,
311- auth : Optional [Authentication ] = None ,
312- http_scheme : Optional [str ] = None ,
313414 ) -> RetrievalJob :
314415 assert isinstance (config .offline_store , TrinoOfflineStoreConfig )
315416 assert isinstance (data_source , TrinoSource )
316417 from_expression = data_source .get_table_query_string ()
317418
318- client = _get_trino_client (
319- config = config , user = user , auth = auth , http_scheme = http_scheme
320- )
419+ client = _get_trino_client (config = config )
321420 field_string = ", " .join (
322421 join_key_columns + feature_name_columns + [timestamp_field ]
323422 )
@@ -378,21 +477,22 @@ def _upload_entity_df_and_get_entity_schema(
378477 # TODO: Ensure that the table expires after some time
379478
380479
381- def _get_trino_client (
382- config : RepoConfig ,
383- user : Optional [str ],
384- auth : Optional [Any ],
385- http_scheme : Optional [str ],
386- ) -> Trino :
387- client = Trino (
388- user = user ,
389- catalog = config .offline_store .catalog ,
480+ def _get_trino_client (config : RepoConfig ) -> Trino :
481+ auth = None
482+ if config .offline_store .auth is not None :
483+ auth = config .offline_store .auth .to_trino_auth ()
484+
485+ return Trino (
390486 host = config .offline_store .host ,
391487 port = config .offline_store .port ,
488+ user = config .offline_store .user ,
489+ catalog = config .offline_store .catalog ,
490+ source = config .offline_store .source ,
491+ http_scheme = config .offline_store .http_scheme ,
492+ verify = config .offline_store .verify ,
493+ extra_credential = config .offline_store .extra_credential ,
392494 auth = auth ,
393- http_scheme = http_scheme ,
394495 )
395- return client
396496
397497
398498def _get_entity_df_event_timestamp_range (
0 commit comments