validate stable_diffusion output (#5163)

changed default steps, forgot to update validation
This commit is contained in:
chenyu 2024-06-26 16:42:21 -04:00 committed by GitHub
parent 21b225ac45
commit e4a5870b36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -643,7 +643,7 @@ if __name__ == "__main__":
if not args.noshow: im.show()
# validation!
if args.prompt == default_prompt and args.steps == 5 and args.seed == 0 and args.guidance == 7.5:
if args.prompt == default_prompt and args.steps == 10 and args.seed == 0 and args.guidance == 7.5:
ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "stable_diffusion_seed0.png")))
distance = (((x - ref_image).cast(dtypes.float) / ref_image.max())**2).mean().item()
assert distance < 3e-4, colored(f"validation failed with {distance=}", "red")