Add kwargs to optimizer, scheduler and dataloader using function `acc… · huggingface/accelerate@8cb3ace

@@ -180,6 +180,7 @@ def load_accelerator_state(

180180

process_index,

181181

scaler=None,

182182

map_location=None,

183+

load_kwargs=None,

183184

**load_model_func_kwargs,

184185

):

185186

"""

@@ -200,6 +201,8 @@ def load_accelerator_state(

200201

An optional *GradScaler* instance to load

201202

map_location (`str`, *optional*):

202203

What device to load the optimizer state onto. Should be one of either "cpu" or "on_device".

204+

load_kwargs (`dict`, *optional*):

205+

Additional arguments that can be passed to the `load` function.

203206

load_model_func_kwargs (`dict`, *optional*):

204207

Additional arguments that can be passed to the model's `load_state_dict` method.

205208

@@ -217,6 +220,9 @@ def load_accelerator_state(

217220

elif map_location == "on_device":

218221

map_location = PartialState().device

219222223+

if load_kwargs is None:

224+

load_kwargs = {}

225+220226

input_dir = Path(input_dir)

221227

# Model states

222228

for i, model in enumerate(models):

@@ -235,15 +241,15 @@ def load_accelerator_state(

235241

for i, opt in enumerate(optimizers):

236242

optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"

237243

input_optimizer_file = input_dir.joinpath(optimizer_name)

238-

optimizer_state = load(input_optimizer_file, map_location=map_location)

244+

optimizer_state = load(input_optimizer_file, map_location=map_location, **load_kwargs)

239245

optimizers[i].load_state_dict(optimizer_state)

240246

logger.info("All optimizer states loaded successfully")

241247242248

# Scheduler states

243249

for i, scheduler in enumerate(schedulers):

244250

scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"

245251

input_scheduler_file = input_dir.joinpath(scheduler_name)

246-

scheduler_state = load(input_scheduler_file)

252+

scheduler_state = load(input_scheduler_file, **load_kwargs)

247253

scheduler.load_state_dict(scheduler_state)

248254

logger.info("All scheduler states loaded successfully")

249255

@@ -261,7 +267,7 @@ def load_accelerator_state(

261267

dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"

262268

input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name)

263269

if input_dataloader_state_dict_file.exists():

264-

state_dict = load(input_dataloader_state_dict_file)

270+

state_dict = load(input_dataloader_state_dict_file, **load_kwargs)

265271

dataloader.load_state_dict(state_dict)

266272

logger.info("All dataloader sampler states loaded successfully")

267273