From 0ba093dea06299dfd5c0d13f88b8b37276f6b37e Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 26 Jun 2024 16:50:38 -0400 Subject: [PATCH] hotfix: only validate stable diffusion when using threefry (#5166) --- examples/stable_diffusion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index a5e6e3b0..87e4bcaa 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -9,7 +9,7 @@ from collections import namedtuple from PIL import Image import numpy as np from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit -from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm +from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm, THREEFRY from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict @@ -643,7 +643,7 @@ if __name__ == "__main__": if not args.noshow: im.show() # validation! - if args.prompt == default_prompt and args.steps == 10 and args.seed == 0 and args.guidance == 7.5: + if args.prompt == default_prompt and args.steps == 10 and args.seed == 0 and args.guidance == 7.5 and THREEFRY: ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "stable_diffusion_seed0.png"))) distance = (((x - ref_image).cast(dtypes.float) / ref_image.max())**2).mean().item() assert distance < 3e-4, colored(f"validation failed with {distance=}", "red")