mirror of https://github.com/commaai/tinygrad.git
more test cleanups (#2631)
* more test cleanups * move test example back
This commit is contained in:
parent
a63f48d3db
commit
232ed2af3f
|
@ -15,7 +15,7 @@ repos:
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
- id: mypy
|
- id: mypy
|
||||||
name: mypy
|
name: mypy
|
||||||
entry: mypy tinygrad/ extra/helpers.py
|
entry: mypy tinygrad/
|
||||||
language: system
|
language: system
|
||||||
always_run: true
|
always_run: true
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
@ -41,7 +41,7 @@ repos:
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
- id: example
|
- id: example
|
||||||
name: multi device tests
|
name: multi device tests
|
||||||
entry: python3 test/external/test_example.py
|
entry: python3 test/external/external_test_example.py
|
||||||
language: system
|
language: system
|
||||||
always_run: true
|
always_run: true
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
|
|
@ -26,12 +26,11 @@ def eval_resnet():
|
||||||
|
|
||||||
# evaluation on the mlperf classes of the validation set from imagenet
|
# evaluation on the mlperf classes of the validation set from imagenet
|
||||||
from extra.datasets.imagenet import iterate
|
from extra.datasets.imagenet import iterate
|
||||||
from extra.helpers import cross_process
|
|
||||||
|
|
||||||
BS = 64
|
BS = 64
|
||||||
n,d = 0,0
|
n,d = 0,0
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
iterator = cross_process(lambda: iterate(BS))
|
iterator = iterate(BS)
|
||||||
x,ny = next(iterator)
|
x,ny = next(iterator)
|
||||||
dat = Tensor(x)
|
dat = Tensor(x)
|
||||||
while dat is not None:
|
while dat is not None:
|
||||||
|
|
|
@ -1,50 +0,0 @@
|
||||||
import multiprocessing, subprocess
|
|
||||||
import cloudpickle
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
def _early_exec_process(qin, qout):
|
|
||||||
while True:
|
|
||||||
path, inp = qin.get()
|
|
||||||
try:
|
|
||||||
qout.put(subprocess.check_output(path, input=inp))
|
|
||||||
except Exception as e:
|
|
||||||
qout.put(e)
|
|
||||||
|
|
||||||
def enable_early_exec():
|
|
||||||
qin: multiprocessing.Queue = multiprocessing.Queue()
|
|
||||||
qout: multiprocessing.Queue = multiprocessing.Queue()
|
|
||||||
p = multiprocessing.Process(target=_early_exec_process, args=(qin, qout))
|
|
||||||
p.daemon = True
|
|
||||||
p.start()
|
|
||||||
def early_exec(x):
|
|
||||||
qin.put(x)
|
|
||||||
ret = qout.get()
|
|
||||||
if isinstance(ret, Exception): raise ret
|
|
||||||
else: return ret
|
|
||||||
return early_exec
|
|
||||||
|
|
||||||
def proc(itermaker, q) -> None:
|
|
||||||
try:
|
|
||||||
for x in itermaker(): q.put(x)
|
|
||||||
except Exception as e:
|
|
||||||
q.put(e)
|
|
||||||
finally:
|
|
||||||
q.put(None)
|
|
||||||
q.close()
|
|
||||||
|
|
||||||
class _CloudpickleFunctionWrapper:
|
|
||||||
def __init__(self, fn): self.fn = fn
|
|
||||||
def __getstate__(self): return cloudpickle.dumps(self.fn)
|
|
||||||
def __setstate__(self, pfn): self.fn = cloudpickle.loads(pfn)
|
|
||||||
def __call__(self, *args, **kwargs) -> Any: return self.fn(*args, **kwargs)
|
|
||||||
|
|
||||||
def cross_process(itermaker, maxsize=16):
|
|
||||||
q: multiprocessing.Queue = multiprocessing.Queue(maxsize)
|
|
||||||
# multiprocessing uses pickle which cannot dump lambdas, so use cloudpickle.
|
|
||||||
p = multiprocessing.Process(target=proc, args=(_CloudpickleFunctionWrapper(itermaker), q))
|
|
||||||
p.start()
|
|
||||||
while True:
|
|
||||||
ret = q.get()
|
|
||||||
if isinstance(ret, Exception): raise ret
|
|
||||||
elif ret is None: break
|
|
||||||
else: yield ret
|
|
1
setup.py
1
setup.py
|
@ -46,7 +46,6 @@ setup(name='tinygrad',
|
||||||
"opencv-python",
|
"opencv-python",
|
||||||
"tabulate",
|
"tabulate",
|
||||||
"safetensors",
|
"safetensors",
|
||||||
"cloudpickle",
|
|
||||||
"transformers",
|
"transformers",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
"tiktoken",
|
"tiktoken",
|
||||||
|
|
|
@ -1,57 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
import os, cloudpickle, tempfile, unittest, subprocess
|
|
||||||
from extra.helpers import enable_early_exec, cross_process, _CloudpickleFunctionWrapper
|
|
||||||
|
|
||||||
def normalize_line_endings(s): return s.replace(b'\r\n', b'\n')
|
|
||||||
|
|
||||||
class TestEarlyExec(unittest.TestCase):
|
|
||||||
def setUp(self) -> None:
|
|
||||||
self.early_exec = enable_early_exec()
|
|
||||||
|
|
||||||
def early_exec_py_file(self, file_content, exec_args):
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp:
|
|
||||||
temp.write(file_content)
|
|
||||||
temp_path = temp.name
|
|
||||||
try:
|
|
||||||
output = self.early_exec((["python3", temp_path] + exec_args, None))
|
|
||||||
return output
|
|
||||||
finally:
|
|
||||||
os.remove(temp_path)
|
|
||||||
|
|
||||||
def test_enable_early_exec(self):
|
|
||||||
output = self.early_exec_py_file(b'print("Hello, world!")', [])
|
|
||||||
self.assertEqual(b"Hello, world!\n", normalize_line_endings(output))
|
|
||||||
|
|
||||||
def test_enable_early_exec_with_arg(self):
|
|
||||||
output = self.early_exec_py_file(b'import sys\nprint("Hello, " + sys.argv[1] + "!")', ["world"])
|
|
||||||
self.assertEqual(b"Hello, world!\n", normalize_line_endings(output))
|
|
||||||
|
|
||||||
def test_enable_early_exec_process_exception(self):
|
|
||||||
with self.assertRaises(subprocess.CalledProcessError):
|
|
||||||
self.early_exec_py_file(b'raise Exception("Test exception")', [])
|
|
||||||
|
|
||||||
def test_enable_early_exec_type_exception(self):
|
|
||||||
with self.assertRaises(TypeError):
|
|
||||||
self.early_exec((["python3"], "print('Hello, world!')"))
|
|
||||||
|
|
||||||
class TestCrossProcess(unittest.TestCase):
|
|
||||||
|
|
||||||
def test_cross_process(self):
|
|
||||||
def _iterate():
|
|
||||||
for i in range(10): yield i
|
|
||||||
results = list(cross_process(_iterate))
|
|
||||||
self.assertEqual(list(range(10)), results)
|
|
||||||
|
|
||||||
def test_cross_process_exception(self):
|
|
||||||
def _iterate():
|
|
||||||
for i in range(10):
|
|
||||||
if i == 5: raise ValueError("Test exception")
|
|
||||||
yield i
|
|
||||||
with self.assertRaises(ValueError): list(cross_process(_iterate))
|
|
||||||
|
|
||||||
def test_CloudpickleFunctionWrapper(self):
|
|
||||||
def add(x, y): return x + y
|
|
||||||
self.assertEqual(7, cloudpickle.loads(cloudpickle.dumps(_CloudpickleFunctionWrapper(add)))(3, 4))
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
|
Loading…
Reference in New Issue