Skip to content

Commit 01db8cc

Browse files
feat: Add get online feature rpc to gprc server (feast-dev#3815)
Signed-off-by: Hai Nguyen <quanghai.ng1512@gmail.com>
1 parent 0151961 commit 01db8cc

File tree

5 files changed

+90
-20
lines changed

5 files changed

+90
-20
lines changed

protos/feast/serving/GrpcServer.proto

+7-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
syntax = "proto3";
22

3+
import "feast/serving/ServingService.proto";
4+
35
message PushRequest {
46
map<string, string> features = 1;
57
string stream_feature_view = 2;
@@ -8,7 +10,7 @@ message PushRequest {
810
}
911

1012
message PushResponse {
11-
bool status = 1;
13+
bool status = 1;
1214
}
1315

1416
message WriteToOnlineStoreRequest {
@@ -18,10 +20,11 @@ message WriteToOnlineStoreRequest {
1820
}
1921

2022
message WriteToOnlineStoreResponse {
21-
bool status = 1;
23+
bool status = 1;
2224
}
2325

2426
service GrpcFeatureServer {
25-
rpc Push (PushRequest) returns (PushResponse) {};
26-
rpc WriteToOnlineStore (WriteToOnlineStoreRequest) returns (WriteToOnlineStoreResponse);
27+
rpc Push (PushRequest) returns (PushResponse) {};
28+
rpc WriteToOnlineStore (WriteToOnlineStoreRequest) returns (WriteToOnlineStoreResponse);
29+
rpc GetOnlineFeatures (feast.serving.GetOnlineFeaturesRequest) returns (feast.serving.GetOnlineFeaturesResponse);
2730
}

protos/feast/serving/ServingService.proto

+2
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ message GetOnlineFeaturesResponse {
105105
repeated FieldStatus statuses = 2;
106106
repeated google.protobuf.Timestamp event_timestamps = 3;
107107
}
108+
109+
bool status = 3;
108110
}
109111

110112
message GetOnlineFeaturesResponseMetadata {

sdk/python/feast/cli.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -705,15 +705,24 @@ def serve_command(
705705
show_default=False,
706706
help="The maximum number of threads that can be used to execute the gRPC calls",
707707
)
708+
@click.option(
709+
"--registry_ttl_sec",
710+
"-r",
711+
help="Number of seconds after which the registry is refreshed",
712+
type=click.INT,
713+
default=5,
714+
show_default=True,
715+
)
708716
@click.pass_context
709717
def listen_command(
710718
ctx: click.Context,
711719
address: str,
712720
max_workers: int,
721+
registry_ttl_sec: int,
713722
):
714723
"""Start a gRPC feature server to ingest streaming features on given address"""
715724
store = create_feature_store(ctx)
716-
server = get_grpc_server(address, store, max_workers)
725+
server = get_grpc_server(address, store, max_workers, registry_ttl_sec)
717726
server.start()
718727
server.wait_for_termination()
719728

sdk/python/feast/infra/contrib/grpc_server.py

+61-8
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import logging
2+
import threading
23
from concurrent import futures
4+
from typing import Optional
35

46
import grpc
57
import pandas as pd
68
from grpc_health.v1 import health, health_pb2_grpc
79

810
from feast.data_source import PushMode
9-
from feast.errors import PushSourceNotFoundException
11+
from feast.errors import FeatureServiceNotFoundException, PushSourceNotFoundException
1012
from feast.feature_store import FeatureStore
1113
from feast.protos.feast.serving.GrpcServer_pb2 import (
1214
PushResponse,
@@ -16,6 +18,12 @@
1618
GrpcFeatureServerServicer,
1719
add_GrpcFeatureServerServicer_to_server,
1820
)
21+
from feast.protos.feast.serving.ServingService_pb2 import (
22+
GetOnlineFeaturesRequest,
23+
GetOnlineFeaturesResponse,
24+
)
25+
26+
logger = logging.getLogger(__name__)
1927

2028

2129
def parse(features):
@@ -28,10 +36,16 @@ def parse(features):
2836
class GrpcFeatureServer(GrpcFeatureServerServicer):
2937
fs: FeatureStore
3038

31-
def __init__(self, fs: FeatureStore):
39+
_shuting_down: bool = False
40+
_active_timer: Optional[threading.Timer] = None
41+
42+
def __init__(self, fs: FeatureStore, registry_ttl_sec: int = 5):
3243
self.fs = fs
44+
self.registry_ttl_sec = registry_ttl_sec
3345
super().__init__()
3446

47+
self._async_refresh()
48+
3549
def Push(self, request, context):
3650
try:
3751
df = parse(request.features)
@@ -53,19 +67,19 @@ def Push(self, request, context):
5367
to=to,
5468
)
5569
except PushSourceNotFoundException as e:
56-
logging.exception(str(e))
70+
logger.exception(str(e))
5771
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
5872
context.set_details(str(e))
5973
return PushResponse(status=False)
6074
except Exception as e:
61-
logging.exception(str(e))
75+
logger.exception(str(e))
6276
context.set_code(grpc.StatusCode.INTERNAL)
6377
context.set_details(str(e))
6478
return PushResponse(status=False)
6579
return PushResponse(status=True)
6680

6781
def WriteToOnlineStore(self, request, context):
68-
logging.warning(
82+
logger.warning(
6983
"write_to_online_store is deprecated. Please consider using Push instead"
7084
)
7185
try:
@@ -76,16 +90,55 @@ def WriteToOnlineStore(self, request, context):
7690
allow_registry_cache=request.allow_registry_cache,
7791
)
7892
except Exception as e:
79-
logging.exception(str(e))
93+
logger.exception(str(e))
8094
context.set_code(grpc.StatusCode.INTERNAL)
8195
context.set_details(str(e))
8296
return PushResponse(status=False)
8397
return WriteToOnlineStoreResponse(status=True)
8498

99+
def GetOnlineFeatures(self, request: GetOnlineFeaturesRequest, context):
100+
if request.HasField("feature_service"):
101+
logger.info(f"Requesting feature service: {request.feature_service}")
102+
try:
103+
features = self.fs.get_feature_service(
104+
request.feature_service, allow_cache=True
105+
)
106+
except FeatureServiceNotFoundException as e:
107+
logger.error(f"Feature service {request.feature_service} not found")
108+
context.set_code(grpc.StatusCode.INTERNAL)
109+
context.set_details(str(e))
110+
return GetOnlineFeaturesResponse()
111+
else:
112+
features = list(request.features.val)
113+
114+
result = self.fs._get_online_features(
115+
features,
116+
request.entities,
117+
request.full_feature_names,
118+
).proto
119+
120+
return result
121+
122+
def _async_refresh(self):
123+
self.fs.refresh_registry()
124+
if self._shuting_down:
125+
return
126+
self._active_timer = threading.Timer(self.registry_ttl_sec, self._async_refresh)
127+
self._active_timer.start()
85128

86-
def get_grpc_server(address: str, fs: FeatureStore, max_workers: int):
129+
130+
def get_grpc_server(
131+
address: str,
132+
fs: FeatureStore,
133+
max_workers: int,
134+
registry_ttl_sec: int,
135+
):
136+
logger.info(f"Initializing gRPC server on {address}")
87137
server = grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers))
88-
add_GrpcFeatureServerServicer_to_server(GrpcFeatureServer(fs), server)
138+
add_GrpcFeatureServerServicer_to_server(
139+
GrpcFeatureServer(fs, registry_ttl_sec=registry_ttl_sec),
140+
server,
141+
)
89142
health_servicer = health.HealthServicer(
90143
experimental_non_blocking=True,
91144
experimental_thread_pool=futures.ThreadPoolExecutor(max_workers=max_workers),

sdk/python/feast/type_map.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -428,12 +428,15 @@ def _python_value_to_proto_value(
428428
for value in values
429429
]
430430
if feast_value_type in PYTHON_SCALAR_VALUE_TYPE_TO_PROTO_VALUE:
431-
return [
432-
ProtoValue(**{field_name: func(value)})
433-
if not pd.isnull(value)
434-
else ProtoValue()
435-
for value in values
436-
]
431+
out = []
432+
for value in values:
433+
if isinstance(value, ProtoValue):
434+
out.append(value)
435+
elif not pd.isnull(value):
436+
out.append(ProtoValue(**{field_name: func(value)}))
437+
else:
438+
out.append(ProtoValue())
439+
return out
437440

438441
raise Exception(f"Unsupported data type: ${str(type(values[0]))}")
439442

@@ -746,7 +749,7 @@ def spark_to_feast_value_type(spark_type_as_str: str) -> ValueType:
746749
"array<timestamp>": ValueType.UNIX_TIMESTAMP_LIST,
747750
}
748751
# TODO: Find better way of doing this.
749-
if type(spark_type_as_str) != str or spark_type_as_str not in type_map:
752+
if not isinstance(spark_type_as_str, str) or spark_type_as_str not in type_map:
750753
return ValueType.NULL
751754
return type_map[spark_type_as_str.lower()]
752755

0 commit comments

Comments
 (0)