mirror of https://github.com/commaai/tinygrad.git
Removing METAL Skips as CI works (#2488)
* Test metal CI * remove metal and CI restrictions * enable dtype tests for metal ci
This commit is contained in:
parent
5588922884
commit
cf0c9096a9
|
@ -240,9 +240,8 @@ jobs:
|
|||
key: downloads-cache-metal-${{ env.DOWNLOAD_CACHE_VERSION }}
|
||||
- name: Test LLaMA compile speed
|
||||
run: PYTHONPATH="." METAL=1 python test/external/external_test_speed_llama.py
|
||||
#- name: Run dtype test
|
||||
# run: DEBUG=4 METAL=1 python -m pytest -n=auto test/test_dtype.py
|
||||
# dtype test has issues on test_half_to_int8
|
||||
- name: Run dtype test
|
||||
run: DEBUG=4 METAL=1 python -m pytest -n=auto test/test_dtype.py
|
||||
- name: Check Device.DEFAULT (METAL) and print some source
|
||||
run: |
|
||||
METAL=1 python -c "from tinygrad import Device; assert Device.DEFAULT == 'METAL', Device.DEFAULT"
|
||||
|
|
|
@ -865,7 +865,6 @@ class TestOps(unittest.TestCase):
|
|||
lambda x,w: torch.nn.functional.conv_transpose2d(x,w, stride=stride).relu(),
|
||||
lambda x,w: Tensor.conv_transpose2d(x,w,stride=stride).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "METAL" and CI, "broken in METAL CI")
|
||||
def test_output_padded_conv_transpose2d(self):
|
||||
for output_padding, stride in [((1,1), (2,3)), ((2,1), (3,2))]:
|
||||
helper_test_op([(2,4,6,5), (4,4,3,3),(4,)],
|
||||
|
@ -1036,14 +1035,12 @@ class TestOps(unittest.TestCase):
|
|||
lambda x,w: torch.nn.functional.conv2d(torch.nn.functional.pad(x, p),w).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "METAL" and CI, "broken in METAL CI")
|
||||
def test_padded_conv2d_p21(self):
|
||||
bs,cin,H,W,padding = 4, 3, 3, 3, (2,1)
|
||||
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "METAL" and CI, "broken in METAL CI")
|
||||
def test_padded_conv2d_p22(self):
|
||||
bs,cin,H,W,padding = 4, 3, 3, 3, (2,2)
|
||||
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
|
||||
|
|
Loading…
Reference in New Issue