fix float16 in CLANG on linux

This commit is contained in:
George Hotz 2023-03-11 21:51:22 -08:00
parent 803b0aef28
commit dc9a6b4bb7
2 changed files with 6 additions and 1 deletions

View File

@ -14,6 +14,7 @@ class TestDtype(unittest.TestCase):
na = a.numpy()
print(na, na.dtype, a.lazydata.realized)
assert na.dtype == np.float16
np.testing.assert_allclose(na, [1,2,3,4])
def test_half_add(self):
a = Tensor([1,2,3,4], dtype=dtypes.float16)
@ -21,6 +22,7 @@ class TestDtype(unittest.TestCase):
c = a+b
print(c.numpy())
assert c.dtype == dtypes.float16
np.testing.assert_allclose(c.numpy(), [2,4,6,8])
def test_upcast_float(self):
# NOTE: there's no downcasting support
@ -29,6 +31,7 @@ class TestDtype(unittest.TestCase):
na = a.numpy()
print(na, na.dtype)
assert na.dtype == np.float32
np.testing.assert_allclose(na, [1,2,3,4])
def test_half_add_upcast(self):
a = Tensor([1,2,3,4], dtype=dtypes.float16)
@ -36,6 +39,7 @@ class TestDtype(unittest.TestCase):
c = a+b
print(c.numpy())
assert c.dtype == dtypes.float32
np.testing.assert_allclose(c.numpy(), [2,4,6,8])
if __name__ == '__main__':
unittest.main()

View File

@ -14,8 +14,9 @@ class ClangProgram:
prg = "#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define half __fp16\n" + prg
# TODO: is there a way to not write this to disk?
fn = f"/tmp/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{'dylib' if platform.system() == 'Darwin' else 'so'}"
# NOTE: --rtlib=compiler-rt fixes float16 on Linux, it defines __gnu_h2f_ieee and __gnu_f2h_ieee
if not os.path.exists(fn):
subprocess.check_output(['clang', '-shared', '-O2', '-Wall','-Werror', '-lm', '-fPIC', '-x', 'c', '-', '-o', fn+".tmp"], input=prg.encode('utf-8'))
subprocess.check_output(['clang', '-shared', '-O2', '-Wall','-Werror', '-lm', '--rtlib=compiler-rt', '-fPIC', '-x', 'c', '-', '-o', fn+".tmp"], input=prg.encode('utf-8'))
os.rename(fn+".tmp", fn)
self.lib = ctypes.CDLL(fn)
self.fxn = self.lib[name]