validate sdxl output and put it in benchmark (#5211)

* validate sdxl output and put it in benchmark

* don't print fetch progress_bar in CI
This commit is contained in:
chenyu 2024-06-28 11:40:52 -04:00 committed by GitHub
parent 63fa4e2a0e
commit 7090eac8cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 22 additions and 4 deletions

View File

@ -39,6 +39,8 @@ jobs:
# run: echo "RUN_PROCESS_REPLAY=1" >> $GITHUB_ENV
- name: Run Stable Diffusion
run: JIT=2 THREEFRY=1 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
- name: Run SDXL
run: JIT=2 THREEFRY=1 python3 examples/sdxl.py --seed 0 --noshow | tee sdxl.txt
- name: Run model inference benchmark
run: METAL=1 python3 test/external/external_model_benchmark.py
- name: Test speed vs torch
@ -103,6 +105,7 @@ jobs:
matmul.txt
matmul_half.txt
sd.txt
sdxl.txt
beautiful_mnist.txt
train_cifar.txt
train_cifar_half.txt
@ -156,6 +159,8 @@ jobs:
run: CUDA=1 PTX=1 M_START=12 M_STOP=20 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 python3 ./extra/gemm/fuzz_matmul.py
- name: Run Stable Diffusion
run: NV=1 THREEFRY=1 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
- name: Run SDXL
run: NV=1 THREEFRY=1 python3 examples/sdxl.py --seed 0 --noshow | tee sdxl.txt
- name: Run LLaMA
run: |
NV=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
@ -195,6 +200,7 @@ jobs:
matmul_ptx.txt
matmul_nv.txt
sd.txt
sdxl.txt
llama_unjitted.txt
llama_jitted.txt
llama_beam.txt
@ -309,6 +315,8 @@ jobs:
# run: HSA=1 M_START=12 M_STOP=20 M_STEP=1 N_START=12 N_STOP=20 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 DEBUG=2 python3 ./extra/gemm/fuzz_matmul.py
- name: Run Stable Diffusion
run: AMD=1 THREEFRY=1 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
- name: Run SDXL
run: AMD=1 THREEFRY=1 python3 examples/sdxl.py --seed 0 --noshow | tee sdxl.txt
- name: Run LLaMA 7B
run: |
AMD=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
@ -359,6 +367,7 @@ jobs:
matmul.txt
matmul_amd.txt
sd.txt
sdxl.txt
mixtral.txt
testmoreamdbenchmark:

View File

@ -6,7 +6,7 @@
from tinygrad import Tensor, TinyJit, dtypes
from tinygrad.nn import Linear, Conv2d, GroupNorm, LayerNorm, Embedding
from tinygrad.nn.state import safe_load, load_state_dict
from tinygrad.helpers import fetch, trange
from tinygrad.helpers import fetch, trange, colored, THREEFRY
from examples.stable_diffusion import ClipTokenizer, ResnetBlock, Mid, Downsample, Upsample
import numpy as np
@ -854,9 +854,10 @@ class DPMPP2MSampler:
if __name__ == "__main__":
default_prompt = "a horse sized cat eating a bagel"
parser = argparse.ArgumentParser(description="Run SDXL", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--steps', type=int, default=5, help="The number of diffusion steps")
parser.add_argument('--prompt', type=str, default="a horse sized cat eating a bagel", help="Description of image to generate")
parser.add_argument('--steps', type=int, default=10, help="The number of diffusion steps")
parser.add_argument('--prompt', type=str, default=default_prompt, help="Description of image to generate")
parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
parser.add_argument('--seed', type=int, help="Set the random latent seed")
parser.add_argument('--guidance', type=float, default=6.0, help="Prompt strength")
@ -910,3 +911,11 @@ if __name__ == "__main__":
if not args.noshow:
im.show()
# validation!
if args.prompt == default_prompt and args.steps == 10 and args.seed == 0 and args.guidance == 6.0 and args.width == args.height == 1024 \
and not args.weights and THREEFRY:
ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "sdxl_seed0.png")))
distance = (((x - ref_image).cast(dtypes.float) / ref_image.max())**2).mean().item()
assert distance < 3e-4, colored(f"validation failed with {distance=}", "red")
print(colored(f"output validated with {distance=}", "green"))

BIN
examples/sdxl_seed0.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 MiB

View File

@ -239,7 +239,7 @@ def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional
with urllib.request.urlopen(url, timeout=10) as r:
assert r.status == 200
total_length = int(r.headers.get('content-length', 0))
progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=f"{url}: ")
progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=f"{url}: ", disable=CI)
(path := fp.parent).mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
while chunk := r.read(16384): progress_bar.update(f.write(chunk))