GitHub - SusungHong/Self-Attention-Guidance: Official implementation of the paper "Improving Sample Quality of Diffusion Models Using Self-Attention Guidance" (ICCV`23)

Self-Attention Diffusion Guidance (ICCV`23)

image This is the implementation of the paper Improving Sample Quality of Diffusion Models Using Self-Attention Guidance by Hong et al. To gain insight from our exploration of the self-attention maps of diffusion models and for detailed explanations, please see our Paper and Project Page.

This repository is based on openai/guided-diffusion, and we modified feature extraction code from yandex-research/ddpm-segmentation to get the self-attention maps. The major implementation of our method is in ./guided_diffusion/gaussian_diffusion.py and ./guided_diffusion/unet.py.

All you need is to setup the environment, download existing models, and sample from them using our implementation. Neither further training nor a dataset is needed to apply self-attention guidance!

Updates

2023-08-14: This repository supports DDIM sampling with SAG.

2023-02-19: The Gradio Demo๐Ÿค— of SAG for Stable Diffusion is now available

2023-02-16: The Stable Diffusion pipeline of SAG is now available at huggingface/diffusers ๐Ÿค—๐Ÿงจ

2023-02-01: The demo for Stable Diffusion is now available in Colab.

Environment

  • Python 3.8, PyTorch 1.11.0
  • 8 x NVIDIA RTX 3090 (set backend="gloo" in ./guided_diffusion/dist_util.py if P2P access is not available)
git clone https://github.com/KU-CVLAB/Self-Attention-Guidance
conda create -n sag python=3.8 anaconda
conda activate sag
conda install mpi4py
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install blobfile

Downloading Pretrained Diffusion Models (and Classifiers for CG)

Pretrained weights for ImageNet and LSUN can be downloaded from the repository. Download and place them in the ./models/ directory.

Sampling from Pretrained Diffusion Models

You can sample from pretrained diffusion models with self-attention guidance by changing SAG_FLAGS in the following commands. Note that sampling with --guide_scale 1.0 means sampling without self-attention guidance. Below are the 4 examples.

  • ImageNet 128x128 model (--classifier_guidance False deactivates classifier guidance):
SAMPLE_FLAGS="--batch_size 64 --num_samples 10000 --timestep_respacing 250"
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --image_size 128 --learn_sigma True --noise_schedule linear --num_channels 256 --num_heads 4 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
SAG_FLAGS="--guide_scale 1.1 --guide_start 250 --sel_attn_block output --sel_attn_depth 8 --blur_sigma 3 --classifier_guidance True"
mpiexec -n $NUM_GPUS python classifier_sample.py $SAG_FLAGS $MODEL_FLAGS --classifier_scale 0.5 --classifier_path models/128x128_classifier.pt --model_path models/128x128_diffusion.pt $SAMPLE_FLAGS
  • ImageNet 256x256 model (--class_cond True for conditional models):
SAMPLE_FLAGS="--batch_size 16 --num_samples 10000 --timestep_respacing 250"
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
SAG_FLAGS="--guide_scale 1.5 --guide_start 250 --sel_attn_block output --sel_attn_depth 2 --blur_sigma 9 --classifier_guidance False"
mpiexec -n $NUM_GPUS python classifier_sample.py $SAG_FLAGS $MODEL_FLAGS --classifier_scale 0.0 --classifier_path models/256x256_classifier.pt --model_path models/256x256_diffusion_uncond.pt $SAMPLE_FLAGS
  • LSUN Cat model (respaced to 250 steps):
SAMPLE_FLAGS="--batch_size 16 --num_samples 10000 --timestep_respacing 250"
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --dropout 0.1 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
SAG_FLAGS="--guide_scale 1.05 --guide_start 250 --sel_attn_block output --sel_attn_depth 2 --blur_sigma 9 --classifier_guidance False"
mpiexec -n $NUM_GPUS python image_sample.py $SAG_FLAGS $MODEL_FLAGS --model_path models/lsun_cat.pt $SAMPLE_FLAGS
  • LSUN Horse model (respaced to 250 steps):
SAMPLE_FLAGS="--batch_size 16 --num_samples 10000 --timestep_respacing 250"
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --dropout 0.1 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
SAG_FLAGS="--guide_scale 1.01 --guide_start 250 --sel_attn_block output --sel_attn_depth 2 --blur_sigma 9 --classifier_guidance False"
mpiexec -n $NUM_GPUS python image_sample.py $SAG_FLAGS $MODEL_FLAGS --model_path models/lsun_horse.pt $SAMPLE_FLAGS
  • ImageNet 128x128 model (DDIM 25 steps):
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --image_size 128 --learn_sigma True --num_channels 256 --num_heads 4 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
CLASSIFIER_FLAGS="--image_size 128 --classifier_attention_resolutions 32,16,8 --classifier_depth 2 --classifier_width 128 --classifier_pool attention --classifier_resblock_updown True --classifier_use_scale_shift_norm True --classifier_scale 1.0 --classifier_use_fp16 True"
SAMPLE_FLAGS="--batch_size 8 --num_samples 8 --timestep_respacing ddim25 --use_ddim True"
SAG_FLAGS="--guide_scale 1.1 --guide_start 25 --sel_attn_block output --sel_attn_depth 8 --blur_sigma 3 --classifier_guidance True"
mpiexec -n $NUM_GPUS python classifier_sample.py \
    --model_path models/128x128_diffusion.pt \
    --classifier_path models/128x128_classifier.pt \
    $MODEL_FLAGS $CLASSIFIER_FLAGS $SAMPLE_FLAGS $SAG_FLAGS

Results

Compatibility of self-attention guidance (SAG) and classifier guidance (CG) on ImageNet 128x128 model:

SAG CG FID sFID Precision Recall
5.91 5.09 0.70 0.65
V 2.97 5.09 0.78 0.59
V 5.11 4.09 0.72 0.65
V V 2.58 4.35 0.79 0.59

Results on pretrained models:

Model # of steps Self-attention guidance scale FID sFID IS Precision Recall
ImageNet 256ร—256 (Uncond.) 250 0.0 (baseline)
0.5
0.8
26.21
20.31
20.08
6.35
5.09
5.77
39.70
45.30
45.56
0.61
0.66
0.68
0.63
0.61
0.59
ImageNet 256ร—256 (Cond.) 250 0.0 (baseline)
0.2
10.94
9.41
6.02
5.28
100.98
104.79
0.69
0.70
0.63
0.62
LSUN Cat 256ร—256 250 0.0 (baseline)
0.05
7.03
6.87
8.24
8.21
-
-
0.60
0.60
0.53
0.50
LSUN Horse 256ร—256 250 0.0 (baseline)
0.01
3.45
3.43
7.55
7.51
-
-
0.68
0.68
0.56
0.55

Cite as

@inproceedings{hong2023improving,
  title={Improving sample quality of diffusion models using self-attention guidance},
  author={Hong, Susung and Lee, Gyuseong and Jang, Wooseok and Kim, Seungryong},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  pages={7462--7471},
  year={2023}
}