-
Notifications
You must be signed in to change notification settings - Fork 570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mitchish65 #388
Mitchish65 #388
Conversation
With the latest fix, I can load an unsharded checkpoint this way. It takes a long time, and it's very inefficient (because every rank will load everything but then discard most of what it just loaded), but it does work. Unfortunately I seem to have broken how the optimizer state works. So this is not complete. |
In fact, with the way this works, we can now load a model that's so big it wouldn't fit into CPU memory. Not that we need to do this. But we could. |
For example, suppose you want to keep your checkpoints at every 1000 steps, but you also want to save a temporary checkpoint every 100 steps in case your job fails. In that case you would set `save_interval=1000` and `save_interval_ephemeral=100`.
…ints. But also still contains the old code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
One thing to consider is integrating the safetensors conversion into Shane's checkpoint management script. But let's leave that for another day.
fsdp_model.load_state_dict(state_dict_to_load) | ||
del state_dict_to_load | ||
with torch.no_grad(): | ||
# fill everything with NaN, so we can check afterwards that every parameter has been restored |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💯
More checkpointing formats
The problem that's being solved here is, how do we restore the 65B model from an unsharded checkpoint? The existing way works, I think, only by accident. If it works at all. It's possible that this never worked, or maybe only with sharding strategies that don't actually shard the model (like
NO_SHARD
orSHARD_GRAD_OP
orwrapping_strategy=null
).So this new way uses
FSDP.apply()
to copy tensors from a state dict into the model. This part is pretty straightforward. That's whatapply()
is for.The part that isn't straightforward is the detour through the safetensors format. Safetensors is brilliant. It lets you create a
Dict[str, Tensor]
, where the tensors are memory mapped files. It loads up a 500GB files in seconds (because, of course, it doesn't actually read the tensor bytes until later). So this PR contains a script that can read an existing unsharded checkpoint (in .pt format), and write it to disk in safetensors format (.safetensors). This can be done on CPU, though you need a lot of memory to do it. When reading a .pt file, we check whether there happens to be a .safetensors file with the same name, and if so, we load that instead.One more problem is that state dicts are not
Dict[str, Tensor]
. State dicts can contain inner dicts, and optimizer state dicts contain even more crazy stuff. So there is a mapper in this PR that maps crazy state dicts to well-formedDict[str, Tensor]
s and back. This sacrifices human interpretability of the files, but retains the lazy-loading memmap goodness from safetensors.TODO