[XPU] Implemented 8bit optimizers in triton by Egor-Krivov · Pull Request #1692 · bitsandbytes-foundation/bitsandbytes

Benchmarking on 4096*4096 shapes I get about 2x performance gain from 8bit optimizers on GPU Max 1100.

Torch is 32bit optimizer from torch, BNB is 8bit optimizer:

Torch step (eager): 2.972ms
BNB step: 1.325ms
Torch step (eager): 2.954ms
BNB step: 1.308ms
Torch step (eager): 2.985ms
BNB step: 1.290ms
Torch step (eager): 2.957ms
BNB step: 1.283ms
Torch step (eager): 2.951ms
BNB step: 1.320ms
Torch step (eager): 2.943ms
BNB step: 1.250ms
Torch step (eager): 2.959ms
BNB step: 1.318ms

For small shapes difference is smaller (1024*9):

Torch step (eager): 0.257ms
BNB step: 0.256ms
Torch step (eager): 0.253ms
BNB step: 0.260ms

benchmark is based on optimizer tests:

import os
from os.path import join
import shutil
import time
import uuid

from lion_pytorch import Lion
import torch

import bitsandbytes as bnb
import bitsandbytes.functional as F
from bitsandbytes.utils import sync_gpu

# optim_name = "momentum8bit_blockwise"
# optim_name = "rmsprop8bit_blockwise"
# optim_name = "adagrad8bit_blockwise"
# optim_name = "adam8bit_blockwise"
optim_name = "ademamix8bit_blockwise"
# optim_name = "lion8bit_blockwise"

str2optimizers = {}

k = 20
## TODO: maybe remove these three.
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
# str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
str2optimizers["momentum_pytorch"] = (
    None,
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    bnb.optim.Adam,
)

str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
str2optimizers["paged_adam8bit_blockwise"] = (
    torch.optim.Adam,
    lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True),
)
str2optimizers["paged_adamw8bit_blockwise"] = (
    torch.optim.AdamW,
    lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True),
)

str2optimizers["ademamix"] = (bnb.optim.ademamix._ReferenceAdEMAMix, bnb.optim.AdEMAMix)
str2optimizers["ademamix8bit_blockwise"] = (
    bnb.optim.ademamix._ReferenceAdEMAMix,
    lambda pxx: bnb.optim.AdEMAMix8bit(pxx),
)
str2optimizers["paged_ademamix"] = (bnb.optim.ademamix._ReferenceAdEMAMix, bnb.optim.PagedAdEMAMix)
str2optimizers["paged_ademamix8bit_blockwise"] = (
    bnb.optim.ademamix._ReferenceAdEMAMix,
    lambda pxx: bnb.optim.PagedAdEMAMix8bit(pxx),
)
str2optimizers["ademamix_scheduled"] = (
    lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=k, t_beta3=k),
    lambda pxx: bnb.optim.AdEMAMix(pxx, t_alpha=k, t_beta3=k),
)
str2optimizers["paged_ademamix_scheduled"] = (
    lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=k, t_beta3=k),
    lambda pxx: bnb.optim.PagedAdEMAMix(pxx, t_alpha=k, t_beta3=k),
)
str2optimizers["ademamix8bit_blockwise_scheduled"] = (
    lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=100, t_beta3=100),
    lambda pxx: bnb.optim.AdEMAMix8bit(pxx, t_alpha=100, t_beta3=100),
)
str2optimizers["paged_ademamix8bit_blockwise_scheduled"] = (
    lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=100, t_beta3=100),
    lambda pxx: bnb.optim.PagedAdEMAMix8bit(pxx, t_alpha=100, t_beta3=100),
)

str2optimizers["lion"] = (Lion, bnb.optim.Lion)
str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion)
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))

