diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d2604827..01496f60 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -40,7 +40,9 @@ jobs: - name: Lint tinygrad with pylint run: pylint tinygrad/ - name: Run mypy - run: mypy tinygrad/ test/ --ignore-missing-imports + run: | + mypy tinygrad/ --ignore-missing-imports --check-untyped-defs + mypy test/ --ignore-missing-imports testcpu: name: CPU Tests diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4c2d68d0..8b2b01f8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,7 +15,13 @@ repos: pass_filenames: false - id: mypy name: mypy - entry: mypy tinygrad/ test/ --ignore-missing-imports + entry: mypy tinygrad/ --ignore-missing-imports --check-untyped-defs + language: system + always_run: true + pass_filenames: false + - id: mypytest + name: mypytest + entry: mypy test/ --ignore-missing-imports language: system always_run: true pass_filenames: false diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index d9782389..3ed45f41 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -3,7 +3,7 @@ import os, math, functools, time from typing import Tuple, Union def dedup(x): return list(dict.fromkeys(x)) # retains list order -def prod(x): return math.prod(x) +def prod(x) -> int: return math.prod(x) def argfix(*x): return tuple() if len(x) == 0 else tuple(x[0]) if isinstance(x[0], (tuple, list)) else tuple(x) def argsort(x): return sorted(range(len(x)), key=x.__getitem__) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python def all_same(items): return all(x == items[0] for x in items) if len(items) > 0 else True diff --git a/tinygrad/jit.py b/tinygrad/jit.py index 94c86d82..57c83d37 100644 --- a/tinygrad/jit.py +++ b/tinygrad/jit.py @@ -1,18 +1,18 @@ -from typing import Callable, List, Tuple +from typing import Callable, List, Tuple, Any, Dict import itertools from tinygrad.lazy import Device from tinygrad.tensor import Tensor -from tinygrad.ops import GlobalCounters +from tinygrad.ops import GlobalCounters, DeviceBuffer class TinyJit: - def __init__(self, fxn): + def __init__(self, fxn:Callable): self.fxn = fxn self.cnt = 0 - self.jit_cache : List[Tuple[Callable, List]] = [] + self.jit_cache : List[Tuple[Callable, Any]] = [] # TODO: Any should be List[DeviceBuffer], but this fails self.ret = None - self.input_replace = {} + self.input_replace : Dict[DeviceBuffer, Any]= {} - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> Any: if Device.DEFAULT != "GPU": return self.fxn(*args, **kwargs) # only jit on the GPU input_tensors = {k:v.realize().lazydata.realized._buf for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)} assert len(input_tensors) != 0, "no inputs to JIT" diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index 178aed23..af247bb6 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -1,3 +1,4 @@ +from typing import Optional from tinygrad.tensor import Tensor class BatchNorm2d: @@ -59,15 +60,16 @@ class Linear: class GroupNorm: def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): - self.num_groups, self.num_channels, self.eps, self.affine = num_groups, num_channels, eps, affine - self.weight, self.bias = (Tensor.ones(num_channels), Tensor.zeros(num_channels)) if affine else (None, None) + self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps + self.weight : Optional[Tensor] = Tensor.ones(num_channels) if affine else None + self.bias : Optional[Tensor] = Tensor.zeros(num_channels) if affine else None def __call__(self, x:Tensor): # reshape for layernorm to work as group norm # subtract mean and divide stddev x = x.reshape(x.shape[0], self.num_groups, -1).layernorm(eps=self.eps).reshape(x.shape) - if not self.affine: return x + if self.weight is None or self.bias is None: return x # elementwise_affine on channels return x * self.weight.reshape(1, -1, 1, 1) + self.bias.reshape(1, -1, 1, 1) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index a1ff7427..2fe68171 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -35,6 +35,7 @@ def map_buffers(real_srcs, x:LazyOp) -> LazyOp: # a placeholder class to extend by the exec classes class DeviceBuffer: + _buf: Any # underlying buffer shape: Any # should be Tuple[int, ...] but ndarray and torch.tensor have incompatible types @staticmethod def fromCPU(x:np.ndarray) -> DeviceBuffer: raise NotImplementedError("must be implemented") diff --git a/tinygrad/shape/__init__.py b/tinygrad/shape/__init__.py index f148e06e..62b5a388 100644 --- a/tinygrad/shape/__init__.py +++ b/tinygrad/shape/__init__.py @@ -48,6 +48,8 @@ class ZeroView: # fake properties self.strides, self.contiguous, self.offset = strides_for_shape(self.shape), False, 0 + def __repr__(self): return f"ZeroView({self.old_shape}, {self.arg})" + def expr_node(self, idx=None, valid=None): if idx is None: idx = Variable('idx', 0, prod([y-x for x,y in self.arg])) expr, acc = [valid] if valid is not None else [], 1 @@ -57,7 +59,7 @@ class ZeroView: acc *= ns return Variable.ands(expr) - def __repr__(self): return f"ZeroView({self.old_shape}, {self.arg})" + def expr_idxs(self, idxs, offset=0): raise NotImplementedError("ZeroView doesn't support expr_idxs") ViewTypes = Union[View, ZeroView] diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 57463970..cc5d3393 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -113,24 +113,24 @@ class Tensor: # ***** creation helper functions ***** # TODO: remove use of numpy here and make lazy - @classmethod - def zeros(cls, *shape, **kwargs): return cls([0], **kwargs).reshape([1]*len(shape)).expand(shape) + @staticmethod + def zeros(*shape, **kwargs): return Tensor([0], **kwargs).reshape([1]*len(shape)).expand(shape) - @classmethod - def ones(cls, *shape, **kwargs): return cls([1], **kwargs).reshape([1]*len(shape)).expand(shape) + @staticmethod + def ones(*shape, **kwargs): return Tensor([1], **kwargs).reshape([1]*len(shape)).expand(shape) - @classmethod - def zeros_like(cls, tensor, **kwargs): return cls.zeros(*tensor.shape, **kwargs) + @staticmethod + def zeros_like(tensor, **kwargs): return Tensor.zeros(*tensor.shape, **kwargs) - @classmethod - def empty(cls, *shape, **kwargs): return cls.zeros(*shape, **kwargs) + @staticmethod + def empty(*shape, **kwargs): return Tensor.zeros(*shape, **kwargs) - @classmethod - def eye(cls, dim, **kwargs): return cls([1], **kwargs).slice(((0,dim+1),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).slice(((0,dim*dim),)).reshape(dim, dim) + @staticmethod + def eye(dim, **kwargs): return Tensor([1], **kwargs).slice(((0,dim+1),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).slice(((0,dim*dim),)).reshape(dim, dim) # TODO: requires cumsum to remove numpy - @classmethod - def arange(cls, stop, start=0, step=1, **kwargs): return cls(np.arange(start=start, stop=stop, step=step, dtype=np.float32), **kwargs) + @staticmethod + def arange(stop, start=0, step=1, **kwargs): return Tensor(np.arange(start=start, stop=stop, step=step, dtype=np.float32), **kwargs) # ***** (numpy) rng helper functions ***** # TODO: move randomness generation out of numpy @@ -139,24 +139,24 @@ class Tensor: @staticmethod def manual_seed(seed=None): Tensor._rng = np.random.default_rng(seed=seed) - @classmethod - def rand(cls, *shape, **kwargs): return cls(cls._rng.random(size=shape, dtype=np.float32), **kwargs) + @staticmethod + def rand(*shape, **kwargs): return Tensor(Tensor._rng.random(size=shape, dtype=np.float32), **kwargs) # TODO: replace with a transformation from uniform -> gaussian - @classmethod - def randn(cls, *shape, **kwargs): return cls(cls._rng.standard_normal(size=shape, dtype=np.float32), **kwargs) + @staticmethod + def randn(*shape, **kwargs): return Tensor(Tensor._rng.standard_normal(size=shape, dtype=np.float32), **kwargs) # ***** rng hlops ***** - @classmethod - def uniform(cls, *shape, **kwargs): return cls.rand(*shape, **kwargs) * 2 - 1 + @staticmethod + def uniform(*shape, **kwargs) -> Tensor: return Tensor.rand(*shape, **kwargs) * 2 - 1 - @classmethod - def scaled_uniform(cls, *shape, **kwargs): return cls.uniform(*shape, **kwargs) * (prod(shape)**-0.5) + @staticmethod + def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs) * (prod(shape)**-0.5) - @classmethod # https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform - def glorot_uniform(cls, *shape, **kwargs): return cls.uniform(*shape, **kwargs) * ((6/(shape[0]+prod(shape[1:])))**0.5) + @staticmethod + def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs) * ((6/(shape[0]+prod(shape[1:])))**0.5) # ***** toposort and backward pass *****