Limit dims based on max size (#1390)

* working

* whitespace

* changed defaults to None

* linter

* last linter error
This commit is contained in:
Diogo 2023-07-31 22:18:19 -04:00 committed by GitHub
parent b2fde9ec36
commit ba5e3818a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 28 additions and 20 deletions

View File

@ -5,9 +5,6 @@ from tinygrad.tensor import Tensor
from tinygrad.lazy import LAZY from tinygrad.lazy import LAZY
from tinygrad.ops import GlobalCounters from tinygrad.ops import GlobalCounters
from tinygrad.graph import nm from tinygrad.graph import nm
import pytest
pytestmark = pytest.mark.webgpu
N = 200 # has to be bigger than the cache to fail N = 200 # has to be bigger than the cache to fail

View File

@ -3,7 +3,7 @@ import numpy as np
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
import pytest import pytest
pytestmark = [pytest.mark.exclude_cuda, pytest.mark.webgpu] pytestmark = [pytest.mark.exclude_cuda]
class TestConv(unittest.TestCase): class TestConv(unittest.TestCase):
def test_simple(self): def test_simple(self):

View File

@ -9,7 +9,7 @@ from tinygrad.nn import BatchNorm2d, Conv1d, ConvTranspose1d, Conv2d, ConvTransp
import torch import torch
import pytest import pytest
pytestmark = [pytest.mark.exclude_cuda, pytest.mark.webgpu] pytestmark = [pytest.mark.exclude_cuda]
class TestNN(unittest.TestCase): class TestNN(unittest.TestCase):

View File

@ -6,9 +6,6 @@ import unittest
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI from tinygrad.helpers import getenv, IMAGE, DEBUG, CI
from tinygrad.lazy import Device from tinygrad.lazy import Device
import pytest
pytestmark = pytest.mark.webgpu
if CI: if CI:
import warnings import warnings
@ -807,7 +804,7 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv_transpose3d(x,w).relu(), lambda x,w: torch.nn.functional.conv_transpose3d(x,w).relu(),
lambda x,w: Tensor.conv_transpose2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5) lambda x,w: Tensor.conv_transpose2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
@unittest.skipIf((IMAGE>0 or (Device.DEFAULT == "WEBGPU" and getenv("CI","") != "")), "no conv1d on images") @unittest.skipIf((IMAGE>0), "no conv1d on images")
def test_conv1d(self): def test_conv1d(self):
for bs in [1,8]: for bs in [1,8]:
for cin in [1,3]: for cin in [1,3]:

View File

@ -18,7 +18,7 @@ from tinygrad.helpers import colored, getenv, DEBUG, CI
from tinygrad.jit import TinyJit from tinygrad.jit import TinyJit
import pytest import pytest
pytestmark = [pytest.mark.exclude_cuda, pytest.mark.exclude_gpu, pytest.mark.exclude_clang, pytest.mark.webgpu] pytestmark = [pytest.mark.exclude_cuda, pytest.mark.exclude_gpu, pytest.mark.exclude_clang]
IN_CHANS = [int(x) for x in getenv("IN_CHANS", "4,16,64").split(",")] IN_CHANS = [int(x) for x in getenv("IN_CHANS", "4,16,64").split(",")]
@ -130,7 +130,7 @@ class TestBigSpeed(unittest.TestCase):
def test_large_conv_3x3(self): helper_test_conv(bs=4, in_chans=128, out_chans=128, kernel_size=3, img_size_y=130, img_size_x=130) def test_large_conv_3x3(self): helper_test_conv(bs=4, in_chans=128, out_chans=128, kernel_size=3, img_size_y=130, img_size_x=130)
def test_large_conv_5x5(self): helper_test_conv(bs=4, in_chans=128, out_chans=128, kernel_size=5, img_size_y=130, img_size_x=130) def test_large_conv_5x5(self): helper_test_conv(bs=4, in_chans=128, out_chans=128, kernel_size=5, img_size_y=130, img_size_x=130)
@unittest.skipIf((getenv("BIG") == 1 or Device.DEFAULT == "WEBGPU"), "only big tests") @unittest.skipIf((getenv("BIG") == 1), "only big tests")
class TestSpeed(unittest.TestCase): class TestSpeed(unittest.TestCase):
def test_sub(self): def test_sub(self):
def f(a, b): return a-b def f(a, b): return a-b

View File