str2optimizers["momentum"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["momentum8bit_blockwise"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
)

str2optimizers["rmsprop"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit_blockwise"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
)


str2statenames = {}
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["lion"] = [("exp_avg", "state1")]
str2statenames["paged_lion"] = [("exp_avg", "state1")]
str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")]

str2statenames["adam8bit_blockwise"] = [
    ("exp_avg", "state1", "qmap1", "absmax1"),
    ("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["paged_adam8bit_blockwise"] = [
    ("exp_avg", "state1", "qmap1", "absmax1"),
    ("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["paged_adamw8bit_blockwise"] = [
    ("exp_avg", "state1", "qmap1", "absmax1"),
    ("exp_avg_sq", "state2", "qmap2", "absmax2"),
]

str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]

str2statenames["ademamix"] = str2statenames["ademamix_scheduled"] = [("m1_m2", "state1"), ("nu", "state2")]
str2statenames["paged_ademamix"] = str2statenames["paged_ademamix_scheduled"] = [("m1_m2", "state1"), ("nu", "state2")]
str2statenames["ademamix8bit_blockwise"] = str2statenames["ademamix8bit_blockwise_scheduled"] = [
    ("m1_m2", "state1", "qmap1", "absmax1"),
    ("nu", "state2", "qmap2", "absmax2"),
]
str2statenames["paged_ademamix8bit_blockwise"] = [
    ("m1_m2", "state1", "qmap1", "absmax1"),
    ("nu", "state2", "qmap2", "absmax2"),
]


gtype = [torch.float32, torch.float16, torch.bfloat16][2]

dim2 = 4096
dim1 = 4096

# dim2 = 1024
# dim1 = 1024

device = "xpu"

check_precision = True

gradient_scale = 0.01
# gradient_scale = 1


def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
    idx = torch.isclose(a, b, rtol=rtol, atol=atol)
    error_count = (idx == 0).sum().item()
    if error_count > max_error_count:
        print(f"Too many values not close: assert {error_count} < {max_error_count}")
        torch.testing.assert_close(a, b, rtol=rtol, atol=atol)


def get_temp_dir():
    path = f"/tmp/autoswap/{uuid.uuid4()}"
    os.makedirs(path, exist_ok=True)
    return path


def rm_path(path):
    shutil.rmtree(path)


def test_optimizer8bit(dim1, dim2, gtype, optim_name, device):
    torch.set_printoptions(precision=10)

    if dim1 == 1 and dim2 == 1:
        return

    p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1
    p2 = p1.clone()
    p1 = p1.float()

    blocksize = 256

    torch_optimizer = str2optimizers[optim_name][0]([p1])
    bnb_optimizer = str2optimizers[optim_name][1]([p2])

    if gtype == torch.float32:
        atol, rtol = 3e-3, 1e-3
        # atol, rtol = 5e-3, 1e-3
        patol, prtol = 1e-5, 1e-3
    elif gtype == torch.bfloat16:
        atol, rtol = 3e-3, 1e-3
        # atol, rtol = 5e-3, 1e-3
        patol, prtol = 1e-4, 1e-2
    else:
        atol, rtol = 3e-3, 1e-3
        # atol, rtol = 5e-3, 1e-3
        patol, prtol = 1e-5, 1e-3

    errors = []
    relerrors = []

    for i in range(50):
        g = torch.randn(dim1, dim2, device=device, dtype=gtype) * gradient_scale
        p1.grad = g.clone().float()
        p2.grad = g.clone()

        p_copy = p1.clone()

        sync_gpu(p1)
        start = time.time()
        torch_optimizer.step()
        sync_gpu(p1)
        stop = time.time()
        print(f"Torch step (eager): {1000 * (stop - start):.3f}ms")
        sync_gpu(p1)
        start = time.time()
        bnb_optimizer.step()
        sync_gpu(p1)
        stop = time.time()
        print(f"BNB step: {1000 * (stop - start):.3f}ms")

        # since Lion can have pretty noisy updates where things lie at the boundary
        if check_precision:
            assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)

        dequant_states = []
        for name1, name2, qmap, max_val in str2statenames[optim_name]:
            ## For AdEMAMix, we need to dequantize [p2][name2][0] and [p2][name2][1]
            ## separately and then stack them. The qmap is shared, but absmax is also stacked.
            # if optim_name == "ademamix8bit_blockwise" and name1 == "m1_m2":
            #     m1 = F.dequantize_blockwise(
            #         code=bnb_optimizer.state[p2][qmap],
            #         absmax=bnb_optimizer.state[p2][max_val][0],
            #         A=bnb_optimizer.state[p2][name2][0],
            #         blocksize=blocksize,
            #     )
            #     m2 = F.dequantize_blockwise(
            #         code=bnb_optimizer.state[p2][qmap],
            #         absmax=bnb_optimizer.state[p2][max_val][1],
            #         A=bnb_optimizer.state[p2][name2][1],
            #         blocksize=blocksize,
            #     )

            #     s1 = torch.stack((m1, m2))
            if True:
                s2 = F.dequantize_blockwise(
                    code=bnb_optimizer.state[p2][qmap],
                    absmax=bnb_optimizer.state[p2][max_val],
                    A=bnb_optimizer.state[p2][name2],
                    blocksize=blocksize,
                )

            code = bnb_optimizer.state[p2][qmap]
            absmax = bnb_optimizer.state[p2][max_val]
            A = bnb_optimizer.state[p2][name2]

            s1 = torch_optimizer.state[p1][name1]
            diff = s1 - s2
            # import pdb; pdb.set_trace()

            num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s2, atol=atol, rtol=rtol) == 0
            if check_precision:
                assert num_not_close.sum().item() < 20
            dequant_states.append(s2.clone())

        err = torch.abs(p1 - p2)
        relerr = err / (torch.abs(p1) + 1e-9)
        if g.dtype == torch.bfloat16 and check_precision:
            assert err.mean() <= 0.00017
            assert relerr.mean() <= 0.0016
        elif check_precision:
            assert err.mean() < 0.00006
            assert relerr.mean() < 0.0006

        errors.append(err.mean().item())
        relerrors.append(relerr.mean().item())

        if i % 10 == 0 and i > 0:
            for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
                s1cpy = s.clone()
                raws1cpy = bnb_optimizer.state[p2][name2].clone()
                qmap1 = bnb_optimizer.state[p2][qmap].clone()

                path = get_temp_dir()
                torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
                del bnb_optimizer
                bnb_optimizer = None
                bnb_optimizer = str2optimizers[optim_name][1]([p2])
                bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
                rm_path(path)
                torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2])
                torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap])

                ## For AdEMAMix, we need to dequantize [p2][name2][0] and [p2][name2][1]
                ## separately and then stack them. The qmap is shared, but absmax is also stacked.
                if optim_name == "ademamix8bit_blockwise" and name1 == "m1_m2":
                    s2 = torch.stack(
                        (
                            F.dequantize_blockwise(
                                code=bnb_optimizer.state[p2][qmap],
                                absmax=bnb_optimizer.state[p2][max_val][0],
                                A=bnb_optimizer.state[p2][name2][0],
                                blocksize=blocksize,
                            ),
                            F.dequantize_blockwise(
                                code=bnb_optimizer.state[p2][qmap],
                                absmax=bnb_optimizer.state[p2][max_val][1],
                                A=bnb_optimizer.state[p2][name2][1],
                                blocksize=blocksize,
                            ),
                        )
                    )
                else:
                    s2 = F.dequantize_blockwise(
                        code=bnb_optimizer.state[p2][qmap],
                        absmax=bnb_optimizer.state[p2][max_val],
                        A=bnb_optimizer.state[p2][name2],
                        blocksize=blocksize,
                    )

                torch.testing.assert_close(s1cpy, s2)

                num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s2, atol=atol, rtol=rtol) == 0
                if check_precision:
                    assert num_not_close.sum().item() < 20

            # Lion can have pretty noisy updates where things lie at the boundary
            assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)

        # the parameters diverge quickly. Here we keep them close
        # together so we can test against the Adam error
        p1.data = p1.data.to(gtype).float()
        p2.copy_(p1.data)
        torch.testing.assert_close(p1.to(gtype), p2)
        for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
            torch_optimizer.state[p1][name1].copy_(s.data)


test_optimizer8bit(dim1, dim2, gtype, optim_name, device)

``