|
9 | 9 | import pandas as pd |
10 | 10 | import psutil |
11 | 11 | from dateutil import parser |
12 | | -from fastapi import Depends, FastAPI, HTTPException, Request, Response, status |
| 12 | +from fastapi import Depends, FastAPI, Request, Response, status |
13 | 13 | from fastapi.logger import logger |
| 14 | +from fastapi.responses import JSONResponse |
14 | 15 | from google.protobuf.json_format import MessageToDict |
15 | 16 | from prometheus_client import Gauge, start_http_server |
16 | 17 | from pydantic import BaseModel |
|
19 | 20 | from feast import proto_json, utils |
20 | 21 | from feast.constants import DEFAULT_FEATURE_SERVER_REGISTRY_TTL |
21 | 22 | from feast.data_source import PushMode |
22 | | -from feast.errors import FeatureViewNotFoundException, PushSourceNotFoundException |
| 23 | +from feast.errors import ( |
| 24 | + FeastError, |
| 25 | + FeatureViewNotFoundException, |
| 26 | +) |
23 | 27 | from feast.permissions.action import WRITE, AuthzedAction |
24 | 28 | from feast.permissions.security_manager import assert_permissions |
25 | 29 | from feast.permissions.server.rest import inject_user_details |
@@ -101,187 +105,163 @@ async def lifespan(app: FastAPI): |
101 | 105 | async def get_body(request: Request): |
102 | 106 | return await request.body() |
103 | 107 |
|
104 | | - # TODO RBAC: complete the dependencies for the other endpoints |
105 | 108 | @app.post( |
106 | 109 | "/get-online-features", |
107 | 110 | dependencies=[Depends(inject_user_details)], |
108 | 111 | ) |
109 | 112 | def get_online_features(body=Depends(get_body)): |
110 | | - try: |
111 | | - body = json.loads(body) |
112 | | - full_feature_names = body.get("full_feature_names", False) |
113 | | - entity_rows = body["entities"] |
114 | | - # Initialize parameters for FeatureStore.get_online_features(...) call |
115 | | - if "feature_service" in body: |
116 | | - feature_service = store.get_feature_service( |
117 | | - body["feature_service"], allow_cache=True |
| 113 | + body = json.loads(body) |
| 114 | + full_feature_names = body.get("full_feature_names", False) |
| 115 | + entity_rows = body["entities"] |
| 116 | + # Initialize parameters for FeatureStore.get_online_features(...) call |
| 117 | + if "feature_service" in body: |
| 118 | + feature_service = store.get_feature_service( |
| 119 | + body["feature_service"], allow_cache=True |
| 120 | + ) |
| 121 | + assert_permissions( |
| 122 | + resource=feature_service, actions=[AuthzedAction.READ_ONLINE] |
| 123 | + ) |
| 124 | + features = feature_service |
| 125 | + else: |
| 126 | + features = body["features"] |
| 127 | + all_feature_views, all_on_demand_feature_views = ( |
| 128 | + utils._get_feature_views_to_use( |
| 129 | + store.registry, |
| 130 | + store.project, |
| 131 | + features, |
| 132 | + allow_cache=True, |
| 133 | + hide_dummy_entity=False, |
118 | 134 | ) |
| 135 | + ) |
| 136 | + for feature_view in all_feature_views: |
119 | 137 | assert_permissions( |
120 | | - resource=feature_service, actions=[AuthzedAction.READ_ONLINE] |
| 138 | + resource=feature_view, actions=[AuthzedAction.READ_ONLINE] |
121 | 139 | ) |
122 | | - features = feature_service |
123 | | - else: |
124 | | - features = body["features"] |
125 | | - all_feature_views, all_on_demand_feature_views = ( |
126 | | - utils._get_feature_views_to_use( |
127 | | - store.registry, |
128 | | - store.project, |
129 | | - features, |
130 | | - allow_cache=True, |
131 | | - hide_dummy_entity=False, |
132 | | - ) |
| 140 | + for od_feature_view in all_on_demand_feature_views: |
| 141 | + assert_permissions( |
| 142 | + resource=od_feature_view, actions=[AuthzedAction.READ_ONLINE] |
133 | 143 | ) |
134 | | - for feature_view in all_feature_views: |
135 | | - assert_permissions( |
136 | | - resource=feature_view, actions=[AuthzedAction.READ_ONLINE] |
137 | | - ) |
138 | | - for od_feature_view in all_on_demand_feature_views: |
139 | | - assert_permissions( |
140 | | - resource=od_feature_view, actions=[AuthzedAction.READ_ONLINE] |
141 | | - ) |
142 | | - |
143 | | - response_proto = store.get_online_features( |
144 | | - features=features, |
145 | | - entity_rows=entity_rows, |
146 | | - full_feature_names=full_feature_names, |
147 | | - ).proto |
148 | | - |
149 | | - # Convert the Protobuf object to JSON and return it |
150 | | - return MessageToDict( |
151 | | - response_proto, preserving_proto_field_name=True, float_precision=18 |
152 | | - ) |
153 | | - except Exception as e: |
154 | | - # Print the original exception on the server side |
155 | | - logger.exception(traceback.format_exc()) |
156 | | - # Raise HTTPException to return the error message to the client |
157 | | - raise HTTPException(status_code=500, detail=str(e)) |
| 144 | + |
| 145 | + response_proto = store.get_online_features( |
| 146 | + features=features, |
| 147 | + entity_rows=entity_rows, |
| 148 | + full_feature_names=full_feature_names, |
| 149 | + ).proto |
| 150 | + |
| 151 | + # Convert the Protobuf object to JSON and return it |
| 152 | + return MessageToDict( |
| 153 | + response_proto, preserving_proto_field_name=True, float_precision=18 |
| 154 | + ) |
158 | 155 |
|
159 | 156 | @app.post("/push", dependencies=[Depends(inject_user_details)]) |
160 | 157 | def push(body=Depends(get_body)): |
161 | | - try: |
162 | | - request = PushFeaturesRequest(**json.loads(body)) |
163 | | - df = pd.DataFrame(request.df) |
164 | | - actions = [] |
165 | | - if request.to == "offline": |
166 | | - to = PushMode.OFFLINE |
167 | | - actions = [AuthzedAction.WRITE_OFFLINE] |
168 | | - elif request.to == "online": |
169 | | - to = PushMode.ONLINE |
170 | | - actions = [AuthzedAction.WRITE_ONLINE] |
171 | | - elif request.to == "online_and_offline": |
172 | | - to = PushMode.ONLINE_AND_OFFLINE |
173 | | - actions = WRITE |
174 | | - else: |
175 | | - raise ValueError( |
176 | | - f"{request.to} is not a supported push format. Please specify one of these ['online', 'offline', 'online_and_offline']." |
177 | | - ) |
178 | | - |
179 | | - from feast.data_source import PushSource |
| 158 | + request = PushFeaturesRequest(**json.loads(body)) |
| 159 | + df = pd.DataFrame(request.df) |
| 160 | + actions = [] |
| 161 | + if request.to == "offline": |
| 162 | + to = PushMode.OFFLINE |
| 163 | + actions = [AuthzedAction.WRITE_OFFLINE] |
| 164 | + elif request.to == "online": |
| 165 | + to = PushMode.ONLINE |
| 166 | + actions = [AuthzedAction.WRITE_ONLINE] |
| 167 | + elif request.to == "online_and_offline": |
| 168 | + to = PushMode.ONLINE_AND_OFFLINE |
| 169 | + actions = WRITE |
| 170 | + else: |
| 171 | + raise ValueError( |
| 172 | + f"{request.to} is not a supported push format. Please specify one of these ['online', 'offline', 'online_and_offline']." |
| 173 | + ) |
180 | 174 |
|
181 | | - all_fvs = store.list_feature_views( |
182 | | - allow_cache=request.allow_registry_cache |
183 | | - ) + store.list_stream_feature_views( |
184 | | - allow_cache=request.allow_registry_cache |
| 175 | + from feast.data_source import PushSource |
| 176 | + |
| 177 | + all_fvs = store.list_feature_views( |
| 178 | + allow_cache=request.allow_registry_cache |
| 179 | + ) + store.list_stream_feature_views(allow_cache=request.allow_registry_cache) |
| 180 | + fvs_with_push_sources = { |
| 181 | + fv |
| 182 | + for fv in all_fvs |
| 183 | + if ( |
| 184 | + fv.stream_source is not None |
| 185 | + and isinstance(fv.stream_source, PushSource) |
| 186 | + and fv.stream_source.name == request.push_source_name |
185 | 187 | ) |
186 | | - fvs_with_push_sources = { |
187 | | - fv |
188 | | - for fv in all_fvs |
189 | | - if ( |
190 | | - fv.stream_source is not None |
191 | | - and isinstance(fv.stream_source, PushSource) |
192 | | - and fv.stream_source.name == request.push_source_name |
193 | | - ) |
194 | | - } |
| 188 | + } |
195 | 189 |
|
196 | | - for feature_view in fvs_with_push_sources: |
197 | | - assert_permissions(resource=feature_view, actions=actions) |
| 190 | + for feature_view in fvs_with_push_sources: |
| 191 | + assert_permissions(resource=feature_view, actions=actions) |
198 | 192 |
|
199 | | - store.push( |
200 | | - push_source_name=request.push_source_name, |
201 | | - df=df, |
202 | | - allow_registry_cache=request.allow_registry_cache, |
203 | | - to=to, |
204 | | - ) |
205 | | - except PushSourceNotFoundException as e: |
206 | | - # Print the original exception on the server side |
207 | | - logger.exception(traceback.format_exc()) |
208 | | - # Raise HTTPException to return the error message to the client |
209 | | - raise HTTPException(status_code=422, detail=str(e)) |
210 | | - except Exception as e: |
211 | | - # Print the original exception on the server side |
212 | | - logger.exception(traceback.format_exc()) |
213 | | - # Raise HTTPException to return the error message to the client |
214 | | - raise HTTPException(status_code=500, detail=str(e)) |
| 193 | + store.push( |
| 194 | + push_source_name=request.push_source_name, |
| 195 | + df=df, |
| 196 | + allow_registry_cache=request.allow_registry_cache, |
| 197 | + to=to, |
| 198 | + ) |
215 | 199 |
|
216 | 200 | @app.post("/write-to-online-store", dependencies=[Depends(inject_user_details)]) |
217 | 201 | def write_to_online_store(body=Depends(get_body)): |
| 202 | + request = WriteToFeatureStoreRequest(**json.loads(body)) |
| 203 | + df = pd.DataFrame(request.df) |
| 204 | + feature_view_name = request.feature_view_name |
| 205 | + allow_registry_cache = request.allow_registry_cache |
218 | 206 | try: |
219 | | - request = WriteToFeatureStoreRequest(**json.loads(body)) |
220 | | - df = pd.DataFrame(request.df) |
221 | | - feature_view_name = request.feature_view_name |
222 | | - allow_registry_cache = request.allow_registry_cache |
223 | | - try: |
224 | | - feature_view = store.get_stream_feature_view( |
225 | | - feature_view_name, allow_registry_cache=allow_registry_cache |
226 | | - ) |
227 | | - except FeatureViewNotFoundException: |
228 | | - feature_view = store.get_feature_view( |
229 | | - feature_view_name, allow_registry_cache=allow_registry_cache |
230 | | - ) |
231 | | - |
232 | | - assert_permissions( |
233 | | - resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE] |
| 207 | + feature_view = store.get_stream_feature_view( |
| 208 | + feature_view_name, allow_registry_cache=allow_registry_cache |
234 | 209 | ) |
235 | | - store.write_to_online_store( |
236 | | - feature_view_name=feature_view_name, |
237 | | - df=df, |
238 | | - allow_registry_cache=allow_registry_cache, |
| 210 | + except FeatureViewNotFoundException: |
| 211 | + feature_view = store.get_feature_view( |
| 212 | + feature_view_name, allow_registry_cache=allow_registry_cache |
239 | 213 | ) |
240 | | - except Exception as e: |
241 | | - # Print the original exception on the server side |
242 | | - logger.exception(traceback.format_exc()) |
243 | | - # Raise HTTPException to return the error message to the client |
244 | | - raise HTTPException(status_code=500, detail=str(e)) |
| 214 | + |
| 215 | + assert_permissions(resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE]) |
| 216 | + store.write_to_online_store( |
| 217 | + feature_view_name=feature_view_name, |
| 218 | + df=df, |
| 219 | + allow_registry_cache=allow_registry_cache, |
| 220 | + ) |
245 | 221 |
|
246 | 222 | @app.get("/health") |
247 | 223 | def health(): |
248 | 224 | return Response(status_code=status.HTTP_200_OK) |
249 | 225 |
|
250 | 226 | @app.post("/materialize", dependencies=[Depends(inject_user_details)]) |
251 | 227 | def materialize(body=Depends(get_body)): |
252 | | - try: |
253 | | - request = MaterializeRequest(**json.loads(body)) |
254 | | - for feature_view in request.feature_views: |
255 | | - assert_permissions( |
256 | | - resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE] |
257 | | - ) |
258 | | - store.materialize( |
259 | | - utils.make_tzaware(parser.parse(request.start_ts)), |
260 | | - utils.make_tzaware(parser.parse(request.end_ts)), |
261 | | - request.feature_views, |
| 228 | + request = MaterializeRequest(**json.loads(body)) |
| 229 | + for feature_view in request.feature_views: |
| 230 | + assert_permissions( |
| 231 | + resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE] |
262 | 232 | ) |
263 | | - except Exception as e: |
264 | | - # Print the original exception on the server side |
265 | | - logger.exception(traceback.format_exc()) |
266 | | - # Raise HTTPException to return the error message to the client |
267 | | - raise HTTPException(status_code=500, detail=str(e)) |
| 233 | + store.materialize( |
| 234 | + utils.make_tzaware(parser.parse(request.start_ts)), |
| 235 | + utils.make_tzaware(parser.parse(request.end_ts)), |
| 236 | + request.feature_views, |
| 237 | + ) |
268 | 238 |
|
269 | 239 | @app.post("/materialize-incremental", dependencies=[Depends(inject_user_details)]) |
270 | 240 | def materialize_incremental(body=Depends(get_body)): |
271 | | - try: |
272 | | - request = MaterializeIncrementalRequest(**json.loads(body)) |
273 | | - for feature_view in request.feature_views: |
274 | | - assert_permissions( |
275 | | - resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE] |
276 | | - ) |
277 | | - store.materialize_incremental( |
278 | | - utils.make_tzaware(parser.parse(request.end_ts)), request.feature_views |
| 241 | + request = MaterializeIncrementalRequest(**json.loads(body)) |
| 242 | + for feature_view in request.feature_views: |
| 243 | + assert_permissions( |
| 244 | + resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE] |
| 245 | + ) |
| 246 | + store.materialize_incremental( |
| 247 | + utils.make_tzaware(parser.parse(request.end_ts)), request.feature_views |
| 248 | + ) |
| 249 | + |
| 250 | + @app.exception_handler(Exception) |
| 251 | + async def rest_exception_handler(request: Request, exc: Exception): |
| 252 | + # Print the original exception on the server side |
| 253 | + logger.exception(traceback.format_exc()) |
| 254 | + |
| 255 | + if isinstance(exc, FeastError): |
| 256 | + return JSONResponse( |
| 257 | + status_code=exc.http_status_code(), |
| 258 | + content=exc.to_error_detail(), |
| 259 | + ) |
| 260 | + else: |
| 261 | + return JSONResponse( |
| 262 | + status_code=500, |
| 263 | + content=str(exc), |
279 | 264 | ) |
280 | | - except Exception as e: |
281 | | - # Print the original exception on the server side |
282 | | - logger.exception(traceback.format_exc()) |
283 | | - # Raise HTTPException to return the error message to the client |
284 | | - raise HTTPException(status_code=500, detail=str(e)) |
285 | 265 |
|
286 | 266 | return app |
287 | 267 |
|
|
0 commit comments