diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 96698ec6..fe3569b4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -240,9 +240,8 @@ jobs: key: downloads-cache-metal-${{ env.DOWNLOAD_CACHE_VERSION }} - name: Test LLaMA compile speed run: PYTHONPATH="." METAL=1 python test/external/external_test_speed_llama.py - #- name: Run dtype test - # run: DEBUG=4 METAL=1 python -m pytest -n=auto test/test_dtype.py - # dtype test has issues on test_half_to_int8 + - name: Run dtype test + run: DEBUG=4 METAL=1 python -m pytest -n=auto test/test_dtype.py - name: Check Device.DEFAULT (METAL) and print some source run: | METAL=1 python -c "from tinygrad import Device; assert Device.DEFAULT == 'METAL', Device.DEFAULT" diff --git a/test/test_ops.py b/test/test_ops.py index 69dee1ed..a378cac0 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -865,7 +865,6 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv_transpose2d(x,w, stride=stride).relu(), lambda x,w: Tensor.conv_transpose2d(x,w,stride=stride).relu(), atol=1e-4, grad_rtol=1e-5) - @unittest.skipIf(Device.DEFAULT == "METAL" and CI, "broken in METAL CI") def test_output_padded_conv_transpose2d(self): for output_padding, stride in [((1,1), (2,3)), ((2,1), (3,2))]: helper_test_op([(2,4,6,5), (4,4,3,3),(4,)], @@ -1036,14 +1035,12 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv2d(torch.nn.functional.pad(x, p),w).relu(), lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4) - @unittest.skipIf(Device.DEFAULT == "METAL" and CI, "broken in METAL CI") def test_padded_conv2d_p21(self): bs,cin,H,W,padding = 4, 3, 3, 3, (2,1) helper_test_op([(bs,cin,11,28), (4,cin,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(), lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4) - @unittest.skipIf(Device.DEFAULT == "METAL" and CI, "broken in METAL CI") def test_padded_conv2d_p22(self): bs,cin,H,W,padding = 4, 3, 3, 3, (2,2) helper_test_op([(bs,cin,11,28), (4,cin,H,W)],