diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 484b5e3b..77a85160 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -39,6 +39,8 @@ jobs: # run: echo "RUN_PROCESS_REPLAY=1" >> $GITHUB_ENV - name: Run Stable Diffusion run: JIT=2 THREEFRY=1 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt + - name: Run SDXL + run: JIT=2 THREEFRY=1 python3 examples/sdxl.py --seed 0 --noshow | tee sdxl.txt - name: Run model inference benchmark run: METAL=1 python3 test/external/external_model_benchmark.py - name: Test speed vs torch @@ -103,6 +105,7 @@ jobs: matmul.txt matmul_half.txt sd.txt + sdxl.txt beautiful_mnist.txt train_cifar.txt train_cifar_half.txt @@ -156,6 +159,8 @@ jobs: run: CUDA=1 PTX=1 M_START=12 M_STOP=20 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 python3 ./extra/gemm/fuzz_matmul.py - name: Run Stable Diffusion run: NV=1 THREEFRY=1 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt + - name: Run SDXL + run: NV=1 THREEFRY=1 python3 examples/sdxl.py --seed 0 --noshow | tee sdxl.txt - name: Run LLaMA run: | NV=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt @@ -195,6 +200,7 @@ jobs: matmul_ptx.txt matmul_nv.txt sd.txt + sdxl.txt llama_unjitted.txt llama_jitted.txt llama_beam.txt @@ -309,6 +315,8 @@ jobs: # run: HSA=1 M_START=12 M_STOP=20 M_STEP=1 N_START=12 N_STOP=20 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 DEBUG=2 python3 ./extra/gemm/fuzz_matmul.py - name: Run Stable Diffusion run: AMD=1 THREEFRY=1 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt + - name: Run SDXL + run: AMD=1 THREEFRY=1 python3 examples/sdxl.py --seed 0 --noshow | tee sdxl.txt - name: Run LLaMA 7B run: | AMD=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt @@ -359,6 +367,7 @@ jobs: matmul.txt matmul_amd.txt sd.txt + sdxl.txt mixtral.txt testmoreamdbenchmark: diff --git a/examples/sdxl.py b/examples/sdxl.py index 7fee5ecb..0e5857e1 100644 --- a/examples/sdxl.py +++ b/examples/sdxl.py @@ -6,7 +6,7 @@ from tinygrad import Tensor, TinyJit, dtypes from tinygrad.nn import Linear, Conv2d, GroupNorm, LayerNorm, Embedding from tinygrad.nn.state import safe_load, load_state_dict -from tinygrad.helpers import fetch, trange +from tinygrad.helpers import fetch, trange, colored, THREEFRY from examples.stable_diffusion import ClipTokenizer, ResnetBlock, Mid, Downsample, Upsample import numpy as np @@ -854,9 +854,10 @@ class DPMPP2MSampler: if __name__ == "__main__": + default_prompt = "a horse sized cat eating a bagel" parser = argparse.ArgumentParser(description="Run SDXL", formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--steps', type=int, default=5, help="The number of diffusion steps") - parser.add_argument('--prompt', type=str, default="a horse sized cat eating a bagel", help="Description of image to generate") + parser.add_argument('--steps', type=int, default=10, help="The number of diffusion steps") + parser.add_argument('--prompt', type=str, default=default_prompt, help="Description of image to generate") parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename") parser.add_argument('--seed', type=int, help="Set the random latent seed") parser.add_argument('--guidance', type=float, default=6.0, help="Prompt strength") @@ -910,3 +911,11 @@ if __name__ == "__main__": if not args.noshow: im.show() + + # validation! + if args.prompt == default_prompt and args.steps == 10 and args.seed == 0 and args.guidance == 6.0 and args.width == args.height == 1024 \ + and not args.weights and THREEFRY: + ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "sdxl_seed0.png"))) + distance = (((x - ref_image).cast(dtypes.float) / ref_image.max())**2).mean().item() + assert distance < 3e-4, colored(f"validation failed with {distance=}", "red") + print(colored(f"output validated with {distance=}", "green")) diff --git a/examples/sdxl_seed0.png b/examples/sdxl_seed0.png new file mode 100644 index 00000000..f41b3600 Binary files /dev/null and b/examples/sdxl_seed0.png differ diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 51a9e346..cd9ba086 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -239,7 +239,7 @@ def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional with urllib.request.urlopen(url, timeout=10) as r: assert r.status == 200 total_length = int(r.headers.get('content-length', 0)) - progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=f"{url}: ") + progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=f"{url}: ", disable=CI) (path := fp.parent).mkdir(parents=True, exist_ok=True) with tempfile.NamedTemporaryFile(dir=path, delete=False) as f: while chunk := r.read(16384): progress_bar.update(f.write(chunk))