1
+ import logging
1
2
import os
2
3
from typing import List
3
4
7
8
from bytewax .execution import cluster_main
8
9
from bytewax .inputs import ManualInputConfig
9
10
from bytewax .outputs import ManualOutputConfig
10
- from tqdm import tqdm
11
11
12
12
from feast import FeatureStore , FeatureView , RepoConfig
13
13
from feast .utils import _convert_arrow_to_proto , _run_pyarrow_field_mapping
14
14
15
+ logger = logging .getLogger (__name__ )
15
16
DEFAULT_BATCH_SIZE = 1000
16
17
17
18
@@ -29,14 +30,20 @@ def __init__(
29
30
self .feature_view = feature_view
30
31
self .worker_index = worker_index
31
32
self .paths = paths
33
+ self .mini_batch_size = int (
34
+ os .getenv ("BYTEWAX_MINI_BATCH_SIZE" , DEFAULT_BATCH_SIZE )
35
+ )
32
36
33
37
self ._run_dataflow ()
34
38
35
39
def process_path (self , path ):
40
+ logger .info (f"Processing path { path } " )
36
41
dataset = pq .ParquetDataset (path , use_legacy_dataset = False )
37
42
batches = []
38
43
for fragment in dataset .fragments :
39
- for batch in fragment .to_table ().to_batches ():
44
+ for batch in fragment .to_table ().to_batches (
45
+ max_chunksize = self .mini_batch_size
46
+ ):
40
47
batches .append (batch )
41
48
42
49
return batches
@@ -45,40 +52,26 @@ def input_builder(self, worker_index, worker_count, _state):
45
52
return [(None , self .paths [self .worker_index ])]
46
53
47
54
def output_builder (self , worker_index , worker_count ):
48
- def yield_batch (iterable , batch_size ):
49
- """Yield mini-batches from an iterable."""
50
- for i in range (0 , len (iterable ), batch_size ):
51
- yield iterable [i : i + batch_size ]
52
-
53
- def output_fn (batch ):
54
- table = pa .Table .from_batches ([batch ])
55
+ def output_fn (mini_batch ):
56
+ table : pa .Table = pa .Table .from_batches ([mini_batch ])
55
57
56
58
if self .feature_view .batch_source .field_mapping is not None :
57
59
table = _run_pyarrow_field_mapping (
58
60
table , self .feature_view .batch_source .field_mapping
59
61
)
60
-
61
62
join_key_to_value_type = {
62
63
entity .name : entity .dtype .to_value_type ()
63
64
for entity in self .feature_view .entity_columns
64
65
}
65
-
66
66
rows_to_write = _convert_arrow_to_proto (
67
67
table , self .feature_view , join_key_to_value_type
68
68
)
69
- provider = self .feature_store ._get_provider ()
70
- with tqdm (total = len (rows_to_write )) as progress :
71
- # break rows_to_write to mini-batches
72
- batch_size = int (
73
- os .getenv ("BYTEWAX_MINI_BATCH_SIZE" , DEFAULT_BATCH_SIZE )
74
- )
75
- for mini_batch in yield_batch (rows_to_write , batch_size ):
76
- provider .online_write_batch (
77
- config = self .config ,
78
- table = self .feature_view ,
79
- data = mini_batch ,
80
- progress = progress .update ,
81
- )
69
+ self .feature_store ._get_provider ().online_write_batch (
70
+ config = self .config ,
71
+ table = self .feature_view ,
72
+ data = rows_to_write ,
73
+ progress = None ,
74
+ )
82
75
83
76
return output_fn
84
77
0 commit comments