hotfix: only validate stable diffusion when using threefry (#5166)

This commit is contained in:
chenyu 2024-06-26 16:50:38 -04:00 committed by GitHub
parent e4a5870b36
commit 0ba093dea0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 2 deletions

View File

@ -9,7 +9,7 @@ from collections import namedtuple
from PIL import Image
import numpy as np
from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm, THREEFRY
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
@ -643,7 +643,7 @@ if __name__ == "__main__":
if not args.noshow: im.show()
# validation!
if args.prompt == default_prompt and args.steps == 10 and args.seed == 0 and args.guidance == 7.5:
if args.prompt == default_prompt and args.steps == 10 and args.seed == 0 and args.guidance == 7.5 and THREEFRY:
ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "stable_diffusion_seed0.png")))
distance = (((x - ref_image).cast(dtypes.float) / ref_image.max())**2).mean().item()
assert distance < 3e-4, colored(f"validation failed with {distance=}", "red")