minor cleanups, remove dead files (#2398)

* minor cleanups, remove dead files

* s.name

* use disk

* pytest passes on mac
This commit is contained in:
George Hotz 2023-11-23 09:01:50 -08:00 committed by GitHub
parent 66c75f30c6
commit 4f8f0ac139
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 17 additions and 22 deletions

View File

@ -1,8 +0,0 @@
#!/bin/bash
# note: if we compile tinygrad/nn/__init__.py __dict__ no longer works, and optimizers will silently fail
mypyc --check-untyped-defs --explicit-package-bases --warn-unreachable tinygrad/shape/shapetracker.py tinygrad/shape/symbolic.py \
tinygrad/helpers.py tinygrad/mlops.py tinygrad/tensor.py tinygrad/graph.py \
#tinygrad/codegen/gpu.py tinygrad/runtime/ops_metal.py
#tinygrad/codegen/ast.py
#tinygrad/nn/__init__.py
#tinygrad/ops.py tinygrad/runtime/ops_metal.py tinygrad/runtime/ops_gpu.py tinygrad/runtime/ops_cpu.py tinygrad/lazy.py

View File

@ -37,7 +37,7 @@ def fetch_as_file(url):
def download_file(url, fp, skip_if_exists=True):
if skip_if_exists and Path(fp).is_file() and Path(fp).stat().st_size > 0:
return
r = requests.get(url, stream=True)
r = requests.get(url, stream=True, timeout=10)
assert r.status_code == 200
progress_bar = tqdm(total=int(r.headers.get('content-length', 0)), unit='B', unit_scale=True, desc=url)
(path := Path(fp).parent).mkdir(parents=True, exist_ok=True)

View File

@ -1,3 +0,0 @@
#!/bin/bash
rm tinygrad/*.so tinygrad/codegen/*.so tinygrad/shape/*.so tinygrad/nn/*.so tinygrad/runtime/*.so *.so

View File

@ -14,9 +14,7 @@ from PIL import Image
@unittest.skipIf(CI, "no internet tests in CI")
class TestFetch(unittest.TestCase):
def test_fetch_bad_http(self):
self.assertRaises(AssertionError, fetch, 'http://httpstat.us/500')
self.assertRaises(AssertionError, fetch, 'http://httpstat.us/404')
self.assertRaises(AssertionError, fetch, 'http://httpstat.us/400')
self.assertRaises(AssertionError, fetch, 'http://www.google.com/404')
def test_fetch_small(self):
assert(len(fetch('https://google.com'))>0)

View File

@ -49,7 +49,7 @@ class TestTrain(unittest.TestCase):
train_one_step(model,X,Y)
check_gc()
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "too many buffers for webgpu")
@unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal")
def test_vit(self):
model = ViT()
X = np.zeros((BS,3,224,224), dtype=np.float32)

View File

@ -62,4 +62,5 @@ class TestWhisper(unittest.TestCase):
with self.assertRaises(Exception):
transcribe_waveform(self.model, self.enc, waveforms)
if __name__ == '__main__':
unittest.main()

View File

@ -2,6 +2,7 @@ import unittest
from tinygrad import Tensor
from tinygrad.ops import Device
from tinygrad.helpers import Timing, CI
import multiprocessing.shared_memory as shared_memory
N = 4096 if CI else 16384
class TestCopySpeed(unittest.TestCase):
@ -9,13 +10,18 @@ class TestCopySpeed(unittest.TestCase):
def setUpClass(cls): Device[Device.DEFAULT].synchronize()
def testCopySHMtoDefault(self):
t = Tensor.empty(N, N, device="disk:/dev/shm/test_X").realize()
#t = Tensor.empty(N, N, device="disk:shm:test_X").realize()
s = shared_memory.SharedMemory(name="test_X", create=True, size=N*N*4)
s.close()
if CI:
t = Tensor.empty(N, N, device="disk:/dev/shm/test_X").realize()
else:
t = Tensor.empty(N, N, device="disk:shm:test_X").realize()
for _ in range(3):
with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"):
with Timing("queue: "):
t.to(Device.DEFAULT).realize()
Device[Device.DEFAULT].synchronize()
s.unlink()
def testCopyCPUtoDefault(self):
t = Tensor.rand(N, N, device="cpu").realize()
@ -45,6 +51,8 @@ class TestCopySpeed(unittest.TestCase):
@unittest.skipIf(CI, "CI doesn't have 6 GPUs")
def testCopyCPUto6GPUs(self):
from tinygrad.runtime.ops_gpu import CL
if len(CL.devices) != 6: raise unittest.SkipTest("computer doesn't have 6 GPUs")
t = Tensor.rand(N, N, device="cpu").realize()
print(f"buffer: {t.nbytes()*1e-9:.2f} GB")
for _ in range(3):

View File

@ -72,7 +72,7 @@ class Tensor:
data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
else:
data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
else: raise RuntimeError(f"can't create Tensor from {data}")
else: raise RuntimeError(f"can't create Tensor from {data} with type {type(data)}")
# data is a LazyBuffer, but it might be on the wrong device
self.lazydata = data if data.device == device else data.copy_to_device(device)
@ -665,7 +665,7 @@ class Tensor:
return (x, y)
def _to_float(self, x:Union[Tensor, float]):
return x.lazydata.op.arg if isinstance(x, Tensor) and not x.lazydata.realized and x.lazydata.op.op == LoadOps.CONST and not x.requires_grad \
return x.lazydata.base.op.arg if isinstance(x, Tensor) and x.lazydata.is_unrealized_const() and not x.requires_grad \
and x.lazydata.st.contiguous and self._broadcasted(x)[0].shape == self.shape else x
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor:
@ -715,7 +715,6 @@ class Tensor:
# ***** binary op wrappers (18 wasted lines to make the typechecker happy) *****
# NOTE: __pow__ and friends are broken in mypyc with the ** operator
def __add__(self, x) -> Tensor: return self.add(x)
def __sub__(self, x) -> Tensor: return self.sub(x)
def __mul__(self, x) -> Tensor: return self.mul(x)