typing fixup

This commit is contained in:
George Hotz 2023-02-27 09:52:04 -08:00
parent 9aaa7edd74
commit e74779f19d
8 changed files with 48 additions and 35 deletions

View File

@ -40,7 +40,9 @@ jobs:
- name: Lint tinygrad with pylint - name: Lint tinygrad with pylint
run: pylint tinygrad/ run: pylint tinygrad/
- name: Run mypy - name: Run mypy
run: mypy tinygrad/ test/ --ignore-missing-imports run: |
mypy tinygrad/ --ignore-missing-imports --check-untyped-defs
mypy test/ --ignore-missing-imports
testcpu: testcpu:
name: CPU Tests name: CPU Tests

View File

@ -15,7 +15,13 @@ repos:
pass_filenames: false pass_filenames: false
- id: mypy - id: mypy
name: mypy name: mypy
entry: mypy tinygrad/ test/ --ignore-missing-imports entry: mypy tinygrad/ --ignore-missing-imports --check-untyped-defs
language: system
always_run: true
pass_filenames: false
- id: mypytest
name: mypytest
entry: mypy test/ --ignore-missing-imports
language: system language: system
always_run: true always_run: true
pass_filenames: false pass_filenames: false

View File

@ -3,7 +3,7 @@ import os, math, functools, time
from typing import Tuple, Union from typing import Tuple, Union
def dedup(x): return list(dict.fromkeys(x)) # retains list order def dedup(x): return list(dict.fromkeys(x)) # retains list order
def prod(x): return math.prod(x) def prod(x) -> int: return math.prod(x)
def argfix(*x): return tuple() if len(x) == 0 else tuple(x[0]) if isinstance(x[0], (tuple, list)) else tuple(x) def argfix(*x): return tuple() if len(x) == 0 else tuple(x[0]) if isinstance(x[0], (tuple, list)) else tuple(x)
def argsort(x): return sorted(range(len(x)), key=x.__getitem__) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python def argsort(x): return sorted(range(len(x)), key=x.__getitem__) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
def all_same(items): return all(x == items[0] for x in items) if len(items) > 0 else True def all_same(items): return all(x == items[0] for x in items) if len(items) > 0 else True

View File

