|
15 | 15 | import contextlib
|
16 | 16 | import itertools
|
17 | 17 | import logging
|
| 18 | +from collections import OrderedDict |
18 | 19 | from datetime import datetime
|
19 | 20 | from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
|
20 | 21 |
|
|
26 | 27 | from feast.infra.online_stores.helpers import compute_entity_id
|
27 | 28 | from feast.infra.online_stores.online_store import OnlineStore
|
28 | 29 | from feast.infra.supported_async_methods import SupportedAsyncMethods
|
| 30 | +from feast.infra.utils.aws_utils import dynamo_write_items_async |
29 | 31 | from feast.protos.feast.core.DynamoDBTable_pb2 import (
|
30 | 32 | DynamoDBTable as DynamoDBTableProto,
|
31 | 33 | )
|
@@ -103,7 +105,7 @@ async def close(self):
|
103 | 105 |
|
104 | 106 | @property
|
105 | 107 | def async_supported(self) -> SupportedAsyncMethods:
|
106 |
| - return SupportedAsyncMethods(read=True) |
| 108 | + return SupportedAsyncMethods(read=True, write=True) |
107 | 109 |
|
108 | 110 | def update(
|
109 | 111 | self,
|
@@ -238,6 +240,42 @@ def online_write_batch(
|
238 | 240 | )
|
239 | 241 | self._write_batch_non_duplicates(table_instance, data, progress, config)
|
240 | 242 |
|
| 243 | + async def online_write_batch_async( |
| 244 | + self, |
| 245 | + config: RepoConfig, |
| 246 | + table: FeatureView, |
| 247 | + data: List[ |
| 248 | + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] |
| 249 | + ], |
| 250 | + progress: Optional[Callable[[int], Any]], |
| 251 | + ) -> None: |
| 252 | + """ |
| 253 | + Writes a batch of feature rows to the online store asynchronously. |
| 254 | +
|
| 255 | + If a tz-naive timestamp is passed to this method, it is assumed to be UTC. |
| 256 | +
|
| 257 | + Args: |
| 258 | + config: The config for the current feature store. |
| 259 | + table: Feature view to which these feature rows correspond. |
| 260 | + data: A list of quadruplets containing feature data. Each quadruplet contains an entity |
| 261 | + key, a dict containing feature values, an event timestamp for the row, and the created |
| 262 | + timestamp for the row if it exists. |
| 263 | + progress: Function to be called once a batch of rows is written to the online store, used |
| 264 | + to show progress. |
| 265 | + """ |
| 266 | + online_config = config.online_store |
| 267 | + assert isinstance(online_config, DynamoDBOnlineStoreConfig) |
| 268 | + |
| 269 | + table_name = _get_table_name(online_config, config, table) |
| 270 | + items = [ |
| 271 | + _to_client_write_item(config, entity_key, features, timestamp) |
| 272 | + for entity_key, features, timestamp, _ in _latest_data_to_write(data) |
| 273 | + ] |
| 274 | + client = await _get_aiodynamodb_client( |
| 275 | + online_config.region, config.online_store.max_pool_connections |
| 276 | + ) |
| 277 | + await dynamo_write_items_async(client, table_name, items) |
| 278 | + |
241 | 279 | def online_read(
|
242 | 280 | self,
|
243 | 281 | config: RepoConfig,
|
@@ -419,19 +457,10 @@ def _write_batch_non_duplicates(
|
419 | 457 | """Deduplicate write batch request items on ``entity_id`` primary key."""
|
420 | 458 | with table_instance.batch_writer(overwrite_by_pkeys=["entity_id"]) as batch:
|
421 | 459 | for entity_key, features, timestamp, created_ts in data:
|
422 |
| - entity_id = compute_entity_id( |
423 |
| - entity_key, |
424 |
| - entity_key_serialization_version=config.entity_key_serialization_version, |
425 |
| - ) |
426 | 460 | batch.put_item(
|
427 |
| - Item={ |
428 |
| - "entity_id": entity_id, # PartitionKey |
429 |
| - "event_ts": str(utils.make_tzaware(timestamp)), |
430 |
| - "values": { |
431 |
| - k: v.SerializeToString() |
432 |
| - for k, v in features.items() # Serialized Features |
433 |
| - }, |
434 |
| - } |
| 461 | + Item=_to_resource_write_item( |
| 462 | + config, entity_key, features, timestamp |
| 463 | + ) |
435 | 464 | )
|
436 | 465 | if progress:
|
437 | 466 | progress(1)
|
@@ -675,3 +704,45 @@ def _get_dynamodb_resource(self, region: str, endpoint_url: Optional[str] = None
|
675 | 704 | region, endpoint_url
|
676 | 705 | )
|
677 | 706 | return self._dynamodb_resource
|
| 707 | + |
| 708 | + |
| 709 | +def _to_resource_write_item(config, entity_key, features, timestamp): |
| 710 | + entity_id = compute_entity_id( |
| 711 | + entity_key, |
| 712 | + entity_key_serialization_version=config.entity_key_serialization_version, |
| 713 | + ) |
| 714 | + return { |
| 715 | + "entity_id": entity_id, # PartitionKey |
| 716 | + "event_ts": str(utils.make_tzaware(timestamp)), |
| 717 | + "values": { |
| 718 | + k: v.SerializeToString() |
| 719 | + for k, v in features.items() # Serialized Features |
| 720 | + }, |
| 721 | + } |
| 722 | + |
| 723 | + |
| 724 | +def _to_client_write_item(config, entity_key, features, timestamp): |
| 725 | + entity_id = compute_entity_id( |
| 726 | + entity_key, |
| 727 | + entity_key_serialization_version=config.entity_key_serialization_version, |
| 728 | + ) |
| 729 | + return { |
| 730 | + "entity_id": {"S": entity_id}, # PartitionKey |
| 731 | + "event_ts": {"S": str(utils.make_tzaware(timestamp))}, |
| 732 | + "values": { |
| 733 | + "M": { |
| 734 | + k: {"B": v.SerializeToString()} |
| 735 | + for k, v in features.items() # Serialized Features |
| 736 | + } |
| 737 | + }, |
| 738 | + } |
| 739 | + |
| 740 | + |
| 741 | +def _latest_data_to_write( |
| 742 | + data: List[ |
| 743 | + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] |
| 744 | + ], |
| 745 | +): |
| 746 | + as_hashable = ((d[0].SerializeToString(), d) for d in data) |
| 747 | + sorted_data = sorted(as_hashable, key=lambda ah: (ah[0], ah[1][2])) |
| 748 | + return (v for _, v in OrderedDict((ah[0], ah[1]) for ah in sorted_data).items()) |
0 commit comments