Skip to content

Commit a1ff129

Browse files
authored
feat: Faiss and In memory store (#4464)
* add faiss & in memory online store Signed-off-by: cmuhao <sduxuhao@gmail.com> * add faiss & in memory online store Signed-off-by: cmuhao <sduxuhao@gmail.com> * add faiss & in memory online store Signed-off-by: cmuhao <sduxuhao@gmail.com> * add faiss & in memory online store Signed-off-by: cmuhao <sduxuhao@gmail.com> * add faiss & in memory online store Signed-off-by: cmuhao <sduxuhao@gmail.com> * add faiss & in memory online store Signed-off-by: cmuhao <sduxuhao@gmail.com> * add faiss & in memory online store Signed-off-by: cmuhao <sduxuhao@gmail.com> * add faiss & in memory online store Signed-off-by: cmuhao <sduxuhao@gmail.com> * add faiss & in memory online store Signed-off-by: cmuhao <sduxuhao@gmail.com> * add faiss & in memory online store Signed-off-by: cmuhao <sduxuhao@gmail.com> * add faiss & in memory online store Signed-off-by: cmuhao <sduxuhao@gmail.com> * add faiss & in memory online store Signed-off-by: cmuhao <sduxuhao@gmail.com> * add dependency Signed-off-by: cmuhao <sduxuhao@gmail.com> * update package name Signed-off-by: cmuhao <sduxuhao@gmail.com> --------- Signed-off-by: cmuhao <sduxuhao@gmail.com>
1 parent 9ca1452 commit a1ff129

File tree

2 files changed

+240
-0
lines changed

2 files changed

+240
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
import logging
2+
from datetime import datetime
3+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
4+
5+
import faiss
6+
import numpy as np
7+
from google.protobuf.timestamp_pb2 import Timestamp
8+
9+
from feast import Entity, FeatureView, RepoConfig
10+
from feast.infra.key_encoding_utils import serialize_entity_key
11+
from feast.infra.online_stores.online_store import OnlineStore
12+
from feast.protos.feast.types.EntityKey_pb2 import EntityKey
13+
from feast.protos.feast.types.Value_pb2 import Value
14+
from feast.repo_config import FeastConfigBaseModel
15+
16+
17+
class FaissOnlineStoreConfig(FeastConfigBaseModel):
18+
dimension: int
19+
index_path: str
20+
index_type: str = "IVFFlat"
21+
nlist: int = 100
22+
23+
24+
class InMemoryStore:
25+
def __init__(self):
26+
self.feature_names: List[str] = []
27+
self.entity_keys: Dict[str, int] = {}
28+
29+
def update(self, feature_names: List[str], entity_keys: Dict[str, int]):
30+
self.feature_names = feature_names
31+
self.entity_keys = entity_keys
32+
33+
def delete(self, entity_keys: List[str]):
34+
for entity_key in entity_keys:
35+
if entity_key in self.entity_keys:
36+
del self.entity_keys[entity_key]
37+
38+
def read(self, entity_keys: List[str]) -> List[Optional[int]]:
39+
return [self.entity_keys.get(entity_key) for entity_key in entity_keys]
40+
41+
def teardown(self):
42+
self.feature_names = []
43+
self.entity_keys = {}
44+
45+
46+
class FaissOnlineStore(OnlineStore):
47+
_index: Optional[faiss.IndexIVFFlat] = None
48+
_in_memory_store: InMemoryStore = InMemoryStore()
49+
_config: Optional[FaissOnlineStoreConfig] = None
50+
_logger: logging.Logger = logging.getLogger(__name__)
51+
52+
def _get_index(self, config: RepoConfig) -> faiss.IndexIVFFlat:
53+
if self._index is None or self._config is None:
54+
raise ValueError("Index is not initialized")
55+
return self._index
56+
57+
def update(
58+
self,
59+
config: RepoConfig,
60+
tables_to_delete: Sequence[FeatureView],
61+
tables_to_keep: Sequence[FeatureView],
62+
entities_to_delete: Sequence[Entity],
63+
entities_to_keep: Sequence[Entity],
64+
partial: bool,
65+
):
66+
feature_views = tables_to_keep
67+
if not feature_views:
68+
return
69+
70+
feature_names = [f.name for f in feature_views[0].features]
71+
dimension = len(feature_names)
72+
73+
self._config = FaissOnlineStoreConfig(**config.online_store.dict())
74+
if self._index is None or not partial:
75+
quantizer = faiss.IndexFlatL2(dimension)
76+
self._index = faiss.IndexIVFFlat(quantizer, dimension, self._config.nlist)
77+
self._index.train(
78+
np.random.rand(self._config.nlist * 100, dimension).astype(np.float32)
79+
)
80+
self._in_memory_store = InMemoryStore()
81+
82+
self._in_memory_store.update(feature_names, {})
83+
84+
def teardown(
85+
self,
86+
config: RepoConfig,
87+
tables: Sequence[FeatureView],
88+
entities: Sequence[Entity],
89+
):
90+
self._index = None
91+
self._in_memory_store.teardown()
92+
93+
def online_read(
94+
self,
95+
config: RepoConfig,
96+
table: FeatureView,
97+
entity_keys: List[EntityKey],
98+
requested_features: Optional[List[str]] = None,
99+
) -> List[Tuple[Optional[datetime], Optional[Dict[str, Value]]]]:
100+
if self._index is None:
101+
return [(None, None)] * len(entity_keys)
102+
103+
results: List[Tuple[Optional[datetime], Optional[Dict[str, Any]]]] = []
104+
for entity_key in entity_keys:
105+
serialized_key = serialize_entity_key(
106+
entity_key, config.entity_key_serialization_version
107+
).hex()
108+
idx = self._in_memory_store.entity_keys.get(serialized_key, -1)
109+
if idx == -1:
110+
results.append((None, None))
111+
else:
112+
feature_vector = self._index.reconstruct(int(idx))
113+
feature_dict = {
114+
name: Value(double_val=value)
115+
for name, value in zip(
116+
self._in_memory_store.feature_names, feature_vector
117+
)
118+
}
119+
results.append((None, feature_dict))
120+
return results
121+
122+
def online_write_batch(
123+
self,
124+
config: RepoConfig,
125+
table: FeatureView,
126+
data: List[Tuple[EntityKey, Dict[str, Value], datetime, Optional[datetime]]],
127+
progress: Optional[Callable[[int], Any]],
128+
) -> None:
129+
if self._index is None:
130+
self._logger.warning("Index is not initialized. Skipping write operation.")
131+
return
132+
133+
feature_vectors = []
134+
serialized_keys = []
135+
136+
for entity_key, feature_dict, _, _ in data:
137+
serialized_key = serialize_entity_key(
138+
entity_key, config.entity_key_serialization_version
139+
).hex()
140+
feature_vector = np.array(
141+
[
142+
feature_dict[name].double_val
143+
for name in self._in_memory_store.feature_names
144+
],
145+
dtype=np.float32,
146+
)
147+
148+
feature_vectors.append(feature_vector)
149+
serialized_keys.append(serialized_key)
150+
151+
feature_vectors_array = np.array(feature_vectors)
152+
153+
existing_indices = [
154+
self._in_memory_store.entity_keys.get(sk, -1) for sk in serialized_keys
155+
]
156+
mask = np.array(existing_indices) != -1
157+
if np.any(mask):
158+
self._index.remove_ids(
159+
np.array([idx for idx in existing_indices if idx != -1])
160+
)
161+
162+
new_indices = np.arange(
163+
self._index.ntotal, self._index.ntotal + len(feature_vectors_array)
164+
)
165+
self._index.add(feature_vectors_array)
166+
167+
for sk, idx in zip(serialized_keys, new_indices):
168+
self._in_memory_store.entity_keys[sk] = idx
169+
170+
if progress:
171+
progress(len(data))
172+
173+
def retrieve_online_documents(
174+
self,
175+
config: RepoConfig,
176+
table: FeatureView,
177+
requested_feature: str,
178+
embedding: List[float],
179+
top_k: int,
180+
distance_metric: Optional[str] = None,
181+
) -> List[
182+
Tuple[
183+
Optional[datetime],
184+
Optional[Value],
185+
Optional[Value],
186+
Optional[Value],
187+
]
188+
]:
189+
if self._index is None:
190+
self._logger.warning("Index is not initialized. Returning empty result.")
191+
return []
192+
193+
query_vector = np.array(embedding, dtype=np.float32).reshape(1, -1)
194+
distances, indices = self._index.search(query_vector, top_k)
195+
196+
results: List[
197+
Tuple[
198+
Optional[datetime],
199+
Optional[Value],
200+
Optional[Value],
201+
Optional[Value],
202+
]
203+
] = []
204+
for i, idx in enumerate(indices[0]):
205+
if idx == -1:
206+
continue
207+
208+
feature_vector = self._index.reconstruct(int(idx))
209+
210+
timestamp = Timestamp()
211+
timestamp.GetCurrentTime()
212+
213+
feature_value = Value(string_val=",".join(map(str, feature_vector)))
214+
vector_value = Value(string_val=",".join(map(str, feature_vector)))
215+
distance_value = Value(float_val=distances[0][i])
216+
217+
results.append(
218+
(
219+
timestamp.ToDatetime(),
220+
feature_value,
221+
vector_value,
222+
distance_value,
223+
)
224+
)
225+
226+
return results
227+
228+
async def online_read_async(
229+
self,
230+
config: RepoConfig,
231+
table: FeatureView,
232+
entity_keys: List[EntityKey],
233+
requested_features: Optional[List[str]] = None,
234+
) -> List[Tuple[Optional[datetime], Optional[Dict[str, Value]]]]:
235+
# Implement async read if needed
236+
raise NotImplementedError("Async read is not implemented for FaissOnlineStore")

setup.py

+4
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@
144144

145145
MSSQL_REQUIRED = ["ibis-framework[mssql]>=9.0.0,<10"]
146146

147+
FAISS_REQUIRED = ["faiss-cpu>=1.7.0,<2"]
148+
147149
CI_REQUIRED = (
148150
[
149151
"build",
@@ -210,6 +212,7 @@
210212
+ SQLITE_VEC_REQUIRED
211213
+ SINGLESTORE_REQUIRED
212214
+ OPENTELEMETRY
215+
+ FAISS_REQUIRED
213216
)
214217

215218
DOCS_REQUIRED = CI_REQUIRED
@@ -279,6 +282,7 @@
279282
"sqlite_vec": SQLITE_VEC_REQUIRED,
280283
"singlestore": SINGLESTORE_REQUIRED,
281284
"opentelemetry": OPENTELEMETRY,
285+
"faiss": FAISS_REQUIRED,
282286
},
283287
include_package_data=True,
284288
license="Apache",

0 commit comments

Comments
 (0)