validate stable diffusion for seed 0 (#2773)

* validate stable diffusion for seed 0

the closest false positive i can get is with the setup and one less step. dist = 0.0036
same setup with fp16 has dist=5e-6.
so setting validation threshold to 1e-4 should be good

* run with --seed 0
This commit is contained in:
chenyu 2023-12-15 00:07:09 -05:00 committed by GitHub
parent 9afa8009c1
commit a044125c39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 6 deletions

View File

@ -33,7 +33,7 @@ jobs:
- name: Run Tensor Core GEMM
run: DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
- name: Run Stable Diffusion
run: python3 examples/stable_diffusion.py --noshow --timing | tee sd.txt
run: python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
- name: Run LLaMA
run: |
JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
@ -118,7 +118,7 @@ jobs:
- name: Run Tensor Core GEMM
run: HIP=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
- name: Run Stable Diffusion
run: python3 examples/stable_diffusion.py --noshow --timing | tee sd.txt
run: python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
- name: Run LLaMA
run: |
JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt

View File

@ -6,10 +6,12 @@ import gzip, argparse, math, re
from functools import lru_cache
from collections import namedtuple
from PIL import Image
import numpy as np
from tqdm import tqdm
from tinygrad.tensor import Tensor
from tinygrad import Device
from tinygrad.helpers import dtypes, GlobalCounters, Timing, Context, getenv, fetch
from tinygrad.helpers import dtypes, GlobalCounters, Timing, Context, getenv, fetch, colored
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
from tinygrad.jit import TinyJit
@ -574,9 +576,10 @@ class StableDiffusion:
# cond_stage_model.transformer.text_model
if __name__ == "__main__":
default_prompt = "a horse sized cat eating a bagel"
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--steps', type=int, default=5, help="Number of steps in diffusion")
parser.add_argument('--prompt', type=str, default="a horse sized cat eating a bagel", help="Phrase to render")
parser.add_argument('--prompt', type=str, default=default_prompt, help="Phrase to render")
parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
parser.add_argument('--noshow', action='store_true', help="Don't show the image")
parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
@ -636,10 +639,15 @@ if __name__ == "__main__":
print(x.shape)
# save image
from PIL import Image
import numpy as np
im = Image.fromarray(x.numpy().astype(np.uint8, copy=False))
print(f"saving {args.out}")
im.save(args.out)
# Open image.
if not args.noshow: im.show()
# validation!
if args.prompt == default_prompt and args.steps == 5 and args.seed == 0 and args.guidance == 7.5:
ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "stable_diffusion_seed0.png")))
distance = (((x - ref_image).cast(dtypes.float) / ref_image.max())**2).mean().item()
assert distance < 1e-4, f"validation failed with {distance=}"
print(colored(f"output validated with {distance=}", "green"))

Binary file not shown.

After

Width:  |  Height:  |  Size: 479 KiB