@ -2,13 +2,9 @@ import dataclasses
import numpy as np import numpy as np
import torch import torch
import unittest import unittest
import itertools from tinygrad.tensor import Tensor
from tinygrad.tensor import Tensor, Device
from tinygrad.helpers import dtypes from tinygrad.helpers import dtypes
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
import pytest
pytestmark = pytest.mark.webgpu
x_init = np.random.randn(1,3).astype(np.float32) x_init = np.random.randn(1,3).astype(np.float32)
U_init = np.random.randn(3,3).astype(np.float32) U_init = np.random.randn(3,3).astype(np.float32)

View File

@ -21,6 +21,7 @@ class CStyleLanguage(NamedTuple):
gid: List[str] = [] gid: List[str] = []
lid: List[str] = [] lid: List[str] = []
global_max: List[int] = [] global_max: List[int] = []
local_max: List[int] = []
extra_args: List[str] = [] extra_args: List[str] = []
float4: Optional[str] = None float4: Optional[str] = None
half_prekernel: Optional[str] = None half_prekernel: Optional[str] = None
@ -195,7 +196,7 @@ class CStyleCodegen(Linearizer):
def codegen(self): def codegen(self):
self.process() self.process()
if self.lang.global_max: self.limit_global_dims(len(self.lang.gid), self.lang.global_max) # NOTE: this is optional now if self.lang.global_max: self.limit_global_dims(len(self.lang.gid), self.lang.global_max, self.lang.local_max) # NOTE: this is optional now
self.linearize() self.linearize()
prg, global_size, local_size = uops_to_cstyle(self.uops, self.lang) prg, global_size, local_size = uops_to_cstyle(self.uops, self.lang)

View File

@ -595,8 +595,20 @@ class Linearizer:
for i,x in enumerate(rets): self.sts[i].reshape(tuple([y[0] for y in x])) for i,x in enumerate(rets): self.sts[i].reshape(tuple([y[0] for y in x]))
# ******************** GPU simplifiers ******************** # ******************** GPU simplifiers ********************
def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]:
new_shape,dims = list(x), len(x)
for i in range(dims):
next_idx = (i + 1) % dims
while new_shape[i] > max_size[i]:
new_shape[i] = new_shape[i] // 2
if (new_shape[next_idx] <= max_size[next_idx]):
new_shape[next_idx] = new_shape[next_idx] * 2
else:
next_idx = (next_idx + 1) % dims
new_shape[next_idx] = new_shape[next_idx] * 2
return tuple(new_shape)
def limit_global_dims(self, limit, global_max): def limit_global_dims(self, limit: int, global_max: List[int], local_max: List[int]):
# sometimes, there's more dimensions than len(self.lang.gid). # sometimes, there's more dimensions than len(self.lang.gid).
# compact all the dimensions into the first # compact all the dimensions into the first
# NOTE: this might make multiview shapetrackers # NOTE: this might make multiview shapetrackers
@ -608,7 +620,10 @@ class Linearizer:
# and then padded right with 1s if its length < 3 which makes this part a bit awkward to write # and then padded right with 1s if its length < 3 which makes this part a bit awkward to write
global_dims = self.first_reduce-self.local_dims global_dims = self.first_reduce-self.local_dims
if global_dims > 0: if global_dims > 0:
assert max(global_max) >= max(self.full_shape[0:global_dims]), f"device max allocation {max(self.full_shape[0:global_dims])} exceeds global dim maximum {max(global_max)}" if global_max:
tmp = global_max[:global_dims] + (local_max[:self.local_dims] if local_max else [])
if max(global_max) < max(self.full_shape[:global_dims]): self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None)
assert max(global_max) >= max(self.full_shape[:global_dims]), f"device max allocation {max(self.full_shape[:global_dims])} exceeds global dim maximum {max(global_max)}"
for i in range(global_dims-1): for i in range(global_dims-1):
if self.full_shape[i] > global_max[i]: if self.full_shape[i] > global_max[i]:
order = list(range(len(self.full_shape))) order = list(range(len(self.full_shape)))

View File

@ -11,6 +11,8 @@ class WGSLLanguage(CStyleLanguage):
gid = [f"i32(gindex.{'xyz'[x]})" for x in range(3)] gid = [f"i32(gindex.{'xyz'[x]})" for x in range(3)]
lid = [f"i32(lindex.{'xyz'[x]})" for x in range(3)] lid = [f"i32(lindex.{'xyz'[x]})" for x in range(3)]
size_prefix = "let" size_prefix = "let"
global_max = [65535, 65535, 65535]
local_max = [256, 256, 64]
barrier="workgroupBarrier();" barrier="workgroupBarrier();"
generic_var_prefix = "var " generic_var_prefix = "var "
external_local_bufs = True external_local_bufs = True