|
9 | 9 | import pandas as pd
|
10 | 10 | import psutil
|
11 | 11 | from dateutil import parser
|
12 |
| -from fastapi import Depends, FastAPI, HTTPException, Request, Response, status |
| 12 | +from fastapi import Depends, FastAPI, Request, Response, status |
13 | 13 | from fastapi.logger import logger
|
| 14 | +from fastapi.responses import JSONResponse |
14 | 15 | from google.protobuf.json_format import MessageToDict
|
15 | 16 | from prometheus_client import Gauge, start_http_server
|
16 | 17 | from pydantic import BaseModel
|
|
19 | 20 | from feast import proto_json, utils
|
20 | 21 | from feast.constants import DEFAULT_FEATURE_SERVER_REGISTRY_TTL
|
21 | 22 | from feast.data_source import PushMode
|
22 |
| -from feast.errors import FeatureViewNotFoundException, PushSourceNotFoundException |
| 23 | +from feast.errors import ( |
| 24 | + FeastError, |
| 25 | + FeatureViewNotFoundException, |
| 26 | +) |
23 | 27 | from feast.permissions.action import WRITE, AuthzedAction
|
24 | 28 | from feast.permissions.security_manager import assert_permissions
|
25 | 29 | from feast.permissions.server.rest import inject_user_details
|
@@ -101,187 +105,163 @@ async def lifespan(app: FastAPI):
|
101 | 105 | async def get_body(request: Request):
|
102 | 106 | return await request.body()
|
103 | 107 |
|
104 |
| - # TODO RBAC: complete the dependencies for the other endpoints |
105 | 108 | @app.post(
|
106 | 109 | "/get-online-features",
|
107 | 110 | dependencies=[Depends(inject_user_details)],
|
108 | 111 | )
|
109 | 112 | def get_online_features(body=Depends(get_body)):
|
110 |
| - try: |
111 |
| - body = json.loads(body) |
112 |
| - full_feature_names = body.get("full_feature_names", False) |
113 |
| - entity_rows = body["entities"] |
114 |
| - # Initialize parameters for FeatureStore.get_online_features(...) call |
115 |
| - if "feature_service" in body: |
116 |
| - feature_service = store.get_feature_service( |
117 |
| - body["feature_service"], allow_cache=True |
| 113 | + body = json.loads(body) |
| 114 | + full_feature_names = body.get("full_feature_names", False) |
| 115 | + entity_rows = body["entities"] |
| 116 | + # Initialize parameters for FeatureStore.get_online_features(...) call |
| 117 | + if "feature_service" in body: |
| 118 | + feature_service = store.get_feature_service( |
| 119 | + body["feature_service"], allow_cache=True |
| 120 | + ) |
| 121 | + assert_permissions( |
| 122 | + resource=feature_service, actions=[AuthzedAction.READ_ONLINE] |
| 123 | + ) |
| 124 | + features = feature_service |
| 125 | + else: |
| 126 | + features = body["features"] |
| 127 | + all_feature_views, all_on_demand_feature_views = ( |
| 128 | + utils._get_feature_views_to_use( |
| 129 | + store.registry, |
| 130 | + store.project, |
| 131 | + features, |
| 132 | + allow_cache=True, |
| 133 | + hide_dummy_entity=False, |
118 | 134 | )
|
| 135 | + ) |
| 136 | + for feature_view in all_feature_views: |
119 | 137 | assert_permissions(
|
120 |
| - resource=feature_service, actions=[AuthzedAction.READ_ONLINE] |
| 138 | + resource=feature_view, actions=[AuthzedAction.READ_ONLINE] |
121 | 139 | )
|
122 |
| - features = feature_service |
123 |
| - else: |
124 |
| - features = body["features"] |
125 |
| - all_feature_views, all_on_demand_feature_views = ( |
126 |
| - utils._get_feature_views_to_use( |
127 |
| - store.registry, |
128 |
| - store.project, |
129 |
| - features, |
130 |
| - allow_cache=True, |
131 |
| - hide_dummy_entity=False, |
132 |
| - ) |
| 140 | + for od_feature_view in all_on_demand_feature_views: |
| 141 | + assert_permissions( |
| 142 | + resource=od_feature_view, actions=[AuthzedAction.READ_ONLINE] |
133 | 143 | )
|
134 |
| - for feature_view in all_feature_views: |
135 |
| - assert_permissions( |
136 |
| - resource=feature_view, actions=[AuthzedAction.READ_ONLINE] |
137 |
| - ) |
138 |
| - for od_feature_view in all_on_demand_feature_views: |
139 |
| - assert_permissions( |
140 |
| - resource=od_feature_view, actions=[AuthzedAction.READ_ONLINE] |
141 |
| - ) |
142 |
| - |
143 |
| - response_proto = store.get_online_features( |
144 |
| - features=features, |
145 |
| - entity_rows=entity_rows, |
146 |
| - full_feature_names=full_feature_names, |
147 |
| - ).proto |
148 |
| - |
149 |
| - # Convert the Protobuf object to JSON and return it |
150 |
| - return MessageToDict( |
151 |
| - response_proto, preserving_proto_field_name=True, float_precision=18 |
152 |
| - ) |
153 |
| - except Exception as e: |
154 |
| - # Print the original exception on the server side |
155 |
| - logger.exception(traceback.format_exc()) |
156 |
| - # Raise HTTPException to return the error message to the client |
157 |
| - raise HTTPException(status_code=500, detail=str(e)) |
| 144 | + |
| 145 | + response_proto = store.get_online_features( |
| 146 | + features=features, |
| 147 | + entity_rows=entity_rows, |
| 148 | + full_feature_names=full_feature_names, |
| 149 | + ).proto |
| 150 | + |
| 151 | + # Convert the Protobuf object to JSON and return it |
| 152 | + return MessageToDict( |
| 153 | + response_proto, preserving_proto_field_name=True, float_precision=18 |
| 154 | + ) |
158 | 155 |
|
159 | 156 | @app.post("/push", dependencies=[Depends(inject_user_details)])
|
160 | 157 | def push(body=Depends(get_body)):
|
161 |
| - try: |
162 |
| - request = PushFeaturesRequest(**json.loads(body)) |
163 |
| - df = pd.DataFrame(request.df) |
164 |
| - actions = [] |
165 |
| - if request.to == "offline": |
166 |
| - to = PushMode.OFFLINE |
167 |
| - actions = [AuthzedAction.WRITE_OFFLINE] |
168 |
| - elif request.to == "online": |
169 |
| - to = PushMode.ONLINE |
170 |
| - actions = [AuthzedAction.WRITE_ONLINE] |
171 |
| - elif request.to == "online_and_offline": |
172 |
| - to = PushMode.ONLINE_AND_OFFLINE |
173 |
| - actions = WRITE |
174 |
| - else: |
175 |
| - raise ValueError( |
176 |
| - f"{request.to} is not a supported push format. Please specify one of these ['online', 'offline', 'online_and_offline']." |
177 |
| - ) |
178 |
| - |
179 |
| - from feast.data_source import PushSource |
| 158 | + request = PushFeaturesRequest(**json.loads(body)) |
| 159 | + df = pd.DataFrame(request.df) |
| 160 | + actions = [] |
| 161 | + if request.to == "offline": |
| 162 | + to = PushMode.OFFLINE |
| 163 | + actions = [AuthzedAction.WRITE_OFFLINE] |
| 164 | + elif request.to == "online": |
| 165 | + to = PushMode.ONLINE |
| 166 | + actions = [AuthzedAction.WRITE_ONLINE] |
| 167 | + elif request.to == "online_and_offline": |
| 168 | + to = PushMode.ONLINE_AND_OFFLINE |
| 169 | + actions = WRITE |
| 170 | + else: |
| 171 | + raise ValueError( |
| 172 | + f"{request.to} is not a supported push format. Please specify one of these ['online', 'offline', 'online_and_offline']." |
| 173 | + ) |
180 | 174 |
|
181 |
| - all_fvs = store.list_feature_views( |
182 |
| - allow_cache=request.allow_registry_cache |
183 |
| - ) + store.list_stream_feature_views( |
184 |
| - allow_cache=request.allow_registry_cache |
| 175 | + from feast.data_source import PushSource |
| 176 | + |
| 177 | + all_fvs = store.list_feature_views( |
| 178 | + allow_cache=request.allow_registry_cache |
| 179 | + ) + store.list_stream_feature_views(allow_cache=request.allow_registry_cache) |
| 180 | + fvs_with_push_sources = { |
| 181 | + fv |
| 182 | + for fv in all_fvs |
| 183 | + if ( |
| 184 | + fv.stream_source is not None |
| 185 | + and isinstance(fv.stream_source, PushSource) |
| 186 | + and fv.stream_source.name == request.push_source_name |
185 | 187 | )
|
186 |
| - fvs_with_push_sources = { |
187 |
| - fv |
188 |
| - for fv in all_fvs |
189 |
| - if ( |
190 |
| - fv.stream_source is not None |
191 |
| - and isinstance(fv.stream_source, PushSource) |
192 |
| - and fv.stream_source.name == request.push_source_name |
193 |
| - ) |
194 |
| - } |
| 188 | + } |
195 | 189 |
|
196 |
| - for feature_view in fvs_with_push_sources: |
197 |
| - assert_permissions(resource=feature_view, actions=actions) |
| 190 | + for feature_view in fvs_with_push_sources: |
| 191 | + assert_permissions(resource=feature_view, actions=actions) |
198 | 192 |
|
199 |
| - store.push( |
200 |
| - push_source_name=request.push_source_name, |
201 |
| - df=df, |
202 |
| - allow_registry_cache=request.allow_registry_cache, |
203 |
| - to=to, |
204 |
| - ) |
205 |
| - except PushSourceNotFoundException as e: |
206 |
| - # Print the original exception on the server side |
207 |
| - logger.exception(traceback.format_exc()) |
208 |
| - # Raise HTTPException to return the error message to the client |
209 |
| - raise HTTPException(status_code=422, detail=str(e)) |
210 |
| - except Exception as e: |
211 |
| - # Print the original exception on the server side |
212 |
| - logger.exception(traceback.format_exc()) |
213 |
| - # Raise HTTPException to return the error message to the client |
214 |
| - raise HTTPException(status_code=500, detail=str(e)) |
| 193 | + store.push( |
| 194 | + push_source_name=request.push_source_name, |
| 195 | + df=df, |
| 196 | + allow_registry_cache=request.allow_registry_cache, |
| 197 | + to=to, |
| 198 | + ) |
215 | 199 |
|
216 | 200 | @app.post("/write-to-online-store", dependencies=[Depends(inject_user_details)])
|
217 | 201 | def write_to_online_store(body=Depends(get_body)):
|
| 202 | + request = WriteToFeatureStoreRequest(**json.loads(body)) |
| 203 | + df = pd.DataFrame(request.df) |
| 204 | + feature_view_name = request.feature_view_name |
| 205 | + allow_registry_cache = request.allow_registry_cache |
218 | 206 | try:
|
219 |
| - request = WriteToFeatureStoreRequest(**json.loads(body)) |
220 |
| - df = pd.DataFrame(request.df) |
221 |
| - feature_view_name = request.feature_view_name |
222 |
| - allow_registry_cache = request.allow_registry_cache |
223 |
| - try: |
224 |
| - feature_view = store.get_stream_feature_view( |
225 |
| - feature_view_name, allow_registry_cache=allow_registry_cache |
226 |
| - ) |
227 |
| - except FeatureViewNotFoundException: |
228 |
| - feature_view = store.get_feature_view( |
229 |
| - feature_view_name, allow_registry_cache=allow_registry_cache |
230 |
| - ) |
231 |
| - |
232 |
| - assert_permissions( |
233 |
| - resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE] |
| 207 | + feature_view = store.get_stream_feature_view( |
| 208 | + feature_view_name, allow_registry_cache=allow_registry_cache |
234 | 209 | )
|
235 |
| - store.write_to_online_store( |
236 |
| - feature_view_name=feature_view_name, |
237 |
| - df=df, |
238 |
| - allow_registry_cache=allow_registry_cache, |
| 210 | + except FeatureViewNotFoundException: |
| 211 | + feature_view = store.get_feature_view( |
| 212 | + feature_view_name, allow_registry_cache=allow_registry_cache |
239 | 213 | )
|
240 |
| - except Exception as e: |
241 |
| - # Print the original exception on the server side |
242 |
| - logger.exception(traceback.format_exc()) |
243 |
| - # Raise HTTPException to return the error message to the client |
244 |
| - raise HTTPException(status_code=500, detail=str(e)) |
| 214 | + |
| 215 | + assert_permissions(resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE]) |
| 216 | + store.write_to_online_store( |
| 217 | + feature_view_name=feature_view_name, |
| 218 | + df=df, |
| 219 | + allow_registry_cache=allow_registry_cache, |
| 220 | + ) |
245 | 221 |
|
246 | 222 | @app.get("/health")
|
247 | 223 | def health():
|
248 | 224 | return Response(status_code=status.HTTP_200_OK)
|
249 | 225 |
|
250 | 226 | @app.post("/materialize", dependencies=[Depends(inject_user_details)])
|
251 | 227 | def materialize(body=Depends(get_body)):
|
252 |
| - try: |
253 |
| - request = MaterializeRequest(**json.loads(body)) |
254 |
| - for feature_view in request.feature_views: |
255 |
| - assert_permissions( |
256 |
| - resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE] |
257 |
| - ) |
258 |
| - store.materialize( |
259 |
| - utils.make_tzaware(parser.parse(request.start_ts)), |
260 |
| - utils.make_tzaware(parser.parse(request.end_ts)), |
261 |
| - request.feature_views, |
| 228 | + request = MaterializeRequest(**json.loads(body)) |
| 229 | + for feature_view in request.feature_views: |
| 230 | + assert_permissions( |
| 231 | + resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE] |
262 | 232 | )
|
263 |
| - except Exception as e: |
264 |
| - # Print the original exception on the server side |
265 |
| - logger.exception(traceback.format_exc()) |
266 |
| - # Raise HTTPException to return the error message to the client |
267 |
| - raise HTTPException(status_code=500, detail=str(e)) |
| 233 | + store.materialize( |
| 234 | + utils.make_tzaware(parser.parse(request.start_ts)), |
| 235 | + utils.make_tzaware(parser.parse(request.end_ts)), |
| 236 | + request.feature_views, |
| 237 | + ) |
268 | 238 |
|
269 | 239 | @app.post("/materialize-incremental", dependencies=[Depends(inject_user_details)])
|
270 | 240 | def materialize_incremental(body=Depends(get_body)):
|
271 |
| - try: |
272 |
| - request = MaterializeIncrementalRequest(**json.loads(body)) |
273 |
| - for feature_view in request.feature_views: |
274 |
| - assert_permissions( |
275 |
| - resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE] |
276 |
| - ) |
277 |
| - store.materialize_incremental( |
278 |
| - utils.make_tzaware(parser.parse(request.end_ts)), request.feature_views |
| 241 | + request = MaterializeIncrementalRequest(**json.loads(body)) |
| 242 | + for feature_view in request.feature_views: |
| 243 | + assert_permissions( |
| 244 | + resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE] |
| 245 | + ) |
| 246 | + store.materialize_incremental( |
| 247 | + utils.make_tzaware(parser.parse(request.end_ts)), request.feature_views |
| 248 | + ) |
| 249 | + |
| 250 | + @app.exception_handler(Exception) |
| 251 | + async def rest_exception_handler(request: Request, exc: Exception): |
| 252 | + # Print the original exception on the server side |
| 253 | + logger.exception(traceback.format_exc()) |
| 254 | + |
| 255 | + if isinstance(exc, FeastError): |
| 256 | + return JSONResponse( |
| 257 | + status_code=exc.http_status_code(), |
| 258 | + content=exc.to_error_detail(), |
| 259 | + ) |
| 260 | + else: |
| 261 | + return JSONResponse( |
| 262 | + status_code=500, |
| 263 | + content=str(exc), |
279 | 264 | )
|
280 |
| - except Exception as e: |
281 |
| - # Print the original exception on the server side |
282 |
| - logger.exception(traceback.format_exc()) |
283 |
| - # Raise HTTPException to return the error message to the client |
284 |
| - raise HTTPException(status_code=500, detail=str(e)) |
285 | 265 |
|
286 | 266 | return app
|
287 | 267 |
|
|
0 commit comments