Skip to content

Commit 7df287e

Browse files
feat: Adding features field to retrieve_online_features to return mor… (feast-dev#4869)
1 parent dbc9207 commit 7df287e

File tree

11 files changed

+91
-41
lines changed

11 files changed

+91
-41
lines changed

sdk/python/feast/feature_store.py

+41-15
Original file line numberDiff line numberDiff line change
@@ -1753,9 +1753,10 @@ async def get_online_features_async(
17531753

17541754
def retrieve_online_documents(
17551755
self,
1756-
feature: str,
1756+
feature: Optional[str],
17571757
query: Union[str, List[float]],
17581758
top_k: int,
1759+
features: Optional[List[str]] = None,
17591760
distance_metric: Optional[str] = None,
17601761
) -> OnlineResponse:
17611762
"""
@@ -1765,6 +1766,7 @@ def retrieve_online_documents(
17651766
feature: The list of document features that should be retrieved from the online document store. These features can be
17661767
specified either as a list of string document feature references or as a feature service. String feature
17671768
references must have format "feature_view:feature", e.g, "document_fv:document_embeddings".
1769+
features: The list of features that should be retrieved from the online store.
17681770
query: The query to retrieve the closest document features for.
17691771
top_k: The number of closest document features to retrieve.
17701772
distance_metric: The distance metric to use for retrieval.
@@ -1773,18 +1775,44 @@ def retrieve_online_documents(
17731775
raise ValueError(
17741776
"Using embedding functionality is not supported for document retrieval. Please embed the query before calling retrieve_online_documents."
17751777
)
1778+
feature_list: List[str] = (
1779+
features
1780+
if features is not None
1781+
else ([feature] if feature is not None else [])
1782+
)
1783+
17761784
(
17771785
available_feature_views,
17781786
_,
17791787
) = utils._get_feature_views_to_use(
17801788
registry=self._registry,
17811789
project=self.project,
1782-
features=[feature],
1790+
features=feature_list,
17831791
allow_cache=True,
17841792
hide_dummy_entity=False,
17851793
)
1794+
if features:
1795+
feature_view_set = set()
1796+
for feature in features:
1797+
feature_view_name = feature.split(":")[0]
1798+
feature_view = self.get_feature_view(feature_view_name)
1799+
feature_view_set.add(feature_view.name)
1800+
if len(feature_view_set) > 1:
1801+
raise ValueError(
1802+
"Document retrieval only supports a single feature view."
1803+
)
1804+
requested_feature = None
1805+
requested_features = [
1806+
f.split(":")[1] for f in features if isinstance(f, str) and ":" in f
1807+
]
1808+
else:
1809+
requested_feature = (
1810+
feature.split(":")[1] if isinstance(feature, str) else feature
1811+
)
1812+
requested_features = [requested_feature] if requested_feature else []
1813+
17861814
requested_feature_view_name = (
1787-
feature.split(":")[0] if isinstance(feature, str) else feature
1815+
feature.split(":")[0] if feature else list(feature_view_set)[0]
17881816
)
17891817
for feature_view in available_feature_views:
17901818
if feature_view.name == requested_feature_view_name:
@@ -1793,14 +1821,15 @@ def retrieve_online_documents(
17931821
raise ValueError(
17941822
f"Feature view {requested_feature_view} not found in the registry."
17951823
)
1796-
requested_feature = (
1797-
feature.split(":")[1] if isinstance(feature, str) else feature
1798-
)
1824+
1825+
requested_feature_view = available_feature_views[0]
1826+
17991827
provider = self._get_provider()
18001828
document_features = self._retrieve_from_online_store(
18011829
provider,
18021830
requested_feature_view,
18031831
requested_feature,
1832+
requested_features,
18041833
query,
18051834
top_k,
18061835
distance_metric,
@@ -1822,6 +1851,7 @@ def retrieve_online_documents(
18221851
document_feature_vals = [feature[4] for feature in document_features]
18231852
document_feature_distance_vals = [feature[5] for feature in document_features]
18241853
online_features_response = GetOnlineFeaturesResponse(results=[])
1854+
requested_feature = requested_feature or requested_features[0]
18251855
utils._populate_result_rows_from_columnar(
18261856
online_features_response=online_features_response,
18271857
data={
@@ -1836,7 +1866,8 @@ def _retrieve_from_online_store(
18361866
self,
18371867
provider: Provider,
18381868
table: FeatureView,
1839-
requested_feature: str,
1869+
requested_feature: Optional[str],
1870+
requested_features: Optional[List[str]],
18401871
query: List[float],
18411872
top_k: int,
18421873
distance_metric: Optional[str],
@@ -1852,6 +1883,7 @@ def _retrieve_from_online_store(
18521883
config=self.config,
18531884
table=table,
18541885
requested_feature=requested_feature,
1886+
requested_features=requested_features,
18551887
query=query,
18561888
top_k=top_k,
18571889
distance_metric=distance_metric,
@@ -1952,19 +1984,13 @@ def serve_ui(
19521984
)
19531985

19541986
def serve_registry(
1955-
self,
1956-
port: int,
1957-
tls_key_path: str = "",
1958-
tls_cert_path: str = "",
1987+
self, port: int, tls_key_path: str = "", tls_cert_path: str = ""
19591988
) -> None:
19601989
"""Start registry server locally on a given port."""
19611990
from feast import registry_server
19621991

19631992
registry_server.start_server(
1964-
self,
1965-
port=port,
1966-
tls_key_path=tls_key_path,
1967-
tls_cert_path=tls_cert_path,
1993+
self, port=port, tls_key_path=tls_key_path, tls_cert_path=tls_cert_path
19681994
)
19691995

19701996
def serve_offline(

sdk/python/feast/infra/key_encoding_utils.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import struct
2-
from typing import List, Tuple
2+
from typing import List, Tuple, Union
3+
4+
from google.protobuf.internal.containers import RepeatedScalarFieldContainer
35

46
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
57
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
@@ -163,3 +165,16 @@ def get_list_val_str(val):
163165
if val.HasField(accept_type):
164166
return str(getattr(val, accept_type).val)
165167
return None
168+
169+
170+
def serialize_f32(
171+
vector: Union[RepeatedScalarFieldContainer[float], List[float]], vector_length: int
172+
) -> bytes:
173+
"""serializes a list of floats into a compact "raw bytes" format"""
174+
return struct.pack(f"{vector_length}f", *vector)
175+
176+
177+
def deserialize_f32(byte_vector: bytes, vector_length: int) -> List[float]:
178+
"""deserializes a list of floats from a compact "raw bytes" format"""
179+
num_floats = vector_length // 4 # 4 bytes per float
180+
return list(struct.unpack(f"{num_floats}f", byte_vector))

sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ def retrieve_online_documents(
213213
self,
214214
config: RepoConfig,
215215
table: FeatureView,
216-
requested_feature: str,
216+
requested_feature: Optional[str],
217+
requested_features: Optional[List[str]],
217218
embedding: List[float],
218219
top_k: int,
219220
*args,

sdk/python/feast/infra/online_stores/faiss_online_store.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ def retrieve_online_documents(
176176
self,
177177
config: RepoConfig,
178178
table: FeatureView,
179-
requested_feature: str,
179+
requested_feature: Optional[str],
180+
requested_featres: Optional[List[str]],
180181
embedding: List[float],
181182
top_k: int,
182183
distance_metric: Optional[str] = None,

sdk/python/feast/infra/online_stores/online_store.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,8 @@ def retrieve_online_documents(
390390
self,
391391
config: RepoConfig,
392392
table: FeatureView,
393-
requested_feature: str,
393+
requested_feature: Optional[str],
394+
requested_features: Optional[List[str]],
394395
embedding: List[float],
395396
top_k: int,
396397
distance_metric: Optional[str] = None,
@@ -411,6 +412,7 @@ def retrieve_online_documents(
411412
config: The config for the current feature store.
412413
table: The feature view whose feature values should be read.
413414
requested_feature: The name of the feature whose embeddings should be used for retrieval.
415+
requested_features: The list of features whose embeddings should be used for retrieval.
414416
embedding: The embeddings to use for retrieval.
415417
top_k: The number of documents to retrieve.
416418
@@ -419,6 +421,10 @@ def retrieve_online_documents(
419421
where the first item is the event timestamp for the row, and the second item is a dict of feature
420422
name to embeddings.
421423
"""
424+
if not requested_feature and not requested_features:
425+
raise ValueError(
426+
"Either requested_feature or requested_features must be specified"
427+
)
422428
raise NotImplementedError(
423429
f"Online store {self.__class__.__name__} does not support online retrieval"
424430
)

sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,8 @@ def retrieve_online_documents(
347347
self,
348348
config: RepoConfig,
349349
table: FeatureView,
350-
requested_feature: str,
350+
requested_feature: Optional[str],
351+
requested_features: Optional[List[str]],
351352
embedding: List[float],
352353
top_k: int,
353354
distance_metric: Optional[str] = "L2",
@@ -366,6 +367,7 @@ def retrieve_online_documents(
366367
config: Feast configuration object
367368
table: FeatureView object as the table to search
368369
requested_feature: The requested feature as the column to search
370+
requested_features: The list of features whose embeddings should be used for retrieval.
369371
embedding: The query embedding to search for
370372
top_k: The number of items to return
371373
distance_metric: The distance metric to use for the search.G

sdk/python/feast/infra/online_stores/qdrant_online_store/qdrant.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,8 @@ def retrieve_online_documents(
248248
self,
249249
config: RepoConfig,
250250
table: FeatureView,
251-
requested_feature: str,
251+
requested_feature: Optional[str],
252+
requested_features: Optional[List[str]],
252253
embedding: List[float],
253254
top_k: int,
254255
distance_metric: Optional[str] = "cosine",

sdk/python/feast/infra/online_stores/sqlite.py

+8-18
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,20 @@
1515
import logging
1616
import os
1717
import sqlite3
18-
import struct
1918
import sys
2019
from datetime import date, datetime
2120
from pathlib import Path
22-
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
21+
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
2322

24-
from google.protobuf.internal.containers import RepeatedScalarFieldContainer
2523
from pydantic import StrictStr
2624

2725
from feast import Entity
2826
from feast.feature_view import FeatureView
2927
from feast.infra.infra_object import SQLITE_INFRA_OBJECT_CLASS_TYPE, InfraObject
30-
from feast.infra.key_encoding_utils import serialize_entity_key
28+
from feast.infra.key_encoding_utils import (
29+
serialize_entity_key,
30+
serialize_f32,
31+
)
3132
from feast.infra.online_stores.online_store import OnlineStore
3233
from feast.infra.online_stores.vector_store import VectorStoreConfig
3334
from feast.protos.feast.core.InfraObject_pb2 import InfraObject as InfraObjectProto
@@ -330,7 +331,8 @@ def retrieve_online_documents(
330331
self,
331332
config: RepoConfig,
332333
table: FeatureView,
333-
requested_feature: str,
334+
requested_feature: Optional[str],
335+
requested_featuers: Optional[List[str]],
334336
embedding: List[float],
335337
top_k: int,
336338
distance_metric: Optional[str] = None,
@@ -432,6 +434,7 @@ def retrieve_online_documents(
432434
_build_retrieve_online_document_record(
433435
entity_key,
434436
string_value if string_value else b"",
437+
# This may be a bug
435438
embedding,
436439
distance,
437440
event_ts,
@@ -459,19 +462,6 @@ def _table_id(project: str, table: FeatureView) -> str:
459462
return f"{project}_{table.name}"
460463

461464

462-
def serialize_f32(
463-
vector: Union[RepeatedScalarFieldContainer[float], List[float]], vector_length: int
464-
) -> bytes:
465-
"""serializes a list of floats into a compact "raw bytes" format"""
466-
return struct.pack(f"{vector_length}f", *vector)
467-
468-
469-
def deserialize_f32(byte_vector: bytes, vector_length: int) -> List[float]:
470-
"""deserializes a list of floats from a compact "raw bytes" format"""
471-
num_floats = vector_length // 4 # 4 bytes per float
472-
return list(struct.unpack(f"{num_floats}f", byte_vector))
473-
474-
475465
class SqliteTable(InfraObject):
476466
"""
477467
A Sqlite table managed by Feast.

sdk/python/feast/infra/passthrough_provider.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,8 @@ def retrieve_online_documents(
294294
self,
295295
config: RepoConfig,
296296
table: FeatureView,
297-
requested_feature: str,
297+
requested_feature: Optional[str],
298+
requested_features: Optional[List[str]],
298299
query: List[float],
299300
top_k: int,
300301
distance_metric: Optional[str] = None,
@@ -305,6 +306,7 @@ def retrieve_online_documents(
305306
config,
306307
table,
307308
requested_feature,
309+
requested_features,
308310
query,
309311
top_k,
310312
distance_metric,

sdk/python/feast/infra/provider.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,8 @@ def retrieve_online_documents(
419419
self,
420420
config: RepoConfig,
421421
table: FeatureView,
422-
requested_feature: str,
422+
requested_feature: Optional[str],
423+
requested_features: Optional[List[str]],
423424
query: List[float],
424425
top_k: int,
425426
distance_metric: Optional[str] = None,
@@ -440,6 +441,7 @@ def retrieve_online_documents(
440441
config: The config for the current feature store.
441442
table: The feature view whose embeddings should be searched.
442443
requested_feature: the requested document feature name.
444+
requested_features: the requested document feature names.
443445
query: The query embedding to search for.
444446
top_k: The number of documents to return.
445447

sdk/python/feast/utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -1192,6 +1192,10 @@ def _utc_now() -> datetime:
11921192
return datetime.now(tz=timezone.utc)
11931193

11941194

1195+
def _serialize_vector_to_float_list(vector: List[float]) -> ValueProto:
1196+
return ValueProto(float_list_val=FloatListProto(val=vector))
1197+
1198+
11951199
def _build_retrieve_online_document_record(
11961200
entity_key: Union[str, bytes],
11971201
feature_value: Union[str, bytes],

0 commit comments

Comments
 (0)