1
- from typing import List , Any
1
+ from typing import Any , Dict , List , Optional , Tuple
2
2
import json
3
3
import os
4
4
from pathlib import Path
@@ -14,17 +14,21 @@ def get_all_files() -> List[str]:
14
14
return [str (x ) for x in sources ]
15
15
16
16
17
- def calculate_shards ( all_files : List [str ], num_shards : int = 20 ) :
17
+ def read_metadata () -> Dict [str , Any ] :
18
18
with (REPO_BASE_DIR / ".jenkins" / "metadata.json" ).open () as fp :
19
- metadata = json .load (fp )
20
- sharded_files = [(0.0 , []) for _ in range (num_shards )]
19
+ return json .load (fp )
21
20
22
- def get_duration (file ):
21
+
22
+ def calculate_shards (all_files : List [str ], num_shards : int = 20 ) -> List [List [str ]]:
23
+ sharded_files : List [Tuple [float , List [str ]]] = [(0.0 , []) for _ in range (num_shards )]
24
+ metadata = read_metadata ()
25
+
26
+ def get_duration (file : str ) -> int :
23
27
# tutorials not listed in the metadata.json file usually take
24
28
# <3min to run, so we'll default to 1min if it's not listed
25
29
return metadata .get (file , {}).get ("duration" , 60 )
26
30
27
- def get_needs_machine (file ) :
31
+ def get_needs_machine (file : str ) -> Optional [ str ] :
28
32
return metadata .get (file , {}).get ("needs" , None )
29
33
30
34
def add_to_shard (i , filename ):
@@ -55,9 +59,19 @@ def add_to_shard(i, filename):
55
59
return [x [1 ] for x in sharded_files ]
56
60
57
61
58
- def remove_other_files (all_files , files_to_run ) -> None :
62
+ def compute_files_to_keep (files_to_run : List [str ]) -> List [str ]:
63
+ metadata = read_metadata ()
64
+ files_to_keep = list (files_to_run )
65
+ for file in files_to_run :
66
+ extra_files = metadata .get (file , {}).get ("extra_files" , [])
67
+ files_to_keep .extend (extra_files )
68
+ return files_to_keep
69
+
70
+
71
+ def remove_other_files (all_files , files_to_keep ) -> None :
72
+
59
73
for file in all_files :
60
- if file not in files_to_run :
74
+ if file not in files_to_keep :
61
75
remove_runnable_code (file , file )
62
76
63
77
@@ -76,7 +90,7 @@ def main() -> None:
76
90
all_files = get_all_files ()
77
91
files_to_run = calculate_shards (all_files , num_shards = args .num_shards )[args .shard_num ]
78
92
if not args .dry_run :
79
- remove_other_files (all_files , files_to_run )
93
+ remove_other_files (all_files , compute_files_to_keep ( files_to_run ) )
80
94
stripped_file_names = [Path (x ).stem for x in files_to_run ]
81
95
print (" " .join (stripped_file_names ))
82
96
0 commit comments