mirror of https://github.com/commaai/tinygrad.git
Bitcast p2 bfloat16 tests + clang fix (#2635)
* add bf16 test support this model takes me almost a minute to download though: https://huggingface.co/TinyPixel/Llama-2-7B-bf16-sharded/resolve/main/pytorch_model-00001-of-00014.bin?download=true: 100%|█████████████████████████████| 981M/981M [00:40<00:00, 24.2MB/s] * ensure we first load if it is bitcast to avoid taking the address of an rvalue * tiny bf16 in the cloud skip GPU * should skip torch lint * Revert "ensure we first load if it is bitcast to avoid taking the address of an rvalue" This reverts commit b86a28ab84bc1173764b2d480218e8de41a32390. * break the kernel * skip LLVM and GPU in CI * skip CUDA
This commit is contained in:
parent
a29538a094
commit
73b067f5ce
|
@ -3,16 +3,17 @@ import unittest
|
|||
import numpy as np
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
|
||||
from tinygrad.helpers import fetch, temp
|
||||
from tinygrad.helpers import CI, fetch, temp
|
||||
from tinygrad.helpers import Timing
|
||||
|
||||
def compare_weights_both(url):
|
||||
import torch
|
||||
fn = fetch(url)
|
||||
tg_weights = get_state_dict(torch_load(fn))
|
||||
torch_weights = get_state_dict(torch.load(fn), tensor_type=torch.Tensor)
|
||||
torch_weights = get_state_dict(torch.load(fn, map_location=torch.device('cpu')), tensor_type=torch.Tensor)
|
||||
assert list(tg_weights.keys()) == list(torch_weights.keys())
|
||||
for k in tg_weights:
|
||||
if torch_weights[k].dtype == torch.bfloat16: torch_weights[k] = torch_weights[k].float() # numpy doesn't support bfloat16
|
||||
np.testing.assert_equal(tg_weights[k].numpy(), torch_weights[k].numpy(), err_msg=f"mismatch at {k}, {tg_weights[k].shape}")
|
||||
print(f"compared {len(tg_weights)} weights")
|
||||
|
||||
|
@ -23,6 +24,14 @@ class TestTorchLoad(unittest.TestCase):
|
|||
def test_load_enet_alt(self): compare_weights_both("https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth")
|
||||
# pytorch zip format
|
||||
def test_load_convnext(self): compare_weights_both('https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth')
|
||||
|
||||
# for GPU, cl_khr_fp16 isn't supported
|
||||
# for LLVM, it segfaults because it can't link to the casting function
|
||||
# CUDACPU architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
|
||||
@unittest.skipIf(Device.DEFAULT in ["GPU", "LLVM", "CUDA"] and CI, "fp16 broken in some backends")
|
||||
@unittest.skipIf(Device.DEFAULT == "TORCH", "torch doesn't support the way we load bfloat (cast to uint32)")
|
||||
def test_load_llama2bfloat(self): compare_weights_both("https://huggingface.co/qazalin/bf16-lightweight/resolve/main/consolidated.00.pth?download=true")
|
||||
|
||||
# TODO: support pytorch tar format with minimal lines
|
||||
#def test_load_resnet(self): compare_weights_both('https://download.pytorch.org/models/resnet50-19c8e357.pth')
|
||||
|
||||
|
|
|
@ -77,7 +77,7 @@ def torch_load(fn:str):
|
|||
# upstream LLaMA also does this conversion:
|
||||
# https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95
|
||||
# TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support
|
||||
if storage[1] == dtypes.bfloat16: ret = ret.cast(dtypes.uint16).to(Device.DEFAULT).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).half()
|
||||
if storage[1] == dtypes.bfloat16: ret = ret.cast(dtypes.uint16).to(Device.DEFAULT).cast(dtypes.uint32).mul(1<<16).contiguous().bitcast(dtypes.float32).half()
|
||||
else: ret = ret.cast(storage[1])
|
||||
|
||||
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
|
||||
|
|
Loading…
Reference in New Issue