15
15
import boto3 .session
16
16
import botocore .exceptions as boto_exceptions
17
17
import google .cloud .storage as gcs
18
+ import omegaconf
18
19
import torch
19
20
import wandb
20
21
from boto3 .s3 .transfer import TransferConfig
@@ -622,6 +623,8 @@ class DeleteBadRunsConfig(StorageCleanerConfig):
622
623
@dataclass
623
624
class UnshardCheckpointsConfig (StorageCleanerConfig ):
624
625
latest_checkpoint_only : bool
626
+ delete_sharded_checkpoints : bool
627
+ checkpoint_num : Optional [int ]
625
628
626
629
627
630
@dataclass
@@ -765,9 +768,13 @@ def delete_bad_runs(run_paths: List[str], config: DeleteBadRunsConfig):
765
768
shutil .rmtree (config .temp_dir )
766
769
767
770
768
- def _is_sharded_checkpoint_dir (directory : str ) -> bool :
771
+ def _is_checkpoint_dir (directory : str ) -> bool :
769
772
storage = _get_storage_adapter_for_path (directory )
770
- return storage .is_dir (directory ) and re .match (r"step\d+$" , Path (directory ).name ) is not None
773
+ return storage .is_dir (directory ) and re .match (r"step\d+(-unsharded)?$" , Path (directory ).name ) is not None
774
+
775
+
776
+ def _is_sharded_checkpoint_dir (directory : str ) -> bool :
777
+ return _is_checkpoint_dir (directory ) and re .match (r"step\d+$" , Path (directory ).name ) is not None
771
778
772
779
773
780
def _get_checkpoint_number (checkpoint_dir : str ) -> int :
@@ -781,17 +788,31 @@ def _get_checkpoint_number(checkpoint_dir: str) -> int:
781
788
782
789
783
790
def _get_sharded_checkpoint_dirs (
784
- run_dir_storage : StorageAdapter , run_dir : str , run_dir_or_archive : str , latest_checkpoint_only : bool
791
+ run_dir_storage : StorageAdapter ,
792
+ run_dir : str ,
793
+ run_dir_or_archive : str ,
794
+ latest_checkpoint_only : bool ,
795
+ checkpoint_num : Optional [int ] = None ,
785
796
) -> List [str ]:
786
797
run_subdir_names = run_dir_storage .list_dirs (run_dir )
787
798
run_subdirectories = list (map (lambda dir_name : os .path .join (run_dir , dir_name ), run_subdir_names ))
788
799
sharded_checkpoint_directories = list (filter (_is_sharded_checkpoint_dir , run_subdirectories ))
789
800
801
+ if latest_checkpoint_only and checkpoint_num is not None :
802
+ raise ValueError ("Cannot set both 'latest_checkpoint_only' and 'checkpoint_num'" )
803
+
790
804
if latest_checkpoint_only :
791
805
latest_checkpoint_directory = max (sharded_checkpoint_directories , default = None , key = _get_checkpoint_number )
792
806
sharded_checkpoint_directories = (
793
807
[latest_checkpoint_directory ] if latest_checkpoint_directory is not None else []
794
808
)
809
+ elif checkpoint_num is not None :
810
+ sharded_checkpoint_directories = [
811
+ sharded_checkpoint_dir
812
+ for sharded_checkpoint_dir in sharded_checkpoint_directories
813
+ if _get_checkpoint_number (sharded_checkpoint_dir ) == checkpoint_num
814
+ ]
815
+ assert len (sharded_checkpoint_directories ) <= 1
795
816
796
817
log .info (
797
818
"Found %d sharded checkpoint directories for %s" , len (sharded_checkpoint_directories ), run_dir_or_archive
@@ -844,13 +865,29 @@ def _unshard_checkpoint(
844
865
sharding_output_dir = local_storage .create_temp_dir (directory = unsharding_config .temp_dir )
845
866
846
867
try :
847
- config = TrainConfig .load (Path (sharding_input_dir ) / "config.yaml" , validate_paths = False )
848
- sharded_checkpoint_type = config .sharded_checkpointer
868
+ # `TrainConfig` is not backwards-compatible with all older checkpoints, so
869
+ # we need to load the yaml directly.
870
+ raw_config = om .load (str (Path (sharding_input_dir ) / "config.yaml" ))
871
+ assert isinstance (raw_config , omegaconf .DictConfig )
872
+
873
+ sharded_checkpoint_type_str = raw_config .get ("sharded_checkpointer" , "torch_legacy" )
874
+ if sharded_checkpoint_type_str == "legacy" :
875
+ # At some point, the enum string for ShardedCheckpointerType.torch_legacy was "legacy"
876
+ sharded_checkpoint_type_str = "torch_legacy"
877
+
878
+ sharded_checkpoint_type = ShardedCheckpointerType [sharded_checkpoint_type_str ]
879
+
880
+ # The ShardedCheckpointers require a `TrainConfig` to be passed in, but
881
+ # legacy configs are not all compatible with this class. None of the config
882
+ # settings are needed for unsharding, so we pass in a dummy config instead.
883
+ # This is a hack, but decoupling unsharding for checkpoint saving/loading
884
+ # seems like overkill.
885
+ dummy_config = TrainConfig .new ()
849
886
checkpointer : Checkpointer
850
887
if sharded_checkpoint_type == ShardedCheckpointerType .torch_legacy :
851
- checkpointer = TorchLegacyShardedCheckpointer (config )
888
+ checkpointer = TorchLegacyShardedCheckpointer (dummy_config )
852
889
elif sharded_checkpoint_type == ShardedCheckpointerType .local :
853
- checkpointer = LocalShardedCheckpointer (config )
890
+ checkpointer = LocalShardedCheckpointer (dummy_config )
854
891
else :
855
892
raise NotImplementedError (sharded_checkpoint_type )
856
893
@@ -911,11 +948,14 @@ def _unshard_checkpoints(
911
948
):
912
949
log .info ("Starting unsharding checkpoints of run directory or archive %s" , run_dir_or_archive )
913
950
951
+ if config .delete_sharded_checkpoints and _is_archive (run_dir_or_archive , run_storage ):
952
+ raise ValueError ("Cannot delete sharded checkpoints of run archive files" )
953
+
914
954
run_dir = _unarchive_if_archive (run_dir_or_archive , run_storage )
915
955
run_dir_storage = _get_storage_adapter_for_path (run_dir )
916
956
917
957
sharded_checkpoint_directories = _get_sharded_checkpoint_dirs (
918
- run_dir_storage , run_dir , run_dir_or_archive , config .latest_checkpoint_only
958
+ run_dir_storage , run_dir , run_dir_or_archive , config .latest_checkpoint_only , config . checkpoint_num
919
959
)
920
960
for sharded_checkpoint_directory in sharded_checkpoint_directories :
921
961
sharded_checkpoint_dir_name = Path (sharded_checkpoint_directory ).name
@@ -947,6 +987,14 @@ def _unshard_checkpoints(
947
987
log .info ("Unsharding sharded checkpoint %s to %s" , sharded_checkpoint_directory , dest_directory )
948
988
_unshard_checkpoint (sharded_checkpoint_directory , dest_directory , run_dir , config )
949
989
990
+ if config .delete_sharded_checkpoints :
991
+ assert run_dir == run_dir_or_archive
992
+ if config .dry_run :
993
+ log .info ("Would delete sharded checkpoint %s" , sharded_checkpoint_directory )
994
+ else :
995
+ log .info ("Deleting sharded checkpoint %s" , sharded_checkpoint_directory )
996
+ run_dir_storage .delete_path (sharded_checkpoint_directory )
997
+
950
998
951
999
def unshard_run_checkpoints (run_path : str , checkpoints_dest_dir : str , config : UnshardCheckpointsConfig ):
952
1000
storage = _get_storage_adapter_for_path (run_path )
@@ -1252,6 +1300,8 @@ def perform_operation(args: argparse.Namespace):
1252
1300
dry_run = args .dry_run ,
1253
1301
temp_dir = temp_dir ,
1254
1302
latest_checkpoint_only = args .latest_checkpoint_only ,
1303
+ delete_sharded_checkpoints = args .delete_sharded_checkpoints ,
1304
+ checkpoint_num = args .checkpoint_num ,
1255
1305
)
1256
1306
if args .run_path is not None :
1257
1307
unshard_run_checkpoints (args .run_path , args .dest_dir , unshard_checkpoints_config )
@@ -1327,6 +1377,18 @@ def _add_unsharding_subparser(subparsers: _SubParsersAction):
1327
1377
action = "store_true" ,
1328
1378
help = "If set, only the latest checkpoint of each run (if sharded) is unsharded." ,
1329
1379
)
1380
+ unsharding_runs_parser .add_argument (
1381
+ "--delete_sharded" ,
1382
+ dest = "delete_sharded_checkpoints" ,
1383
+ action = "store_true" ,
1384
+ help = "If set, deletes sharded checkpoints after they have been successfully unsharded." ,
1385
+ )
1386
+ unsharding_runs_parser .add_argument (
1387
+ "--checkpoint_num" ,
1388
+ type = int ,
1389
+ default = None ,
1390
+ help = "If provided, unsharding is restricted to this checkpoint of the run." ,
1391
+ )
1330
1392
1331
1393
1332
1394
def _add_move_subparser (subparsers : _SubParsersAction ):
0 commit comments