typing fixup

This commit is contained in:
George Hotz 2023-02-27 09:52:04 -08:00
parent 9aaa7edd74
commit e74779f19d
8 changed files with 48 additions and 35 deletions

View File

@ -40,7 +40,9 @@ jobs:
- name: Lint tinygrad with pylint
run: pylint tinygrad/
- name: Run mypy
run: mypy tinygrad/ test/ --ignore-missing-imports
run: |
mypy tinygrad/ --ignore-missing-imports --check-untyped-defs
mypy test/ --ignore-missing-imports
testcpu:
name: CPU Tests

View File

@ -15,7 +15,13 @@ repos:
pass_filenames: false
- id: mypy
name: mypy
entry: mypy tinygrad/ test/ --ignore-missing-imports
entry: mypy tinygrad/ --ignore-missing-imports --check-untyped-defs
language: system
always_run: true
pass_filenames: false
- id: mypytest
name: mypytest
entry: mypy test/ --ignore-missing-imports
language: system
always_run: true
pass_filenames: false

View File

@ -3,7 +3,7 @@ import os, math, functools, time
from typing import Tuple, Union
def dedup(x): return list(dict.fromkeys(x)) # retains list order
def prod(x): return math.prod(x)
def prod(x) -> int: return math.prod(x)
def argfix(*x): return tuple() if len(x) == 0 else tuple(x[0]) if isinstance(x[0], (tuple, list)) else tuple(x)
def argsort(x): return sorted(range(len(x)), key=x.__getitem__) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
def all_same(items): return all(x == items[0] for x in items) if len(items) > 0 else True

View File

