[XPU] Implemented 8bit optimizers in triton by Egor-Krivov · Pull Request #1692 · bitsandbytes-foundation/bitsandbytes
Benchmarking on 4096*4096 shapes I get about 2x performance gain from 8bit optimizers on GPU Max 1100.
Torch is 32bit optimizer from torch, BNB is 8bit optimizer:
Torch step (eager): 2.972ms
BNB step: 1.325ms
Torch step (eager): 2.954ms
BNB step: 1.308ms
Torch step (eager): 2.985ms
BNB step: 1.290ms
Torch step (eager): 2.957ms
BNB step: 1.283ms
Torch step (eager): 2.951ms
BNB step: 1.320ms
Torch step (eager): 2.943ms
BNB step: 1.250ms
Torch step (eager): 2.959ms
BNB step: 1.318ms
For small shapes difference is smaller (1024*9):
Torch step (eager): 0.257ms
BNB step: 0.256ms
Torch step (eager): 0.253ms
BNB step: 0.260ms
benchmark is based on optimizer tests:
import os
from os.path import join
import shutil
import time
import uuid
from lion_pytorch import Lion
import torch
import bitsandbytes as bnb
import bitsandbytes.functional as F
from bitsandbytes.utils import sync_gpu
# optim_name = "momentum8bit_blockwise"
# optim_name = "rmsprop8bit_blockwise"
# optim_name = "adagrad8bit_blockwise"
# optim_name = "adam8bit_blockwise"
optim_name = "ademamix8bit_blockwise"
# optim_name = "lion8bit_blockwise"
str2optimizers = {}
k = 20
## TODO: maybe remove these three.
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
# str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
str2optimizers["momentum_pytorch"] = (
None,
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
bnb.optim.Adam,
)
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
str2optimizers["paged_adam8bit_blockwise"] = (
torch.optim.Adam,
lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True),
)
str2optimizers["paged_adamw8bit_blockwise"] = (
torch.optim.AdamW,
lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True),
)
str2optimizers["ademamix"] = (bnb.optim.ademamix._ReferenceAdEMAMix, bnb.optim.AdEMAMix)
str2optimizers["ademamix8bit_blockwise"] = (
bnb.optim.ademamix._ReferenceAdEMAMix,
lambda pxx: bnb.optim.AdEMAMix8bit(pxx),
)
str2optimizers["paged_ademamix"] = (bnb.optim.ademamix._ReferenceAdEMAMix, bnb.optim.PagedAdEMAMix)
str2optimizers["paged_ademamix8bit_blockwise"] = (
bnb.optim.ademamix._ReferenceAdEMAMix,
lambda pxx: bnb.optim.PagedAdEMAMix8bit(pxx),
)
str2optimizers["ademamix_scheduled"] = (
lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=k, t_beta3=k),
lambda pxx: bnb.optim.AdEMAMix(pxx, t_alpha=k, t_beta3=k),
)
str2optimizers["paged_ademamix_scheduled"] = (
lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=k, t_beta3=k),
lambda pxx: bnb.optim.PagedAdEMAMix(pxx, t_alpha=k, t_beta3=k),
)
str2optimizers["ademamix8bit_blockwise_scheduled"] = (
lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=100, t_beta3=100),
lambda pxx: bnb.optim.AdEMAMix8bit(pxx, t_alpha=100, t_beta3=100),
)
str2optimizers["paged_ademamix8bit_blockwise_scheduled"] = (
lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=100, t_beta3=100),
lambda pxx: bnb.optim.PagedAdEMAMix8bit(pxx, t_alpha=100, t_beta3=100),
)
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion)
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))
str2optimizers["momentum"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["momentum8bit_blockwise"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
)
str2optimizers["rmsprop"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit_blockwise"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
)
str2statenames = {}
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["lion"] = [("exp_avg", "state1")]
str2statenames["paged_lion"] = [("exp_avg", "state1")]
str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")]
str2statenames["adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["paged_adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["paged_adamw8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
str2statenames["ademamix"] = str2statenames["ademamix_scheduled"] = [("m1_m2", "state1"), ("nu", "state2")]
str2statenames["paged_ademamix"] = str2statenames["paged_ademamix_scheduled"] = [("m1_m2", "state1"), ("nu", "state2")]
str2statenames["ademamix8bit_blockwise"] = str2statenames["ademamix8bit_blockwise_scheduled"] = [
("m1_m2", "state1", "qmap1", "absmax1"),
("nu", "state2", "qmap2", "absmax2"),
]
str2statenames["paged_ademamix8bit_blockwise"] = [
("m1_m2", "state1", "qmap1", "absmax1"),
("nu", "state2", "qmap2", "absmax2"),
]
gtype = [torch.float32, torch.float16, torch.bfloat16][2]
dim2 = 4096
dim1 = 4096
# dim2 = 1024
# dim1 = 1024
device = "xpu"
check_precision = True
gradient_scale = 0.01
# gradient_scale = 1
def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
idx = torch.isclose(a, b, rtol=rtol, atol=atol)
error_count = (idx == 0).sum().item()
if error_count > max_error_count:
print(f"Too many values not close: assert {error_count} < {max_error_count}")
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
def get_temp_dir():
path = f"/tmp/autoswap/{uuid.uuid4()}"
os.makedirs(path, exist_ok=True)
return path
def rm_path(path):
shutil.rmtree(path)
def test_optimizer8bit(dim1, dim2, gtype, optim_name, device):
torch.set_printoptions(precision=10)
if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1
p2 = p1.clone()
p1 = p1.float()
blocksize = 256
torch_optimizer = str2optimizers[optim_name][0]([p1])
bnb_optimizer = str2optimizers[optim_name][1]([p2])
if gtype == torch.float32:
atol, rtol = 3e-3, 1e-3
# atol, rtol = 5e-3, 1e-3
patol, prtol = 1e-5, 1e-3
elif gtype == torch.bfloat16:
atol, rtol = 3e-3, 1e-3
# atol, rtol = 5e-3, 1e-3
patol, prtol = 1e-4, 1e-2
else:
atol, rtol = 3e-3, 1e-3
# atol, rtol = 5e-3, 1e-3
patol, prtol = 1e-5, 1e-3
errors = []
relerrors = []
for i in range(50):
g = torch.randn(dim1, dim2, device=device, dtype=gtype) * gradient_scale
p1.grad = g.clone().float()
p2.grad = g.clone()
p_copy = p1.clone()
sync_gpu(p1)
start = time.time()
torch_optimizer.step()
sync_gpu(p1)
stop = time.time()
print(f"Torch step (eager): {1000 * (stop - start):.3f}ms")
sync_gpu(p1)
start = time.time()
bnb_optimizer.step()
sync_gpu(p1)
stop = time.time()
print(f"BNB step: {1000 * (stop - start):.3f}ms")
# since Lion can have pretty noisy updates where things lie at the boundary
if check_precision:
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)
dequant_states = []
for name1, name2, qmap, max_val in str2statenames[optim_name]:
## For AdEMAMix, we need to dequantize [p2][name2][0] and [p2][name2][1]
## separately and then stack them. The qmap is shared, but absmax is also stacked.
# if optim_name == "ademamix8bit_blockwise" and name1 == "m1_m2":
# m1 = F.dequantize_blockwise(
# code=bnb_optimizer.state[p2][qmap],
# absmax=bnb_optimizer.state[p2][max_val][0],
# A=bnb_optimizer.state[p2][name2][0],
# blocksize=blocksize,
# )
# m2 = F.dequantize_blockwise(
# code=bnb_optimizer.state[p2][qmap],
# absmax=bnb_optimizer.state[p2][max_val][1],
# A=bnb_optimizer.state[p2][name2][1],
# blocksize=blocksize,
# )
# s1 = torch.stack((m1, m2))
if True:
s2 = F.dequantize_blockwise(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
blocksize=blocksize,
)
code = bnb_optimizer.state[p2][qmap]
absmax = bnb_optimizer.state[p2][max_val]
A = bnb_optimizer.state[p2][name2]
s1 = torch_optimizer.state[p1][name1]
diff = s1 - s2
# import pdb; pdb.set_trace()
num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s2, atol=atol, rtol=rtol) == 0
if check_precision:
assert num_not_close.sum().item() < 20
dequant_states.append(s2.clone())
err = torch.abs(p1 - p2)
relerr = err / (torch.abs(p1) + 1e-9)
if g.dtype == torch.bfloat16 and check_precision:
assert err.mean() <= 0.00017
assert relerr.mean() <= 0.0016
elif check_precision:
assert err.mean() < 0.00006
assert relerr.mean() < 0.0006
errors.append(err.mean().item())
relerrors.append(relerr.mean().item())
if i % 10 == 0 and i > 0:
for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
s1cpy = s.clone()
raws1cpy = bnb_optimizer.state[p2][name2].clone()
qmap1 = bnb_optimizer.state[p2][qmap].clone()
path = get_temp_dir()
torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
del bnb_optimizer
bnb_optimizer = None
bnb_optimizer = str2optimizers[optim_name][1]([p2])
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path)
torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2])
torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap])
## For AdEMAMix, we need to dequantize [p2][name2][0] and [p2][name2][1]
## separately and then stack them. The qmap is shared, but absmax is also stacked.
if optim_name == "ademamix8bit_blockwise" and name1 == "m1_m2":
s2 = torch.stack(
(
F.dequantize_blockwise(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val][0],
A=bnb_optimizer.state[p2][name2][0],
blocksize=blocksize,
),
F.dequantize_blockwise(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val][1],
A=bnb_optimizer.state[p2][name2][1],
blocksize=blocksize,
),
)
)
else:
s2 = F.dequantize_blockwise(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
blocksize=blocksize,
)
torch.testing.assert_close(s1cpy, s2)
num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s2, atol=atol, rtol=rtol) == 0
if check_precision:
assert num_not_close.sum().item() < 20
# Lion can have pretty noisy updates where things lie at the boundary
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)
# the parameters diverge quickly. Here we keep them close
# together so we can test against the Adam error
p1.data = p1.data.to(gtype).float()
p2.copy_(p1.data)
torch.testing.assert_close(p1.to(gtype), p2)
for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
torch_optimizer.state[p1][name1].copy_(s.data)
test_optimizer8bit(dim1, dim2, gtype, optim_name, device)
``