mirror of https://github.com/commaai/tinygrad.git
hotfix: only validate stable diffusion when using threefry (#5166)
This commit is contained in:
parent
e4a5870b36
commit
0ba093dea0
|
@ -9,7 +9,7 @@ from collections import namedtuple
|
|||
from PIL import Image
|
||||
import numpy as np
|
||||
from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit
|
||||
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm
|
||||
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm, THREEFRY
|
||||
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
|
||||
|
@ -643,7 +643,7 @@ if __name__ == "__main__":
|
|||
if not args.noshow: im.show()
|
||||
|
||||
# validation!
|
||||
if args.prompt == default_prompt and args.steps == 10 and args.seed == 0 and args.guidance == 7.5:
|
||||
if args.prompt == default_prompt and args.steps == 10 and args.seed == 0 and args.guidance == 7.5 and THREEFRY:
|
||||
ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "stable_diffusion_seed0.png")))
|
||||
distance = (((x - ref_image).cast(dtypes.float) / ref_image.max())**2).mean().item()
|
||||
assert distance < 3e-4, colored(f"validation failed with {distance=}", "red")
|
||||
|
|
Loading…
Reference in New Issue