test: fix test_tensor_core_padded on CUDA and add to benchmarks (#4258)

* test: fix test_tensor_core_padded on CUDA and add to benchmarks

* fix linter

* run both tests in one call
This commit is contained in:
Francis Lam 2024-04-22 20:22:11 -07:00 committed by GitHub
parent a90de3b574
commit 3f6c7ca8bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 56 additions and 45 deletions

View File

@ -34,6 +34,8 @@ jobs:
run: METAL=1 python3 test/external/external_model_benchmark.py
- name: Test speed vs torch
run: BIG=2 MPS=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt
- name: Test tensor cores
run: METAL=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
- name: Run Tensor Core GEMM
run: |
DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
@ -112,6 +114,10 @@ jobs:
run: CUDA=1 python3 test/external/external_model_benchmark.py
- name: Test speed vs torch
run: CUDA=1 BIG=2 TORCHCUDA=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt
- name: Test tensor cores
run: |
CUDA=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
PTX=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
- name: Run Tensor Core GEMM (CUDA)
run: |
CUDA=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
@ -197,6 +203,10 @@ jobs:
# run: |
# python3 -c "import torch; print(torch.__version__)"
# LD_PRELOAD="/opt/rocm/lib/libhsa-runtime64.so" HSA=1 BIG=2 TORCHCUDA=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt
- name: Test tensor cores
run: |
HSA=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
KFD=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
- name: Run Tensor Core GEMM (HSA)
run: HSA=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
- name: Run Tensor Core GEMM (KFD)

View File

@ -15,6 +15,42 @@ from tinygrad.helpers import prod, Context, getenv
from tinygrad.dtype import DType, dtypes
from tinygrad.codegen.uops import UOpGraph
def helper_tc_allclose(m:int, k:int, n:int, dtype_in:DType, dtype_out:DType, tc_opt:int):
a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in)
np_a, np_b = a.numpy(), b.numpy()
r = a.matmul(b, acc_dtype=dtype_out)
sched = create_schedule([r.lazydata])
realized_ast = sched[-1].ast[0]
run_schedule(sched)
out = r.numpy()
k = Linearizer(realized_ast)
k.apply_tensor_cores(1, tc_opt=tc_opt)
k.linearize()
assert len([uop for uop in k.uops if uop.uop is UOps.WMMA]) > 0, "tensor core not triggered"
assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
np_c = np_a @ np_b
if dtype_out == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3
elif dtype_in == dtypes.bfloat16: tc_atol, tc_rtol = 1e-2, 3e-3
else: tc_atol, tc_rtol = 5e-3, 1e-4
np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol)
def helper_tc_ensure_uops_and_opts_count(m:int, k:int, n:int, dtype_in:DType, dtype_out:DType, tc_opt:int, ensure_triggered:bool=True):
a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in)
r = a.matmul(b, acc_dtype=dtype_out)
sched = create_schedule([r.lazydata])
realized_ast = sched[-1].ast[0]
k = Linearizer(realized_ast)
k.apply_tensor_cores(1, tc_opt=tc_opt)
k.linearize()
wmmas = len([uop for uop in k.uops if uop.uop is UOps.WMMA])
tcs = len([x for x in k.applied_opts if x.op is OptOps.TC])
if ensure_triggered:
assert wmmas > 0, "tensor core not triggered"
assert tcs == 1, "tensor core opt not included"
else:
assert wmmas == 0, "tensor core is incorrectly triggered"
assert tcs == 0, "tensor core opt is incorrectly included"
class TestLinearizer(unittest.TestCase):
def test_arg_dedup(self):
a, b = Tensor.randn(4), Tensor.randn(4)
@ -178,23 +214,7 @@ class TestLinearizer(unittest.TestCase):
self.skipTest("device doesn't have tensor cores")
for tc in tensor_cores[Device[Device.DEFAULT].compiler.compiler_opts.device]:
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
a, b = Tensor.rand(tc.dims[1], tc.dims[2], dtype=tc.dtype_in), Tensor.rand(tc.dims[2], tc.dims[0], dtype=tc.dtype_in)
np_a, np_b = a.numpy(), b.numpy()
r = a.matmul(b, acc_dtype=tc.dtype_out)
sched = create_schedule([r.lazydata])
realized_ast = sched[-1].ast[0]
run_schedule(sched)
out = r.numpy()
k = Linearizer(realized_ast)
k.apply_tensor_cores(1)
k.linearize()
assert len([uop for uop in k.uops if uop.uop is UOps.WMMA]) == 1, "tensor core not triggered"
assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
np_c = np_a @ np_b
if tc.dtype_out == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3
elif tc.dtype_in == dtypes.bfloat16: tc_atol, tc_rtol = 1e-2, 3e-3
else: tc_atol, tc_rtol = 5e-3, 1e-4
np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol)
helper_tc_allclose(tc.dims[1], tc.dims[2], tc.dims[0], tc.dtype_in, tc.dtype_out, tc_opt=0)
def test_tensor_cores_padded(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores:
@ -203,41 +223,22 @@ class TestLinearizer(unittest.TestCase):
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
pad = 1
def ensure_uops_and_opts_count(m:int, k:int, n:int, tc_opt:int, ensure_triggered:bool=True):
a, b = Tensor.rand(m, k, dtype=tc.dtype_in), Tensor.rand(k, n, dtype=tc.dtype_in)
r = a.matmul(b, acc_dtype=tc.dtype_out)
sched = create_schedule([r.lazydata])
realized_ast = sched[-1].ast[0]
k = Linearizer(realized_ast)
k.apply_tensor_cores(1, tc_opt=tc_opt)
k.linearize()
wmmas = len([uop for uop in k.uops if uop.uop is UOps.WMMA])
tcs = len([x for x in k.applied_opts if x.op is OptOps.TC])
if ensure_triggered:
assert wmmas > 0, "tensor core not triggered"
assert tcs == 1, "tensor core opt not included"
else:
assert wmmas == 0, "tensor core is incorrectly triggered"
assert tcs == 0, "tensor core opt is incorrectly included"
# check that TC is triggered for TC_OPT=2
ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[2]+pad, tc.dims[1]+pad, tc_opt=2, ensure_triggered=True)
helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[2]+pad, tc.dims[1]+pad,tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=True)
# check that TC is not triggered for TC_OPT<2
ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[2]+pad, tc.dims[1]+pad, tc_opt=1, ensure_triggered=False)
helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[2]+pad, tc.dims[1]+pad,
tc.dtype_in, tc.dtype_out, tc_opt=1, ensure_triggered=False)
helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[2]+pad, tc.dims[1]+pad,
tc.dtype_in, tc.dtype_out, tc_opt=0, ensure_triggered=False)
# check excessive padding doesn't trigger padded TC in TC_OPT=2
ensure_uops_and_opts_count(tc.dims[0]//2, tc.dims[2], tc.dims[1], tc_opt=2, ensure_triggered=False)
ensure_uops_and_opts_count(tc.dims[0], tc.dims[2]//2, tc.dims[1], tc_opt=2, ensure_triggered=False)
ensure_uops_and_opts_count(tc.dims[0], tc.dims[2], tc.dims[1]//2, tc_opt=2, ensure_triggered=False)
helper_tc_ensure_uops_and_opts_count(tc.dims[0]//2, tc.dims[2], tc.dims[1], tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False)
helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[2]//2, tc.dims[1], tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False)
helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[2], tc.dims[1]//2, tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False)
# check correctness
a, b = Tensor.rand(tc.dims[1]+pad, tc.dims[2]+pad, dtype=tc.dtype_in), Tensor.rand(tc.dims[2]+pad, tc.dims[0]+pad, dtype=tc.dtype_in)
r = a.matmul(b, acc_dtype=tc.dtype_out)
(atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4)
helper_linearizer_opt(r, [
[Opt(OptOps.TC, axis=0, amt=2)],
], atol=atol, rtol=rtol)
helper_tc_allclose(tc.dims[1]+pad, tc.dims[2]+pad, tc.dims[0]+pad, tc.dtype_in, tc.dtype_out, tc_opt=2)
def test_limit_dims_to_max_5d_global(self):
t = Tensor.empty(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1