@ -1,18 +1,18 @@
from typing import Callable, List, Tuple
from typing import Callable, List, Tuple, Any, Dict
import itertools
from tinygrad.lazy import Device
from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters
from tinygrad.ops import GlobalCounters, DeviceBuffer
class TinyJit:
def __init__(self, fxn):
def __init__(self, fxn:Callable):
self.fxn = fxn
self.cnt = 0
self.jit_cache : List[Tuple[Callable, List]] = []
self.jit_cache : List[Tuple[Callable, Any]] = [] # TODO: Any should be List[DeviceBuffer], but this fails
self.ret = None
self.input_replace = {}
self.input_replace : Dict[DeviceBuffer, Any]= {}
def __call__(self, *args, **kwargs):
def __call__(self, *args, **kwargs) -> Any:
if Device.DEFAULT != "GPU": return self.fxn(*args, **kwargs) # only jit on the GPU
input_tensors = {k:v.realize().lazydata.realized._buf for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
assert len(input_tensors) != 0, "no inputs to JIT"

View File

@ -1,3 +1,4 @@
from typing import Optional
from tinygrad.tensor import Tensor
class BatchNorm2d:
@ -59,15 +60,16 @@ class Linear:
class GroupNorm:
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
self.num_groups, self.num_channels, self.eps, self.affine = num_groups, num_channels, eps, affine
self.weight, self.bias = (Tensor.ones(num_channels), Tensor.zeros(num_channels)) if affine else (None, None)
self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
self.weight : Optional[Tensor] = Tensor.ones(num_channels) if affine else None
self.bias : Optional[Tensor] = Tensor.zeros(num_channels) if affine else None
def __call__(self, x:Tensor):
# reshape for layernorm to work as group norm
# subtract mean and divide stddev
x = x.reshape(x.shape[0], self.num_groups, -1).layernorm(eps=self.eps).reshape(x.shape)
if not self.affine: return x
if self.weight is None or self.bias is None: return x
# elementwise_affine on channels
return x * self.weight.reshape(1, -1, 1, 1) + self.bias.reshape(1, -1, 1, 1)

View File

@ -35,6 +35,7 @@ def map_buffers(real_srcs, x:LazyOp) -> LazyOp:
# a placeholder class to extend by the exec classes
class DeviceBuffer:
_buf: Any # underlying buffer
shape: Any # should be Tuple[int, ...] but ndarray and torch.tensor have incompatible types
@staticmethod
def fromCPU(x:np.ndarray) -> DeviceBuffer: raise NotImplementedError("must be implemented")

View File

@ -48,6 +48,8 @@ class ZeroView:
# fake properties
self.strides, self.contiguous, self.offset = strides_for_shape(self.shape), False, 0
def __repr__(self): return f"ZeroView({self.old_shape}, {self.arg})"
def expr_node(self, idx=None, valid=None):
if idx is None: idx = Variable('idx', 0, prod([y-x for x,y in self.arg]))
expr, acc = [valid] if valid is not None else [], 1
@ -57,7 +59,7 @@ class ZeroView:
acc *= ns
return Variable.ands(expr)
def __repr__(self): return f"ZeroView({self.old_shape}, {self.arg})"
def expr_idxs(self, idxs, offset=0): raise NotImplementedError("ZeroView doesn't support expr_idxs")
ViewTypes = Union[View, ZeroView]

View File

@ -113,24 +113,24 @@ class Tensor:
# ***** creation helper functions *****
# TODO: remove use of numpy here and make lazy
@classmethod
def zeros(cls, *shape, **kwargs): return cls([0], **kwargs).reshape([1]*len(shape)).expand(shape)
@staticmethod
def zeros(*shape, **kwargs): return Tensor([0], **kwargs).reshape([1]*len(shape)).expand(shape)
@classmethod
def ones(cls, *shape, **kwargs): return cls([1], **kwargs).reshape([1]*len(shape)).expand(shape)
@staticmethod
def ones(*shape, **kwargs): return Tensor([1], **kwargs).reshape([1]*len(shape)).expand(shape)
@classmethod
def zeros_like(cls, tensor, **kwargs): return cls.zeros(*tensor.shape, **kwargs)
@staticmethod
def zeros_like(tensor, **kwargs): return Tensor.zeros(*tensor.shape, **kwargs)
@classmethod
def empty(cls, *shape, **kwargs): return cls.zeros(*shape, **kwargs)
@staticmethod
def empty(*shape, **kwargs): return Tensor.zeros(*shape, **kwargs)
@classmethod
def eye(cls, dim, **kwargs): return cls([1], **kwargs).slice(((0,dim+1),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).slice(((0,dim*dim),)).reshape(dim, dim)
@staticmethod
def eye(dim, **kwargs): return Tensor([1], **kwargs).slice(((0,dim+1),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).slice(((0,dim*dim),)).reshape(dim, dim)
# TODO: requires cumsum to remove numpy
@classmethod
def arange(cls, stop, start=0, step=1, **kwargs): return cls(np.arange(start=start, stop=stop, step=step, dtype=np.float32), **kwargs)
@staticmethod
def arange(stop, start=0, step=1, **kwargs): return Tensor(np.arange(start=start, stop=stop, step=step, dtype=np.float32), **kwargs)
# ***** (numpy) rng helper functions *****
# TODO: move randomness generation out of numpy
@ -139,24 +139,24 @@ class Tensor:
@staticmethod
def manual_seed(seed=None): Tensor._rng = np.random.default_rng(seed=seed)
@classmethod
def rand(cls, *shape, **kwargs): return cls(cls._rng.random(size=shape, dtype=np.float32), **kwargs)
@staticmethod
def rand(*shape, **kwargs): return Tensor(Tensor._rng.random(size=shape, dtype=np.float32), **kwargs)
# TODO: replace with a transformation from uniform -> gaussian
@classmethod
def randn(cls, *shape, **kwargs): return cls(cls._rng.standard_normal(size=shape, dtype=np.float32), **kwargs)
@staticmethod
def randn(*shape, **kwargs): return Tensor(Tensor._rng.standard_normal(size=shape, dtype=np.float32), **kwargs)
# ***** rng hlops *****
@classmethod
def uniform(cls, *shape, **kwargs): return cls.rand(*shape, **kwargs) * 2 - 1
@staticmethod
def uniform(*shape, **kwargs) -> Tensor: return Tensor.rand(*shape, **kwargs) * 2 - 1
@classmethod
def scaled_uniform(cls, *shape, **kwargs): return cls.uniform(*shape, **kwargs) * (prod(shape)**-0.5)
@staticmethod
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs) * (prod(shape)**-0.5)
@classmethod
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
def glorot_uniform(cls, *shape, **kwargs): return cls.uniform(*shape, **kwargs) * ((6/(shape[0]+prod(shape[1:])))**0.5)
@staticmethod
def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs) * ((6/(shape[0]+prod(shape[1:])))**0.5)
# ***** toposort and backward pass *****