Add kwargs to optimizer, scheduler and dataloader using function `acc… · huggingface/accelerate@8cb3ace
@@ -180,6 +180,7 @@ def load_accelerator_state(
180180process_index,
181181scaler=None,
182182map_location=None,
183+load_kwargs=None,
183184**load_model_func_kwargs,
184185):
185186"""
@@ -200,6 +201,8 @@ def load_accelerator_state(
200201 An optional *GradScaler* instance to load
201202 map_location (`str`, *optional*):
202203 What device to load the optimizer state onto. Should be one of either "cpu" or "on_device".
204+ load_kwargs (`dict`, *optional*):
205+ Additional arguments that can be passed to the `load` function.
203206 load_model_func_kwargs (`dict`, *optional*):
204207 Additional arguments that can be passed to the model's `load_state_dict` method.
205208@@ -217,6 +220,9 @@ def load_accelerator_state(
217220elif map_location == "on_device":
218221map_location = PartialState().device
219222223+if load_kwargs is None:
224+load_kwargs = {}
225+220226input_dir = Path(input_dir)
221227# Model states
222228for i, model in enumerate(models):
@@ -235,15 +241,15 @@ def load_accelerator_state(
235241for i, opt in enumerate(optimizers):
236242optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
237243input_optimizer_file = input_dir.joinpath(optimizer_name)
238-optimizer_state = load(input_optimizer_file, map_location=map_location)
244+optimizer_state = load(input_optimizer_file, map_location=map_location, **load_kwargs)
239245optimizers[i].load_state_dict(optimizer_state)
240246logger.info("All optimizer states loaded successfully")
241247242248# Scheduler states
243249for i, scheduler in enumerate(schedulers):
244250scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
245251input_scheduler_file = input_dir.joinpath(scheduler_name)
246-scheduler_state = load(input_scheduler_file)
252+scheduler_state = load(input_scheduler_file, **load_kwargs)
247253scheduler.load_state_dict(scheduler_state)
248254logger.info("All scheduler states loaded successfully")
249255@@ -261,7 +267,7 @@ def load_accelerator_state(
261267dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
262268input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name)
263269if input_dataloader_state_dict_file.exists():
264-state_dict = load(input_dataloader_state_dict_file)
270+state_dict = load(input_dataloader_state_dict_file, **load_kwargs)
265271dataloader.load_state_dict(state_dict)
266272logger.info("All dataloader sampler states loaded successfully")
267273