Skip to content

Commit cb27b24

Browse files
authored
Merge pull request feast-dev#6 from redhatHameed/update_client
added unit test for offline store remote client
2 parents 22afc10 + 889f89b commit cb27b24

File tree

3 files changed

+34
-5
lines changed

3 files changed

+34
-5
lines changed
Binary file not shown.

sdk/python/feast/infra/offline_stores/remote.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,13 @@ class RemoteOfflineStoreConfig(FeastConfigBaseModel):
3535
class RemoteRetrievalJob(RetrievalJob):
3636
def __init__(
3737
self,
38-
config: RepoConfig,
38+
client: fl.FlightClient,
3939
feature_refs: List[str],
4040
entity_df: Union[pd.DataFrame, str],
4141
# TODO add missing parameters from the OfflineStore API
4242
):
4343
# Initialize the client connection
44-
self.client = fl.connect(
45-
f"grpc://{config.offline_store.host}:{config.offline_store.port}"
46-
)
44+
self.client = client
4745
self.feature_refs = feature_refs
4846
self.entity_df = entity_df
4947

@@ -108,8 +106,14 @@ def get_historical_features(
108106
assert isinstance(config.offline_store, RemoteOfflineStoreConfig)
109107

110108
# TODO: extend RemoteRetrievalJob API with all method parameters
109+
110+
# Initialize the client connection
111+
client = fl.connect(
112+
f"grpc://{config.offline_store.host}:{config.offline_store.port}"
113+
)
114+
111115
return RemoteRetrievalJob(
112-
config=config, feature_refs=feature_refs, entity_df=entity_df
116+
client=client, feature_refs=feature_refs, entity_df=entity_df
113117
)
114118

115119
@staticmethod

sdk/python/tests/unit/infra/offline_stores/test_offline_store.py

+25
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030
RedshiftOfflineStoreConfig,
3131
RedshiftRetrievalJob,
3232
)
33+
from feast.infra.offline_stores.remote import (
34+
RemoteOfflineStoreConfig,
35+
RemoteRetrievalJob,
36+
)
3337
from feast.infra.offline_stores.snowflake import (
3438
SnowflakeOfflineStoreConfig,
3539
SnowflakeRetrievalJob,
@@ -105,6 +109,7 @@ def metadata(self) -> Optional[RetrievalMetadata]:
105109
PostgreSQLRetrievalJob,
106110
SparkRetrievalJob,
107111
TrinoRetrievalJob,
112+
RemoteRetrievalJob,
108113
]
109114
)
110115
def retrieval_job(request, environment):
@@ -206,6 +211,26 @@ def retrieval_job(request, environment):
206211
config=environment.test_repo_config,
207212
full_feature_names=False,
208213
)
214+
elif request.param is RemoteRetrievalJob:
215+
offline_store_config = RemoteOfflineStoreConfig(
216+
type="remote",
217+
host="localhost",
218+
port=0,
219+
)
220+
environment.test_repo_config.offline_store = offline_store_config
221+
return RemoteRetrievalJob(
222+
client=MagicMock(),
223+
feature_refs=[
224+
"str:str",
225+
],
226+
entity_df=pd.DataFrame.from_dict(
227+
{
228+
"id": [1],
229+
"event_timestamp": ["datetime"],
230+
"val_to_add": [1],
231+
}
232+
),
233+
)
209234
else:
210235
return request.param()
211236

0 commit comments

Comments
 (0)