add support for custom function for reducing the batch size by winglian · Pull Request #3071 · huggingface/accelerate

What does this PR do?

when using auto_find_batch_size in transformers, there is no way to compensate with gradient_accumulation_steps or learning rate when a new batch size is tried. This adds the ability to use a callable function that transformers can set to manage updating the TrainingArgs.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@muellerzr
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.