@ -1,18 +1,18 @@
from typing import Callable, List, Tuple from typing import Callable, List, Tuple, Any, Dict
import itertools import itertools
from tinygrad.lazy import Device from tinygrad.lazy import Device
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters from tinygrad.ops import GlobalCounters, DeviceBuffer
class TinyJit: class TinyJit:
def __init__(self, fxn): def __init__(self, fxn:Callable):
self.fxn = fxn self.fxn = fxn
self.cnt = 0 self.cnt = 0
self.jit_cache : List[Tuple[Callable, List]] = [] self.jit_cache : List[Tuple[Callable, Any]] = [] # TODO: Any should be List[DeviceBuffer], but this fails
self.ret = None self.ret = None
self.input_replace = {} self.input_replace : Dict[DeviceBuffer, Any]= {}
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs) -> Any:
if Device.DEFAULT != "GPU": return self.fxn(*args, **kwargs) # only jit on the GPU if Device.DEFAULT != "GPU": return self.fxn(*args, **kwargs) # only jit on the GPU
input_tensors = {k:v.realize().lazydata.realized._buf for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)} input_tensors = {k:v.realize().lazydata.realized._buf for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
assert len(input_tensors) != 0, "no inputs to JIT" assert len(input_tensors) != 0, "no inputs to JIT"

View File

@ -1,3 +1,4 @@
from typing import Optional
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
class BatchNorm2d: class BatchNorm2d:
@ -59,15 +60,16 @@ class Linear:
class GroupNorm: class GroupNorm:
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
self.num_groups, self.num_channels, self.eps, self.affine = num_groups, num_channels, eps, affine self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
self.weight, self.bias = (Tensor.ones(num_channels), Tensor.zeros(num_channels)) if affine else (None, None) self.weight : Optional[Tensor] = Tensor.ones(num_channels) if affine else None
self.bias : Optional[Tensor] = Tensor.zeros(num_channels) if affine else None
def __call__(self, x:Tensor): def __call__(self, x:Tensor):
# reshape for layernorm to work as group norm # reshape for layernorm to work as group norm
# subtract mean and divide stddev # subtract mean and divide stddev
x = x.reshape(x.shape[0], self.num_groups, -1).layernorm(eps=self.eps).reshape(x.shape) x = x.reshape(x.shape[0], self.num_groups, -1).layernorm(eps=self.eps).reshape(x.shape)
if not self.affine: return x if self.weight is None or self.bias is None: return x
# elementwise_affine on channels # elementwise_affine on channels
return x * self.weight.reshape(1, -1, 1, 1) + self.bias.reshape(1, -1, 1, 1) return x * self.weight.reshape(1, -1, 1, 1) + self.bias.reshape(1, -1, 1, 1)

View File

@ -35,6 +35,7 @@ def map_buffers(real_srcs, x:LazyOp) -> LazyOp:
# a placeholder class to extend by the exec classes # a placeholder class to extend by the exec classes
class DeviceBuffer: class DeviceBuffer:
_buf: Any # underlying buffer
shape: Any # should be Tuple[int, ...] but ndarray and torch.tensor have incompatible types shape: Any # should be Tuple[int, ...] but ndarray and torch.tensor have incompatible types
@staticmethod @staticmethod
def fromCPU(x:np.ndarray) -> DeviceBuffer: raise NotImplementedError("must be implemented") def fromCPU(x:np.ndarray) -> DeviceBuffer: raise NotImplementedError("must be implemented")

View File

@ -48,6 +48,8 @@ class ZeroView:
# fake properties # fake properties
self.strides, self.contiguous, self.offset = strides_for_shape(self.shape), False, 0 self.strides, self.contiguous, self.offset = strides_for_shape(self.shape), False, 0
def __repr__(self): return f"ZeroView({self.old_shape}, {self.arg})"
def expr_node(self, idx=None, valid=None): def expr_node(self, idx=None, valid=None):
if idx is None: idx = Variable('idx', 0, prod([y-x for x,y in self.arg])) if idx is None: idx = Variable('idx', 0, prod([y-x for x,y in self.arg]))
expr, acc = [valid] if valid is not None else [], 1 expr, acc = [valid] if valid is not None else [], 1
@ -57,7 +59,7 @@ class ZeroView:
acc *= ns acc *= ns
return Variable.ands(expr) return Variable.ands(expr)
def __repr__(self): return f"ZeroView({self.old_shape}, {self.arg})" def expr_idxs(self, idxs, offset=0): raise NotImplementedError("ZeroView doesn't support expr_idxs")
ViewTypes = Union[View, ZeroView] ViewTypes = Union[View, ZeroView]

View File

@ -113,24 +113,24 @@ class Tensor:
# ***** creation helper functions ***** # ***** creation helper functions *****
# TODO: remove use of numpy here and make lazy # TODO: remove use of numpy here and make lazy
@classmethod @staticmethod
def zeros(cls, *shape, **kwargs): return cls([0], **kwargs).reshape([1]*len(shape)).expand(shape) def zeros(*shape, **kwargs): return Tensor([0], **kwargs).reshape([1]*len(shape)).expand(shape)
@classmethod @staticmethod
def ones(cls, *shape, **kwargs): return cls([1], **kwargs).reshape([1]*len(shape)).expand(shape) def ones(*shape, **kwargs): return Tensor([1], **kwargs).reshape([1]*len(shape)).expand(shape)
@classmethod @staticmethod
def zeros_like(cls, tensor, **kwargs): return cls.zeros(*tensor.shape, **kwargs) def zeros_like(tensor, **kwargs): return Tensor.zeros(*tensor.shape, **kwargs)
@classmethod @staticmethod
def empty(cls, *shape, **kwargs): return cls.zeros(*shape, **kwargs) def empty(*shape, **kwargs): return Tensor.zeros(*shape, **kwargs)
@classmethod @staticmethod
def eye(cls, dim, **kwargs): return cls([1], **kwargs).slice(((0,dim+1),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).slice(((0,dim*dim),)).reshape(dim, dim) def eye(dim, **kwargs): return Tensor([1], **kwargs).slice(((0,dim+1),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).slice(((0,dim*dim),)).reshape(dim, dim)
# TODO: requires cumsum to remove numpy # TODO: requires cumsum to remove numpy
@classmethod @staticmethod
def arange(cls, stop, start=0, step=1, **kwargs): return cls(np.arange(start=start, stop=stop, step=step, dtype=np.float32), **kwargs) def arange(stop, start=0, step=1, **kwargs): return Tensor(np.arange(start=start, stop=stop, step=step, dtype=np.float32), **kwargs)
# ***** (numpy) rng helper functions ***** # ***** (numpy) rng helper functions *****
# TODO: move randomness generation out of numpy # TODO: move randomness generation out of numpy
@ -139,24 +139,24 @@ class Tensor:
@staticmethod @staticmethod
def manual_seed(seed=None): Tensor._rng = np.random.default_rng(seed=seed) def manual_seed(seed=None): Tensor._rng = np.random.default_rng(seed=seed)
@classmethod @staticmethod
def rand(cls, *shape, **kwargs): return cls(cls._rng.random(size=shape, dtype=np.float32), **kwargs) def rand(*shape, **kwargs): return Tensor(Tensor._rng.random(size=shape, dtype=np.float32), **kwargs)
# TODO: replace with a transformation from uniform -> gaussian # TODO: replace with a transformation from uniform -> gaussian
@classmethod @staticmethod
def randn(cls, *shape, **kwargs): return cls(cls._rng.standard_normal(size=shape, dtype=np.float32), **kwargs) def randn(*shape, **kwargs): return Tensor(Tensor._rng.standard_normal(size=shape, dtype=np.float32), **kwargs)
# ***** rng hlops ***** # ***** rng hlops *****
@classmethod @staticmethod
def uniform(cls, *shape, **kwargs): return cls.rand(*shape, **kwargs) * 2 - 1 def uniform(*shape, **kwargs) -> Tensor: return Tensor.rand(*shape, **kwargs) * 2 - 1
@classmethod @staticmethod
def scaled_uniform(cls, *shape, **kwargs): return cls.uniform(*shape, **kwargs) * (prod(shape)**-0.5) def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs) * (prod(shape)**-0.5)
@classmethod
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform # https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
def glorot_uniform(cls, *shape, **kwargs): return cls.uniform(*shape, **kwargs) * ((6/(shape[0]+prod(shape[1:])))**0.5) @staticmethod
def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs) * ((6/(shape[0]+prod(shape[1:])))**0.5)
# ***** toposort and backward pass ***** # ***** toposort and backward pass *****