Bitcast p2 bfloat16 tests + clang fix (#2635)

* add bf16 test support

this model takes me almost a minute to download though:

https://huggingface.co/TinyPixel/Llama-2-7B-bf16-sharded/resolve/main/pytorch_model-00001-of-00014.bin?download=true: 100%|█████████████████████████████| 981M/981M [00:40<00:00, 24.2MB/s]

* ensure we first load if it is bitcast to avoid taking the address of an rvalue

* tiny bf16 in the cloud

skip GPU

* should skip torch

lint

* Revert "ensure we first load if it is bitcast to avoid taking the address of an rvalue"

This reverts commit b86a28ab84bc1173764b2d480218e8de41a32390.

* break the kernel

* skip LLVM and GPU in CI

* skip CUDA
This commit is contained in:
qazal 2023-12-08 13:30:10 -05:00 committed by GitHub
parent a29538a094
commit 73b067f5ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 3 deletions

View File

@ -3,16 +3,17 @@ import unittest
import numpy as np
from tinygrad.tensor import Tensor, Device
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
from tinygrad.helpers import fetch, temp
from tinygrad.helpers import CI, fetch, temp
from tinygrad.helpers import Timing
def compare_weights_both(url):
import torch
fn = fetch(url)
tg_weights = get_state_dict(torch_load(fn))
torch_weights = get_state_dict(torch.load(fn), tensor_type=torch.Tensor)
torch_weights = get_state_dict(torch.load(fn, map_location=torch.device('cpu')), tensor_type=torch.Tensor)
assert list(tg_weights.keys()) == list(torch_weights.keys())
for k in tg_weights:
if torch_weights[k].dtype == torch.bfloat16: torch_weights[k] = torch_weights[k].float() # numpy doesn't support bfloat16
np.testing.assert_equal(tg_weights[k].numpy(), torch_weights[k].numpy(), err_msg=f"mismatch at {k}, {tg_weights[k].shape}")
print(f"compared {len(tg_weights)} weights")
@ -23,6 +24,14 @@ class TestTorchLoad(unittest.TestCase):
def test_load_enet_alt(self): compare_weights_both("https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth")
# pytorch zip format
def test_load_convnext(self): compare_weights_both('https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth')
# for GPU, cl_khr_fp16 isn't supported
# for LLVM, it segfaults because it can't link to the casting function
# CUDACPU architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
@unittest.skipIf(Device.DEFAULT in ["GPU", "LLVM", "CUDA"] and CI, "fp16 broken in some backends")
@unittest.skipIf(Device.DEFAULT == "TORCH", "torch doesn't support the way we load bfloat (cast to uint32)")
def test_load_llama2bfloat(self): compare_weights_both("https://huggingface.co/qazalin/bf16-lightweight/resolve/main/consolidated.00.pth?download=true")
# TODO: support pytorch tar format with minimal lines
#def test_load_resnet(self): compare_weights_both('https://download.pytorch.org/models/resnet50-19c8e357.pth')

View File

@ -77,7 +77,7 @@ def torch_load(fn:str):
# upstream LLaMA also does this conversion:
# https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95
# TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support
if storage[1] == dtypes.bfloat16: ret = ret.cast(dtypes.uint16).to(Device.DEFAULT).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).half()
if storage[1] == dtypes.bfloat16: ret = ret.cast(dtypes.uint16).to(Device.DEFAULT).cast(dtypes.uint32).mul(1<<16).contiguous().bitcast(dtypes.float32).half()
else: ret = ret.cast(storage[1])
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk