mirror of https://github.com/commaai/tinygrad.git
fix float16 in CLANG on linux
This commit is contained in:
parent
803b0aef28
commit
dc9a6b4bb7
|
@ -14,6 +14,7 @@ class TestDtype(unittest.TestCase):
|
|||
na = a.numpy()
|
||||
print(na, na.dtype, a.lazydata.realized)
|
||||
assert na.dtype == np.float16
|
||||
np.testing.assert_allclose(na, [1,2,3,4])
|
||||
|
||||
def test_half_add(self):
|
||||
a = Tensor([1,2,3,4], dtype=dtypes.float16)
|
||||
|
@ -21,6 +22,7 @@ class TestDtype(unittest.TestCase):
|
|||
c = a+b
|
||||
print(c.numpy())
|
||||
assert c.dtype == dtypes.float16
|
||||
np.testing.assert_allclose(c.numpy(), [2,4,6,8])
|
||||
|
||||
def test_upcast_float(self):
|
||||
# NOTE: there's no downcasting support
|
||||
|
@ -29,6 +31,7 @@ class TestDtype(unittest.TestCase):
|
|||
na = a.numpy()
|
||||
print(na, na.dtype)
|
||||
assert na.dtype == np.float32
|
||||
np.testing.assert_allclose(na, [1,2,3,4])
|
||||
|
||||
def test_half_add_upcast(self):
|
||||
a = Tensor([1,2,3,4], dtype=dtypes.float16)
|
||||
|
@ -36,6 +39,7 @@ class TestDtype(unittest.TestCase):
|
|||
c = a+b
|
||||
print(c.numpy())
|
||||
assert c.dtype == dtypes.float32
|
||||
np.testing.assert_allclose(c.numpy(), [2,4,6,8])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -14,8 +14,9 @@ class ClangProgram:
|
|||
prg = "#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define half __fp16\n" + prg
|
||||
# TODO: is there a way to not write this to disk?
|
||||
fn = f"/tmp/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{'dylib' if platform.system() == 'Darwin' else 'so'}"
|
||||
# NOTE: --rtlib=compiler-rt fixes float16 on Linux, it defines __gnu_h2f_ieee and __gnu_f2h_ieee
|
||||
if not os.path.exists(fn):
|
||||
subprocess.check_output(['clang', '-shared', '-O2', '-Wall','-Werror', '-lm', '-fPIC', '-x', 'c', '-', '-o', fn+".tmp"], input=prg.encode('utf-8'))
|
||||
subprocess.check_output(['clang', '-shared', '-O2', '-Wall','-Werror', '-lm', '--rtlib=compiler-rt', '-fPIC', '-x', 'c', '-', '-o', fn+".tmp"], input=prg.encode('utf-8'))
|
||||
os.rename(fn+".tmp", fn)
|
||||
self.lib = ctypes.CDLL(fn)
|
||||
self.fxn = self.lib[name]
|
||||
|
|
Loading…
Reference in New Issue