[XPU] Implemented 32bit optimizers in triton by YangKai0616 · Pull Request #1710 · bitsandbytes-foundation/bitsandbytes

Depends on #1692.

Implemented 32bit optimizers in triton to use of XPU devices.

The PR includes two implementations:

  1. Pure Torch implementation: utilizing torch.compile
  2. Pure Triton implementation: utilizing triton.jit

For the benchmarking on 4096*4096 shapes, the results are as follows:

Pure Torch implementation:

Torch step (eager): 1.075ms
BNB step: 0.516ms
Torch step (eager): 1.058ms
BNB step: 0.517ms
Torch step (eager): 1.080ms
BNB step: 0.527ms
Torch step (eager): 1.069ms
BNB step: 0.539ms
Torch step (eager): 1.034ms
BNB step: 0.526ms

Pure Triton implementation:

Torch step (eager): 1.034ms
BNB step: 0.524ms
Torch step (eager): 1.054ms
BNB step: 0.488ms
Torch step (eager): 1.031ms
BNB step: 0.526ms
Torch step (eager): 1.047ms
BNB step: 0.538ms
Torch step (eager): 1.045ms
BNB step: 0.489ms

For the benchmarking on 1024*1024 shapes, the results are as follows:
Pure Torch implementation:

Torch step (eager): 0.345ms
BNB step: 0.335ms
Torch step (eager): 0.354ms
BNB step: 0.226ms
Torch step (eager): 0.347ms
BNB step: 0.227ms
Torch step (eager): 0.358ms
BNB step: 0.232ms
Torch step (eager): 0.349ms
BNB step: 0.225ms

Pure Triton implementation:

Torch step (eager): 0.346ms
BNB step: 0.226ms
Torch step (eager): 0.337ms
BNB step: 0.216ms
Torch step (eager): 0.338ms
BNB step: 0.215ms
Torch step (eager): 0.333ms
BNB step: 0.226ms
Torch step (eager): 0.349ms
BNB step: 0.235ms

The test platform is Intel(R) Data Center GPU Max 1550. Test script reference #1692. Torch(eager) is 32bit optimizer from torch, BNB is 32bit optimizer.

Considering that the performance gap between torch.compile and Triton implementations is not significant, but triton's implementation compiles faster, and #1692 was implemented with Triton, this PR adopts the Triton version for submission.

Note:Currently, XPU does not support the allocation of memory buffers using a paging mechanism. Therefore, these tests are skipped in tests/test_optim.py::test_optimizer32bit. This functionality will be developed in the future to support full optimizer capabilities.