Skip to content
/ daam Public
forked from castorini/daam

Diffusion attentive attribution maps for interpreting Stable Diffusion.

License

Notifications You must be signed in to change notification settings

rockerBOO/daam

 
 

Repository files navigation

What the DAAM: Interpreting Stable Diffusion Using Cross Attention

HF Spaces Citation PyPi version Downloads

example image

Updated with aspect ratio support

Install PyTorch for your platform. Install DAAM with git clone https://github.com/rockerBOO/daam && pip install -e daam.

Updated to support Diffusers 0.16.1!

I regularly update this codebase. Please submit an issue if you have any questions.

In our paper, we propose diffusion attentive attribution maps (DAAM), a cross attention-based approach for interpreting Stable Diffusion. Check out our demo: https://huggingface.co/spaces/tetrisd/Diffusion-Attentive-Attribution-Maps. See our documentation, hosted by GitHub pages, and our Colab notebook, updated for v0.1.0.

Getting Started

First, install PyTorch for your platform. Then, install DAAM with git clone https://github.com/rockerBOO/daam && pip install -e daam. Finally, login using huggingface-cli login to get many stable diffusion models -- you'll need to get a token at HuggingFace.co.

Running the Website Demo

Simply run daam-demo in a shell and navigate to http://localhost:8080. The same demo as the one on HuggingFace Spaces will show up.

Using DAAM as a CLI Utility

DAAM comes with a simple generation script for people who want to quickly try it out. Try running

$ mkdir -p daam-test && cd daam-test
$ daam "A dog running across the field."
$ ls
a.heat_map.png    field.heat_map.png    generation.pt   output.png  seed.txt
dog.heat_map.png  running.heat_map.png  prompt.txt

Your current working directory will now contain the generated image as output.png and a DAAM map for every word, as well as some auxiliary data. You can see more options for daam by running daam -h.

Using DAAM as a Library

Import and use DAAM as follows:

from daam import trace, set_seed
from diffusers import StableDiffusionPipeline
from matplotlib import pyplot as plt
import torch


model_id = 'stabilityai/stable-diffusion-2-base'
device = 'cuda'

pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True)
pipe = pipe.to(device)

prompt = 'A dog runs across the field'
gen = set_seed(0)  # for reproducibility

with torch.cuda.amp.autocast(dtype=torch.float16), torch.no_grad():
    with trace(pipe) as tc:
        out = pipe(prompt, num_inference_steps=30, generator=gen)
        heat_map = tc.compute_global_heat_map()
        heat_map = heat_map.compute_word_heat_map('dog')
        heat_map.plot_overlay(out.images[0])
        plt.show()

You can also serialize and deserialize the DAAM maps pretty easily:

from daam import GenerationExperiment, trace

with trace(pipe) as tc:
    pipe('A dog and a cat')
    exp = tc.to_experiment('experiment-dir')
    exp.save()  # experiment-dir now contains all the data and heat maps

exp = GenerationExperiment.load('experiment-dir')  # load the experiment

We'll continue adding docs. In the meantime, check out the GenerationExperiment, GlobalHeatMap, and DiffusionHeatMapHooker classes, as well as the daam/run/*.py example scripts. You can download the COCO-Gen dataset from the paper at http://ralphtang.com/coco-gen.tar.gz. If clicking the link doesn't work on your browser, copy and paste it in a new tab, or use a CLI utility such as wget.

See Also

Citation

@inproceedings{tang2023daam,
    title = "What the {DAAM}: Interpreting Stable Diffusion Using Cross Attention",
    author = "Tang, Raphael  and
      Liu, Linqing  and
      Pandey, Akshat  and
      Jiang, Zhiying  and
      Yang, Gefei  and
      Kumar, Karun  and
      Stenetorp, Pontus  and
      Lin, Jimmy  and
      Ture, Ferhan",
    booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
    year = "2023",
    url = "https://aclanthology.org/2023.acl-long.310",
}

About

Diffusion attentive attribution maps for interpreting Stable Diffusion.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Jupyter Notebook 94.8%
  • Python 5.2%