add support for custom function for reducing the batch size by winglian · Pull Request #3071 · huggingface/accelerate
What does this PR do?
when using auto_find_batch_size in transformers, there is no way to compensate with gradient_accumulation_steps or learning rate when a new batch size is tried. This adds the ability to use a callable function that transformers can set to manage updating the TrainingArgs.
Fixes # (issue)
Before submitting
- This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- Did you read the contributor guideline,
Pull Request section? - Was this discussed/approved via a Github issue or the forum? Please add a link
to it if that's the case. - Did you make sure to update the documentation with your changes? Here are the
documentation guidelines, and
here are tips on formatting docstrings. - Did you write any new necessary tests?
Who can review?
@muellerzr
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.