diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index 5aae147e..6b11831a 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -3,16 +3,17 @@ import unittest import numpy as np from tinygrad.tensor import Tensor, Device from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load -from tinygrad.helpers import fetch, temp +from tinygrad.helpers import CI, fetch, temp from tinygrad.helpers import Timing def compare_weights_both(url): import torch fn = fetch(url) tg_weights = get_state_dict(torch_load(fn)) - torch_weights = get_state_dict(torch.load(fn), tensor_type=torch.Tensor) + torch_weights = get_state_dict(torch.load(fn, map_location=torch.device('cpu')), tensor_type=torch.Tensor) assert list(tg_weights.keys()) == list(torch_weights.keys()) for k in tg_weights: + if torch_weights[k].dtype == torch.bfloat16: torch_weights[k] = torch_weights[k].float() # numpy doesn't support bfloat16 np.testing.assert_equal(tg_weights[k].numpy(), torch_weights[k].numpy(), err_msg=f"mismatch at {k}, {tg_weights[k].shape}") print(f"compared {len(tg_weights)} weights") @@ -23,6 +24,14 @@ class TestTorchLoad(unittest.TestCase): def test_load_enet_alt(self): compare_weights_both("https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth") # pytorch zip format def test_load_convnext(self): compare_weights_both('https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth') + + # for GPU, cl_khr_fp16 isn't supported + # for LLVM, it segfaults because it can't link to the casting function + # CUDACPU architecture is sm_35 but we need at least sm_70 to run fp16 ALUs + @unittest.skipIf(Device.DEFAULT in ["GPU", "LLVM", "CUDA"] and CI, "fp16 broken in some backends") + @unittest.skipIf(Device.DEFAULT == "TORCH", "torch doesn't support the way we load bfloat (cast to uint32)") + def test_load_llama2bfloat(self): compare_weights_both("https://huggingface.co/qazalin/bf16-lightweight/resolve/main/consolidated.00.pth?download=true") + # TODO: support pytorch tar format with minimal lines #def test_load_resnet(self): compare_weights_both('https://download.pytorch.org/models/resnet50-19c8e357.pth') diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 11628416..0882e246 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -77,7 +77,7 @@ def torch_load(fn:str): # upstream LLaMA also does this conversion: # https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95 # TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support - if storage[1] == dtypes.bfloat16: ret = ret.cast(dtypes.uint16).to(Device.DEFAULT).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).half() + if storage[1] == dtypes.bfloat16: ret = ret.cast(dtypes.uint16).to(Device.DEFAULT).cast(dtypes.uint32).mul(1<<16).contiguous().bitcast(dtypes.float32).half() else: ret = ret.cast(storage[1]) # 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk