1
+ import os
1
2
from typing import List
2
3
3
4
import pyarrow as pa
11
12
from feast import FeatureStore , FeatureView , RepoConfig
12
13
from feast .utils import _convert_arrow_to_proto , _run_pyarrow_field_mapping
13
14
15
+ DEFAULT_BATCH_SIZE = 1000
16
+
14
17
15
18
class BytewaxMaterializationDataflow :
16
19
def __init__ (
@@ -44,6 +47,11 @@ def input_builder(self, worker_index, worker_count, _state):
44
47
return
45
48
46
49
def output_builder (self , worker_index , worker_count ):
50
+ def yield_batch (iterable , batch_size ):
51
+ """Yield mini-batches from an iterable."""
52
+ for i in range (0 , len (iterable ), batch_size ):
53
+ yield iterable [i : i + batch_size ]
54
+
47
55
def output_fn (batch ):
48
56
table = pa .Table .from_batches ([batch ])
49
57
@@ -62,12 +70,17 @@ def output_fn(batch):
62
70
)
63
71
provider = self .feature_store ._get_provider ()
64
72
with tqdm (total = len (rows_to_write )) as progress :
65
- provider .online_write_batch (
66
- config = self .config ,
67
- table = self .feature_view ,
68
- data = rows_to_write ,
69
- progress = progress .update ,
73
+ # break rows_to_write to mini-batches
74
+ batch_size = int (
75
+ os .getenv ("BYTEWAX_MINI_BATCH_SIZE" , DEFAULT_BATCH_SIZE )
70
76
)
77
+ for mini_batch in yield_batch (rows_to_write , batch_size ):
78
+ provider .online_write_batch (
79
+ config = self .config ,
80
+ table = self .feature_view ,
81
+ data = mini_batch ,
82
+ progress = progress .update ,
83
+ )
71
84
72
85
return output_fn
73
86
0 commit comments