Parameter name conflict with FSDP2 and torch.compile
Using an accelerate config with
fsdp_config: fsdp_version: 2 dynamo_config: dynamo_backend: INDUCTOR
seems to cause issues in the following lines; where old_named_params has the original model weight names, and new_named_params returns the keys for an OptimizedModule that are prefixed with _orig_mod.; leading to a KeyError in constructing the mapping:
| new_named_params = self._get_named_parameters(*result) | |
| if fsdp2_should_fix_optimizer and self.state.fsdp_plugin.activation_checkpointing: | |
| new_named_params = { | |
| k.replace("._checkpoint_wrapped_module", ""): v for k, v in new_named_params.items() | |
| } | |
| # 3. building a map from the first to the second | |
| mapping = {p: new_named_params[n] for n, p in old_named_params.items()} |
A naive fix of adding
new_named_params = {k.replace("_orig_mod.", ""): v for k, v in new_named_params.items() if k.startswith("_orig_mod")}
before constructing the mapping seems to fix the error for me