From e4a5870b36e34a1750e4a58d81793232735be003 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 26 Jun 2024 16:42:21 -0400 Subject: [PATCH] validate stable_diffusion output (#5163) changed default steps, forgot to update validation --- examples/stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 5b1df9a1..a5e6e3b0 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -643,7 +643,7 @@ if __name__ == "__main__": if not args.noshow: im.show() # validation! - if args.prompt == default_prompt and args.steps == 5 and args.seed == 0 and args.guidance == 7.5: + if args.prompt == default_prompt and args.steps == 10 and args.seed == 0 and args.guidance == 7.5: ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "stable_diffusion_seed0.png"))) distance = (((x - ref_image).cast(dtypes.float) / ref_image.max())**2).mean().item() assert distance < 3e-4, colored(f"validation failed with {distance=}", "red")