[FSDP2] Fix memory spike with `cpu_ram_efficient_loading=True` by S1ro1 · Pull Request #3482 · huggingface/accelerate
With initial implementation of FSDP2, there is an issue if cpu_ram_efficient_loading is set. See this comment for details. This usually isn't an issue, but if the model is large enough, it can lead to memory spike that is ~2x the size of peak.
This PR fixes the spike first found in #3474. Providing some numerics below:
Loss parity between cpu_ram_efficient_loading=true/false
wandb link
Memory snapshots:
Old (cpu_ram_efficient_loading=True)
New (cpu_ram_efficient_loading=True)
Baseline (cpu_ram_efficient_loading=False)
(The discrepancy between new/baseline in the semi-big rectangle appearing first/in the middle is caused by difference in order of evaluation - top-level first vs bottom-level first)
Possibly moving get_non_persistent_buffer_fqns and making it private could be an upgrade, this function is quite specific to this use-case.


