mirror of https://github.com/commaai/tinygrad.git
minor cleanups, remove dead files (#2398)
* minor cleanups, remove dead files * s.name * use disk * pytest passes on mac
This commit is contained in:
parent
66c75f30c6
commit
4f8f0ac139
|
@ -1,8 +0,0 @@
|
|||
#!/bin/bash
|
||||
# note: if we compile tinygrad/nn/__init__.py __dict__ no longer works, and optimizers will silently fail
|
||||
mypyc --check-untyped-defs --explicit-package-bases --warn-unreachable tinygrad/shape/shapetracker.py tinygrad/shape/symbolic.py \
|
||||
tinygrad/helpers.py tinygrad/mlops.py tinygrad/tensor.py tinygrad/graph.py \
|
||||
#tinygrad/codegen/gpu.py tinygrad/runtime/ops_metal.py
|
||||
#tinygrad/codegen/ast.py
|
||||
#tinygrad/nn/__init__.py
|
||||
#tinygrad/ops.py tinygrad/runtime/ops_metal.py tinygrad/runtime/ops_gpu.py tinygrad/runtime/ops_cpu.py tinygrad/lazy.py
|
|
@ -37,7 +37,7 @@ def fetch_as_file(url):
|
|||
def download_file(url, fp, skip_if_exists=True):
|
||||
if skip_if_exists and Path(fp).is_file() and Path(fp).stat().st_size > 0:
|
||||
return
|
||||
r = requests.get(url, stream=True)
|
||||
r = requests.get(url, stream=True, timeout=10)
|
||||
assert r.status_code == 200
|
||||
progress_bar = tqdm(total=int(r.headers.get('content-length', 0)), unit='B', unit_scale=True, desc=url)
|
||||
(path := Path(fp).parent).mkdir(parents=True, exist_ok=True)
|
||||
|
|
3
rmso.sh
3
rmso.sh
|
@ -1,3 +0,0 @@
|
|||
#!/bin/bash
|
||||
rm tinygrad/*.so tinygrad/codegen/*.so tinygrad/shape/*.so tinygrad/nn/*.so tinygrad/runtime/*.so *.so
|
||||
|
|
@ -14,9 +14,7 @@ from PIL import Image
|
|||
@unittest.skipIf(CI, "no internet tests in CI")
|
||||
class TestFetch(unittest.TestCase):
|
||||
def test_fetch_bad_http(self):
|
||||
self.assertRaises(AssertionError, fetch, 'http://httpstat.us/500')
|
||||
self.assertRaises(AssertionError, fetch, 'http://httpstat.us/404')
|
||||
self.assertRaises(AssertionError, fetch, 'http://httpstat.us/400')
|
||||
self.assertRaises(AssertionError, fetch, 'http://www.google.com/404')
|
||||
|
||||
def test_fetch_small(self):
|
||||
assert(len(fetch('https://google.com'))>0)
|
||||
|
|
|
@ -49,7 +49,7 @@ class TestTrain(unittest.TestCase):
|
|||
train_one_step(model,X,Y)
|
||||
check_gc()
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "too many buffers for webgpu")
|
||||
@unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal")
|
||||
def test_vit(self):
|
||||
model = ViT()
|
||||
X = np.zeros((BS,3,224,224), dtype=np.float32)
|
||||
|
|
|
@ -62,4 +62,5 @@ class TestWhisper(unittest.TestCase):
|
|||
with self.assertRaises(Exception):
|
||||
transcribe_waveform(self.model, self.enc, waveforms)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -2,6 +2,7 @@ import unittest
|
|||
from tinygrad import Tensor
|
||||
from tinygrad.ops import Device
|
||||
from tinygrad.helpers import Timing, CI
|
||||
import multiprocessing.shared_memory as shared_memory
|
||||
|
||||
N = 4096 if CI else 16384
|
||||
class TestCopySpeed(unittest.TestCase):
|
||||
|
@ -9,13 +10,18 @@ class TestCopySpeed(unittest.TestCase):
|
|||
def setUpClass(cls): Device[Device.DEFAULT].synchronize()
|
||||
|
||||
def testCopySHMtoDefault(self):
|
||||
t = Tensor.empty(N, N, device="disk:/dev/shm/test_X").realize()
|
||||
#t = Tensor.empty(N, N, device="disk:shm:test_X").realize()
|
||||
s = shared_memory.SharedMemory(name="test_X", create=True, size=N*N*4)
|
||||
s.close()
|
||||
if CI:
|
||||
t = Tensor.empty(N, N, device="disk:/dev/shm/test_X").realize()
|
||||
else:
|
||||
t = Tensor.empty(N, N, device="disk:shm:test_X").realize()
|
||||
for _ in range(3):
|
||||
with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"):
|
||||
with Timing("queue: "):
|
||||
t.to(Device.DEFAULT).realize()
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
s.unlink()
|
||||
|
||||
def testCopyCPUtoDefault(self):
|
||||
t = Tensor.rand(N, N, device="cpu").realize()
|
||||
|
@ -45,6 +51,8 @@ class TestCopySpeed(unittest.TestCase):
|
|||
|
||||
@unittest.skipIf(CI, "CI doesn't have 6 GPUs")
|
||||
def testCopyCPUto6GPUs(self):
|
||||
from tinygrad.runtime.ops_gpu import CL
|
||||
if len(CL.devices) != 6: raise unittest.SkipTest("computer doesn't have 6 GPUs")
|
||||
t = Tensor.rand(N, N, device="cpu").realize()
|
||||
print(f"buffer: {t.nbytes()*1e-9:.2f} GB")
|
||||
for _ in range(3):
|
||||
|
|
|
@ -72,7 +72,7 @@ class Tensor:
|
|||
data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
|
||||
else:
|
||||
data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
|
||||
else: raise RuntimeError(f"can't create Tensor from {data}")
|
||||
else: raise RuntimeError(f"can't create Tensor from {data} with type {type(data)}")
|
||||
|
||||
# data is a LazyBuffer, but it might be on the wrong device
|
||||
self.lazydata = data if data.device == device else data.copy_to_device(device)
|
||||
|
@ -665,7 +665,7 @@ class Tensor:
|
|||
return (x, y)
|
||||
|
||||
def _to_float(self, x:Union[Tensor, float]):
|
||||
return x.lazydata.op.arg if isinstance(x, Tensor) and not x.lazydata.realized and x.lazydata.op.op == LoadOps.CONST and not x.requires_grad \
|
||||
return x.lazydata.base.op.arg if isinstance(x, Tensor) and x.lazydata.is_unrealized_const() and not x.requires_grad \
|
||||
and x.lazydata.st.contiguous and self._broadcasted(x)[0].shape == self.shape else x
|
||||
|
||||
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor:
|
||||
|
@ -715,7 +715,6 @@ class Tensor:
|
|||
|
||||
# ***** binary op wrappers (18 wasted lines to make the typechecker happy) *****
|
||||
|
||||
# NOTE: __pow__ and friends are broken in mypyc with the ** operator
|
||||
def __add__(self, x) -> Tensor: return self.add(x)
|
||||
def __sub__(self, x) -> Tensor: return self.sub(x)
|
||||
def __mul__(self, x) -> Tensor: return self.mul(x)
|
||||
|
|
Loading…
Reference in New Issue