mirror of https://github.com/commaai/tinygrad.git
validate sdxl output and put it in benchmark (#5211)
* validate sdxl output and put it in benchmark * don't print fetch progress_bar in CI
This commit is contained in:
parent
63fa4e2a0e
commit
7090eac8cb
|
@ -39,6 +39,8 @@ jobs:
|
|||
# run: echo "RUN_PROCESS_REPLAY=1" >> $GITHUB_ENV
|
||||
- name: Run Stable Diffusion
|
||||
run: JIT=2 THREEFRY=1 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
|
||||
- name: Run SDXL
|
||||
run: JIT=2 THREEFRY=1 python3 examples/sdxl.py --seed 0 --noshow | tee sdxl.txt
|
||||
- name: Run model inference benchmark
|
||||
run: METAL=1 python3 test/external/external_model_benchmark.py
|
||||
- name: Test speed vs torch
|
||||
|
@ -103,6 +105,7 @@ jobs:
|
|||
matmul.txt
|
||||
matmul_half.txt
|
||||
sd.txt
|
||||
sdxl.txt
|
||||
beautiful_mnist.txt
|
||||
train_cifar.txt
|
||||
train_cifar_half.txt
|
||||
|
@ -156,6 +159,8 @@ jobs:
|
|||
run: CUDA=1 PTX=1 M_START=12 M_STOP=20 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 python3 ./extra/gemm/fuzz_matmul.py
|
||||
- name: Run Stable Diffusion
|
||||
run: NV=1 THREEFRY=1 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
|
||||
- name: Run SDXL
|
||||
run: NV=1 THREEFRY=1 python3 examples/sdxl.py --seed 0 --noshow | tee sdxl.txt
|
||||
- name: Run LLaMA
|
||||
run: |
|
||||
NV=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
|
||||
|
@ -195,6 +200,7 @@ jobs:
|
|||
matmul_ptx.txt
|
||||
matmul_nv.txt
|
||||
sd.txt
|
||||
sdxl.txt
|
||||
llama_unjitted.txt
|
||||
llama_jitted.txt
|
||||
llama_beam.txt
|
||||
|
@ -309,6 +315,8 @@ jobs:
|
|||
# run: HSA=1 M_START=12 M_STOP=20 M_STEP=1 N_START=12 N_STOP=20 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 DEBUG=2 python3 ./extra/gemm/fuzz_matmul.py
|
||||
- name: Run Stable Diffusion
|
||||
run: AMD=1 THREEFRY=1 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
|
||||
- name: Run SDXL
|
||||
run: AMD=1 THREEFRY=1 python3 examples/sdxl.py --seed 0 --noshow | tee sdxl.txt
|
||||
- name: Run LLaMA 7B
|
||||
run: |
|
||||
AMD=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
|
||||
|
@ -359,6 +367,7 @@ jobs:
|
|||
matmul.txt
|
||||
matmul_amd.txt
|
||||
sd.txt
|
||||
sdxl.txt
|
||||
mixtral.txt
|
||||
|
||||
testmoreamdbenchmark:
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
from tinygrad import Tensor, TinyJit, dtypes
|
||||
from tinygrad.nn import Linear, Conv2d, GroupNorm, LayerNorm, Embedding
|
||||
from tinygrad.nn.state import safe_load, load_state_dict
|
||||
from tinygrad.helpers import fetch, trange
|
||||
from tinygrad.helpers import fetch, trange, colored, THREEFRY
|
||||
from examples.stable_diffusion import ClipTokenizer, ResnetBlock, Mid, Downsample, Upsample
|
||||
import numpy as np
|
||||
|
||||
|
@ -854,9 +854,10 @@ class DPMPP2MSampler:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
default_prompt = "a horse sized cat eating a bagel"
|
||||
parser = argparse.ArgumentParser(description="Run SDXL", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--steps', type=int, default=5, help="The number of diffusion steps")
|
||||
parser.add_argument('--prompt', type=str, default="a horse sized cat eating a bagel", help="Description of image to generate")
|
||||
parser.add_argument('--steps', type=int, default=10, help="The number of diffusion steps")
|
||||
parser.add_argument('--prompt', type=str, default=default_prompt, help="Description of image to generate")
|
||||
parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
|
||||
parser.add_argument('--seed', type=int, help="Set the random latent seed")
|
||||
parser.add_argument('--guidance', type=float, default=6.0, help="Prompt strength")
|
||||
|
@ -910,3 +911,11 @@ if __name__ == "__main__":
|
|||
|
||||
if not args.noshow:
|
||||
im.show()
|
||||
|
||||
# validation!
|
||||
if args.prompt == default_prompt and args.steps == 10 and args.seed == 0 and args.guidance == 6.0 and args.width == args.height == 1024 \
|
||||
and not args.weights and THREEFRY:
|
||||
ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "sdxl_seed0.png")))
|
||||
distance = (((x - ref_image).cast(dtypes.float) / ref_image.max())**2).mean().item()
|
||||
assert distance < 3e-4, colored(f"validation failed with {distance=}", "red")
|
||||
print(colored(f"output validated with {distance=}", "green"))
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 1.3 MiB |
|
@ -239,7 +239,7 @@ def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional
|
|||
with urllib.request.urlopen(url, timeout=10) as r:
|
||||
assert r.status == 200
|
||||
total_length = int(r.headers.get('content-length', 0))
|
||||
progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=f"{url}: ")
|
||||
progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=f"{url}: ", disable=CI)
|
||||
(path := fp.parent).mkdir(parents=True, exist_ok=True)
|
||||
with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
|
||||
while chunk := r.read(16384): progress_bar.update(f.write(chunk))
|
||||
|
|
Loading…
Reference in New Issue