mirror of https://github.com/commaai/tinygrad.git
141 lines
5.7 KiB
Python
141 lines
5.7 KiB
Python
import pathlib
|
|
import unittest
|
|
import numpy as np
|
|
from tinygrad.tensor import Tensor, Device
|
|
from tinygrad.state import safe_load, safe_save, get_state_dict, torch_load
|
|
from tinygrad.helpers import dtypes
|
|
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
|
from tinygrad.helpers import Timing
|
|
from extra.utils import fetch_as_file, temp
|
|
|
|
def compare_weights_both(url):
|
|
import torch
|
|
fn = fetch_as_file(url)
|
|
tg_weights = get_state_dict(torch_load(fn))
|
|
torch_weights = get_state_dict(torch.load(fn), tensor_type=torch.Tensor)
|
|
assert list(tg_weights.keys()) == list(torch_weights.keys())
|
|
for k in tg_weights:
|
|
np.testing.assert_equal(tg_weights[k].numpy(), torch_weights[k].numpy(), err_msg=f"mismatch at {k}, {tg_weights[k].shape}")
|
|
print(f"compared {len(tg_weights)} weights")
|
|
|
|
class TestTorchLoad(unittest.TestCase):
|
|
# pytorch pkl format
|
|
def test_load_enet(self): compare_weights_both("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
|
|
# pytorch zip format
|
|
def test_load_enet_alt(self): compare_weights_both("https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth")
|
|
# pytorch zip format
|
|
def test_load_convnext(self): compare_weights_both('https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth')
|
|
# TODO: support pytorch tar format with minimal lines
|
|
#def test_load_resnet(self): compare_weights_both('https://download.pytorch.org/models/resnet50-19c8e357.pth')
|
|
|
|
test_fn = pathlib.Path(__file__).parent.parent.parent / "weights/LLaMA/7B/consolidated.00.pth"
|
|
#test_size = test_fn.stat().st_size
|
|
test_size = 1024*1024*1024*2
|
|
|
|
# sudo su -c 'sync; echo 1 > /proc/sys/vm/drop_caches' && python3 test/unit/test_disk_tensor.py TestRawDiskBuffer.test_readinto_read_speed
|
|
@unittest.skipIf(not test_fn.exists(), "download LLaMA weights for read in speed tests")
|
|
class TestRawDiskBuffer(unittest.TestCase):
|
|
def test_readinto_read_speed(self):
|
|
tst = np.empty(test_size, np.uint8)
|
|
with open(test_fn, "rb") as f:
|
|
with Timing("copy in ", lambda et_ns: f" {test_size/et_ns:.2f} GB/s"):
|
|
f.readinto(tst)
|
|
|
|
def test_mmap_read_speed(self):
|
|
db = RawDiskBuffer(test_size, dtype=dtypes.uint8, device=test_fn)
|
|
tst = np.empty(test_size, np.uint8)
|
|
with Timing("copy in ", lambda et_ns: f" {test_size/et_ns:.2f} GB/s"):
|
|
np.copyto(tst, db.toCPU())
|
|
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu doesn't support uint8 datatype")
|
|
class TestSafetensors(unittest.TestCase):
|
|
def test_real_safetensors(self):
|
|
import torch
|
|
from safetensors.torch import save_file
|
|
torch.manual_seed(1337)
|
|
tensors = {
|
|
"weight1": torch.randn((16, 16)),
|
|
"weight2": torch.arange(0, 17, dtype=torch.uint8),
|
|
"weight3": torch.arange(0, 17, dtype=torch.int32).reshape(17,1,1),
|
|
"weight4": torch.arange(0, 2, dtype=torch.uint8),
|
|
}
|
|
save_file(tensors, temp("model.safetensors"))
|
|
|
|
ret = safe_load(temp("model.safetensors"))
|
|
for k,v in tensors.items(): np.testing.assert_array_equal(ret[k].numpy(), v.numpy())
|
|
safe_save(ret, temp("model.safetensors_alt"))
|
|
with open(temp("model.safetensors"), "rb") as f:
|
|
with open(temp("model.safetensors_alt"), "rb") as g:
|
|
assert f.read() == g.read()
|
|
ret2 = safe_load(temp("model.safetensors_alt"))
|
|
for k,v in tensors.items(): np.testing.assert_array_equal(ret2[k].numpy(), v.numpy())
|
|
|
|
def test_efficientnet_safetensors(self):
|
|
from models.efficientnet import EfficientNet
|
|
model = EfficientNet(0)
|
|
state_dict = get_state_dict(model)
|
|
safe_save(state_dict, temp("eff0"))
|
|
state_dict_loaded = safe_load(temp("eff0"))
|
|
assert sorted(list(state_dict_loaded.keys())) == sorted(list(state_dict.keys()))
|
|
for k,v in state_dict.items():
|
|
np.testing.assert_array_equal(v.numpy(), state_dict_loaded[k].numpy())
|
|
|
|
# load with the real safetensors
|
|
from safetensors import safe_open
|
|
with safe_open(temp("eff0"), framework="pt", device="cpu") as f:
|
|
assert sorted(list(f.keys())) == sorted(list(state_dict.keys()))
|
|
for k in f.keys():
|
|
np.testing.assert_array_equal(f.get_tensor(k).numpy(), state_dict[k].numpy())
|
|
|
|
class TestDiskTensor(unittest.TestCase):
|
|
def test_empty(self):
|
|
pathlib.Path(temp("dt1")).unlink(missing_ok=True)
|
|
Tensor.empty(100, 100, device=f"disk:{temp('dt1')}")
|
|
|
|
def test_write_ones(self):
|
|
pathlib.Path(temp("dt2")).unlink(missing_ok=True)
|
|
|
|
out = Tensor.ones(10, 10, device="CPU")
|
|
outdisk = out.to(f"disk:{temp('dt2')}")
|
|
print(outdisk)
|
|
outdisk.realize()
|
|
del out, outdisk
|
|
|
|
# test file
|
|
with open(temp("dt2"), "rb") as f:
|
|
assert f.read() == b"\x00\x00\x80\x3F" * 100
|
|
|
|
# test load alt
|
|
reloaded = Tensor.empty(10, 10, device=f"disk:{temp('dt2')}")
|
|
out = reloaded.numpy()
|
|
assert np.all(out == 1.)
|
|
|
|
def test_slice(self):
|
|
pathlib.Path(temp("dt3")).unlink(missing_ok=True)
|
|
Tensor.arange(10, device="CPU").to(f"disk:{temp('dt3')}").realize()
|
|
|
|
slice_me = Tensor.empty(10, device=f"disk:{temp('dt3')}")
|
|
print(slice_me)
|
|
is_3 = slice_me[3:4].cpu()
|
|
assert is_3.numpy()[0] == 3
|
|
|
|
def test_slice_2d(self):
|
|
pathlib.Path(temp("dt5")).unlink(missing_ok=True)
|
|
Tensor.arange(100, device="CPU").to(f"disk:{temp('dt5')}").realize()
|
|
slice_me = Tensor.empty(10, 10, device=f"disk:{temp('dt5')}")
|
|
tst = slice_me[1].numpy()
|
|
print(tst)
|
|
np.testing.assert_allclose(tst, np.arange(10, 20))
|
|
|
|
def test_assign_slice(self):
|
|
pathlib.Path(temp("dt4")).unlink(missing_ok=True)
|
|
cc = Tensor.arange(10, device="CPU").to(f"disk:{temp('dt4')}").realize()
|
|
|
|
#cc.assign(np.ones(10)).realize()
|
|
print(cc[3:5].numpy())
|
|
cc[3:5].assign([13, 12]).realize()
|
|
print(cc.numpy())
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|
|
|