Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mitchish65 #388

Merged
merged 72 commits into from
Jan 5, 2024
Merged

Mitchish65 #388

merged 72 commits into from
Jan 5, 2024

Conversation

dirkgr
Copy link
Member

@dirkgr dirkgr commented Dec 4, 2023

More checkpointing formats

The problem that's being solved here is, how do we restore the 65B model from an unsharded checkpoint? The existing way works, I think, only by accident. If it works at all. It's possible that this never worked, or maybe only with sharding strategies that don't actually shard the model (like NO_SHARD or SHARD_GRAD_OP or wrapping_strategy=null).

So this new way uses FSDP.apply() to copy tensors from a state dict into the model. This part is pretty straightforward. That's what apply() is for.

The part that isn't straightforward is the detour through the safetensors format. Safetensors is brilliant. It lets you create a Dict[str, Tensor], where the tensors are memory mapped files. It loads up a 500GB files in seconds (because, of course, it doesn't actually read the tensor bytes until later). So this PR contains a script that can read an existing unsharded checkpoint (in .pt format), and write it to disk in safetensors format (.safetensors). This can be done on CPU, though you need a lot of memory to do it. When reading a .pt file, we check whether there happens to be a .safetensors file with the same name, and if so, we load that instead.

One more problem is that state dicts are not Dict[str, Tensor]. State dicts can contain inner dicts, and optimizer state dicts contain even more crazy stuff. So there is a mapper in this PR that maps crazy state dicts to well-formed Dict[str, Tensor]s and back. This sacrifices human interpretability of the files, but retains the lazy-loading memmap goodness from safetensors.

TODO

  • Verify that this works with 7B on Cirrascale.
  • Verify this works with the 65B on LUMI.

@dirkgr
Copy link
Member Author

dirkgr commented Dec 6, 2023

With the latest fix, I can load an unsharded checkpoint this way. It takes a long time, and it's very inefficient (because every rank will load everything but then discard most of what it just loaded), but it does work.

Unfortunately I seem to have broken how the optimizer state works. So this is not complete.

@dirkgr
Copy link
Member Author

dirkgr commented Dec 6, 2023

In fact, with the way this works, we can now load a model that's so big it wouldn't fit into CPU memory. Not that we need to do this. But we could.

@dirkgr dirkgr marked this pull request as ready for review January 4, 2024 01:43
@dirkgr dirkgr requested a review from AkshitaB January 4, 2024 18:10
Copy link
Member

@epwalsh epwalsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

One thing to consider is integrating the safetensors conversion into Shane's checkpoint management script. But let's leave that for another day.

fsdp_model.load_state_dict(state_dict_to_load)
del state_dict_to_load
with torch.no_grad():
# fill everything with NaN, so we can check afterwards that every parameter has been restored
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯

@dirkgr dirkgr merged commit df19554 into main Jan 5, 2024
@dirkgr dirkgr deleted the mitchish65 branch January 5, 2024 18:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants