mirror of https://github.com/commaai/tinygrad.git
Limit dims based on max size (#1390)
* working * whitespace * changed defaults to None * linter * last linter error
This commit is contained in:
parent
b2fde9ec36
commit
ba5e3818a0
|
@ -5,9 +5,6 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.lazy import LAZY
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.graph import nm
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.webgpu
|
||||
|
||||
N = 200 # has to be bigger than the cache to fail
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
from tinygrad.tensor import Tensor
|
||||
import pytest
|
||||
|
||||
pytestmark = [pytest.mark.exclude_cuda, pytest.mark.webgpu]
|
||||
pytestmark = [pytest.mark.exclude_cuda]
|
||||
|
||||
class TestConv(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
|
|
|
@ -9,7 +9,7 @@ from tinygrad.nn import BatchNorm2d, Conv1d, ConvTranspose1d, Conv2d, ConvTransp
|
|||
import torch
|
||||
import pytest
|
||||
|
||||
pytestmark = [pytest.mark.exclude_cuda, pytest.mark.webgpu]
|
||||
pytestmark = [pytest.mark.exclude_cuda]
|
||||
|
||||
class TestNN(unittest.TestCase):
|
||||
|
||||
|
|
|
@ -6,9 +6,6 @@ import unittest
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI
|
||||
from tinygrad.lazy import Device
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.webgpu
|
||||
|
||||
if CI:
|
||||
import warnings
|
||||
|
@ -807,7 +804,7 @@ class TestOps(unittest.TestCase):
|
|||
lambda x,w: torch.nn.functional.conv_transpose3d(x,w).relu(),
|
||||
lambda x,w: Tensor.conv_transpose2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
@unittest.skipIf((IMAGE>0 or (Device.DEFAULT == "WEBGPU" and getenv("CI","") != "")), "no conv1d on images")
|
||||
@unittest.skipIf((IMAGE>0), "no conv1d on images")
|
||||
def test_conv1d(self):
|
||||
for bs in [1,8]:
|
||||
for cin in [1,3]:
|
||||
|
|
|
@ -18,7 +18,7 @@ from tinygrad.helpers import colored, getenv, DEBUG, CI
|
|||
from tinygrad.jit import TinyJit
|
||||
import pytest
|
||||
|
||||
pytestmark = [pytest.mark.exclude_cuda, pytest.mark.exclude_gpu, pytest.mark.exclude_clang, pytest.mark.webgpu]
|
||||
pytestmark = [pytest.mark.exclude_cuda, pytest.mark.exclude_gpu, pytest.mark.exclude_clang]
|
||||
|
||||
IN_CHANS = [int(x) for x in getenv("IN_CHANS", "4,16,64").split(",")]
|
||||
|
||||
|
@ -130,7 +130,7 @@ class TestBigSpeed(unittest.TestCase):
|
|||
def test_large_conv_3x3(self): helper_test_conv(bs=4, in_chans=128, out_chans=128, kernel_size=3, img_size_y=130, img_size_x=130)
|
||||
def test_large_conv_5x5(self): helper_test_conv(bs=4, in_chans=128, out_chans=128, kernel_size=5, img_size_y=130, img_size_x=130)
|
||||
|
||||
@unittest.skipIf((getenv("BIG") == 1 or Device.DEFAULT == "WEBGPU"), "only big tests")
|
||||
@unittest.skipIf((getenv("BIG") == 1), "only big tests")
|
||||
class TestSpeed(unittest.TestCase):
|
||||
def test_sub(self):
|
||||
def f(a, b): return a-b
|
||||
|
|
|
@ -2,13 +2,9 @@ import dataclasses
|
|||
import numpy as np
|
||||
import torch
|
||||
import unittest
|
||||
import itertools
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes
|
||||
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.webgpu
|
||||
|
||||
x_init = np.random.randn(1,3).astype(np.float32)
|
||||
U_init = np.random.randn(3,3).astype(np.float32)
|
||||
|
|
|
@ -21,6 +21,7 @@ class CStyleLanguage(NamedTuple):
|
|||
gid: List[str] = []
|
||||
lid: List[str] = []
|
||||
global_max: List[int] = []
|
||||
local_max: List[int] = []
|
||||
extra_args: List[str] = []
|
||||
float4: Optional[str] = None
|
||||
half_prekernel: Optional[str] = None
|
||||
|
@ -195,7 +196,7 @@ class CStyleCodegen(Linearizer):
|
|||
|
||||
def codegen(self):
|
||||
self.process()
|
||||
if self.lang.global_max: self.limit_global_dims(len(self.lang.gid), self.lang.global_max) # NOTE: this is optional now
|
||||
if self.lang.global_max: self.limit_global_dims(len(self.lang.gid), self.lang.global_max, self.lang.local_max) # NOTE: this is optional now
|
||||
self.linearize()
|
||||
|
||||
prg, global_size, local_size = uops_to_cstyle(self.uops, self.lang)
|
||||
|
|
|
@ -595,8 +595,20 @@ class Linearizer:
|
|||
for i,x in enumerate(rets): self.sts[i].reshape(tuple([y[0] for y in x]))
|
||||
|
||||
# ******************** GPU simplifiers ********************
|
||||
def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]:
|
||||
new_shape,dims = list(x), len(x)
|
||||
for i in range(dims):
|
||||
next_idx = (i + 1) % dims
|
||||
while new_shape[i] > max_size[i]:
|
||||
new_shape[i] = new_shape[i] // 2
|
||||
if (new_shape[next_idx] <= max_size[next_idx]):
|
||||
new_shape[next_idx] = new_shape[next_idx] * 2
|
||||
else:
|
||||
next_idx = (next_idx + 1) % dims
|
||||
new_shape[next_idx] = new_shape[next_idx] * 2
|
||||
return tuple(new_shape)
|
||||
|
||||
def limit_global_dims(self, limit, global_max):
|
||||
def limit_global_dims(self, limit: int, global_max: List[int], local_max: List[int]):
|
||||
# sometimes, there's more dimensions than len(self.lang.gid).
|
||||
# compact all the dimensions into the first
|
||||
# NOTE: this might make multiview shapetrackers
|
||||
|
@ -607,8 +619,11 @@ class Linearizer:
|
|||
# Check the global allocation limit, current the global_size will be flipped during codegen
|
||||
# and then padded right with 1s if its length < 3 which makes this part a bit awkward to write
|
||||
global_dims = self.first_reduce-self.local_dims
|
||||
if global_dims > 0:
|
||||
assert max(global_max) >= max(self.full_shape[0:global_dims]), f"device max allocation {max(self.full_shape[0:global_dims])} exceeds global dim maximum {max(global_max)}"
|
||||
if global_dims > 0:
|
||||
if global_max:
|
||||
tmp = global_max[:global_dims] + (local_max[:self.local_dims] if local_max else [])
|
||||
if max(global_max) < max(self.full_shape[:global_dims]): self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None)
|
||||
assert max(global_max) >= max(self.full_shape[:global_dims]), f"device max allocation {max(self.full_shape[:global_dims])} exceeds global dim maximum {max(global_max)}"
|
||||
for i in range(global_dims-1):
|
||||
if self.full_shape[i] > global_max[i]:
|
||||
order = list(range(len(self.full_shape)))
|
||||
|
|
|
@ -11,6 +11,8 @@ class WGSLLanguage(CStyleLanguage):
|
|||
gid = [f"i32(gindex.{'xyz'[x]})" for x in range(3)]
|
||||
lid = [f"i32(lindex.{'xyz'[x]})" for x in range(3)]
|
||||
size_prefix = "let"
|
||||
global_max = [65535, 65535, 65535]
|
||||
local_max = [256, 256, 64]
|
||||
barrier="workgroupBarrier();"
|
||||
generic_var_prefix = "var "
|
||||
external_local_bufs = True
|
||||
|
|
Loading…
Reference in New Issue