mirror of https://github.com/commaai/tinygrad.git
typing fixup
This commit is contained in:
parent
9aaa7edd74
commit
e74779f19d
|
@ -40,7 +40,9 @@ jobs:
|
||||||
- name: Lint tinygrad with pylint
|
- name: Lint tinygrad with pylint
|
||||||
run: pylint tinygrad/
|
run: pylint tinygrad/
|
||||||
- name: Run mypy
|
- name: Run mypy
|
||||||
run: mypy tinygrad/ test/ --ignore-missing-imports
|
run: |
|
||||||
|
mypy tinygrad/ --ignore-missing-imports --check-untyped-defs
|
||||||
|
mypy test/ --ignore-missing-imports
|
||||||
|
|
||||||
testcpu:
|
testcpu:
|
||||||
name: CPU Tests
|
name: CPU Tests
|
||||||
|
|
|
@ -15,7 +15,13 @@ repos:
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
- id: mypy
|
- id: mypy
|
||||||
name: mypy
|
name: mypy
|
||||||
entry: mypy tinygrad/ test/ --ignore-missing-imports
|
entry: mypy tinygrad/ --ignore-missing-imports --check-untyped-defs
|
||||||
|
language: system
|
||||||
|
always_run: true
|
||||||
|
pass_filenames: false
|
||||||
|
- id: mypytest
|
||||||
|
name: mypytest
|
||||||
|
entry: mypy test/ --ignore-missing-imports
|
||||||
language: system
|
language: system
|
||||||
always_run: true
|
always_run: true
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os, math, functools, time
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
def dedup(x): return list(dict.fromkeys(x)) # retains list order
|
def dedup(x): return list(dict.fromkeys(x)) # retains list order
|
||||||
def prod(x): return math.prod(x)
|
def prod(x) -> int: return math.prod(x)
|
||||||
def argfix(*x): return tuple() if len(x) == 0 else tuple(x[0]) if isinstance(x[0], (tuple, list)) else tuple(x)
|
def argfix(*x): return tuple() if len(x) == 0 else tuple(x[0]) if isinstance(x[0], (tuple, list)) else tuple(x)
|
||||||
def argsort(x): return sorted(range(len(x)), key=x.__getitem__) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
def argsort(x): return sorted(range(len(x)), key=x.__getitem__) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
||||||
def all_same(items): return all(x == items[0] for x in items) if len(items) > 0 else True
|
def all_same(items): return all(x == items[0] for x in items) if len(items) > 0 else True
|
||||||
|
|
|
@ -1,18 +1,18 @@
|
||||||
from typing import Callable, List, Tuple
|
from typing import Callable, List, Tuple, Any, Dict
|
||||||
import itertools
|
import itertools
|
||||||
from tinygrad.lazy import Device
|
from tinygrad.lazy import Device
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.ops import GlobalCounters
|
from tinygrad.ops import GlobalCounters, DeviceBuffer
|
||||||
|
|
||||||
class TinyJit:
|
class TinyJit:
|
||||||
def __init__(self, fxn):
|
def __init__(self, fxn:Callable):
|
||||||
self.fxn = fxn
|
self.fxn = fxn
|
||||||
self.cnt = 0
|
self.cnt = 0
|
||||||
self.jit_cache : List[Tuple[Callable, List]] = []
|
self.jit_cache : List[Tuple[Callable, Any]] = [] # TODO: Any should be List[DeviceBuffer], but this fails
|
||||||
self.ret = None
|
self.ret = None
|
||||||
self.input_replace = {}
|
self.input_replace : Dict[DeviceBuffer, Any]= {}
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs) -> Any:
|
||||||
if Device.DEFAULT != "GPU": return self.fxn(*args, **kwargs) # only jit on the GPU
|
if Device.DEFAULT != "GPU": return self.fxn(*args, **kwargs) # only jit on the GPU
|
||||||
input_tensors = {k:v.realize().lazydata.realized._buf for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
|
input_tensors = {k:v.realize().lazydata.realized._buf for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
|
||||||
assert len(input_tensors) != 0, "no inputs to JIT"
|
assert len(input_tensors) != 0, "no inputs to JIT"
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from typing import Optional
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
|
|
||||||
class BatchNorm2d:
|
class BatchNorm2d:
|
||||||
|
@ -59,15 +60,16 @@ class Linear:
|
||||||
|
|
||||||
class GroupNorm:
|
class GroupNorm:
|
||||||
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
|
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
|
||||||
self.num_groups, self.num_channels, self.eps, self.affine = num_groups, num_channels, eps, affine
|
self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
|
||||||
self.weight, self.bias = (Tensor.ones(num_channels), Tensor.zeros(num_channels)) if affine else (None, None)
|
self.weight : Optional[Tensor] = Tensor.ones(num_channels) if affine else None
|
||||||
|
self.bias : Optional[Tensor] = Tensor.zeros(num_channels) if affine else None
|
||||||
|
|
||||||
def __call__(self, x:Tensor):
|
def __call__(self, x:Tensor):
|
||||||
# reshape for layernorm to work as group norm
|
# reshape for layernorm to work as group norm
|
||||||
# subtract mean and divide stddev
|
# subtract mean and divide stddev
|
||||||
x = x.reshape(x.shape[0], self.num_groups, -1).layernorm(eps=self.eps).reshape(x.shape)
|
x = x.reshape(x.shape[0], self.num_groups, -1).layernorm(eps=self.eps).reshape(x.shape)
|
||||||
|
|
||||||
if not self.affine: return x
|
if self.weight is None or self.bias is None: return x
|
||||||
# elementwise_affine on channels
|
# elementwise_affine on channels
|
||||||
return x * self.weight.reshape(1, -1, 1, 1) + self.bias.reshape(1, -1, 1, 1)
|
return x * self.weight.reshape(1, -1, 1, 1) + self.bias.reshape(1, -1, 1, 1)
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,7 @@ def map_buffers(real_srcs, x:LazyOp) -> LazyOp:
|
||||||
|
|
||||||
# a placeholder class to extend by the exec classes
|
# a placeholder class to extend by the exec classes
|
||||||
class DeviceBuffer:
|
class DeviceBuffer:
|
||||||
|
_buf: Any # underlying buffer
|
||||||
shape: Any # should be Tuple[int, ...] but ndarray and torch.tensor have incompatible types
|
shape: Any # should be Tuple[int, ...] but ndarray and torch.tensor have incompatible types
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def fromCPU(x:np.ndarray) -> DeviceBuffer: raise NotImplementedError("must be implemented")
|
def fromCPU(x:np.ndarray) -> DeviceBuffer: raise NotImplementedError("must be implemented")
|
||||||
|
|
|
@ -48,6 +48,8 @@ class ZeroView:
|
||||||
# fake properties
|
# fake properties
|
||||||
self.strides, self.contiguous, self.offset = strides_for_shape(self.shape), False, 0
|
self.strides, self.contiguous, self.offset = strides_for_shape(self.shape), False, 0
|
||||||
|
|
||||||
|
def __repr__(self): return f"ZeroView({self.old_shape}, {self.arg})"
|
||||||
|
|
||||||
def expr_node(self, idx=None, valid=None):
|
def expr_node(self, idx=None, valid=None):
|
||||||
if idx is None: idx = Variable('idx', 0, prod([y-x for x,y in self.arg]))
|
if idx is None: idx = Variable('idx', 0, prod([y-x for x,y in self.arg]))
|
||||||
expr, acc = [valid] if valid is not None else [], 1
|
expr, acc = [valid] if valid is not None else [], 1
|
||||||
|
@ -57,7 +59,7 @@ class ZeroView:
|
||||||
acc *= ns
|
acc *= ns
|
||||||
return Variable.ands(expr)
|
return Variable.ands(expr)
|
||||||
|
|
||||||
def __repr__(self): return f"ZeroView({self.old_shape}, {self.arg})"
|
def expr_idxs(self, idxs, offset=0): raise NotImplementedError("ZeroView doesn't support expr_idxs")
|
||||||
|
|
||||||
ViewTypes = Union[View, ZeroView]
|
ViewTypes = Union[View, ZeroView]
|
||||||
|
|
||||||
|
|
|
@ -113,24 +113,24 @@ class Tensor:
|
||||||
# ***** creation helper functions *****
|
# ***** creation helper functions *****
|
||||||
# TODO: remove use of numpy here and make lazy
|
# TODO: remove use of numpy here and make lazy
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def zeros(cls, *shape, **kwargs): return cls([0], **kwargs).reshape([1]*len(shape)).expand(shape)
|
def zeros(*shape, **kwargs): return Tensor([0], **kwargs).reshape([1]*len(shape)).expand(shape)
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def ones(cls, *shape, **kwargs): return cls([1], **kwargs).reshape([1]*len(shape)).expand(shape)
|
def ones(*shape, **kwargs): return Tensor([1], **kwargs).reshape([1]*len(shape)).expand(shape)
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def zeros_like(cls, tensor, **kwargs): return cls.zeros(*tensor.shape, **kwargs)
|
def zeros_like(tensor, **kwargs): return Tensor.zeros(*tensor.shape, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def empty(cls, *shape, **kwargs): return cls.zeros(*shape, **kwargs)
|
def empty(*shape, **kwargs): return Tensor.zeros(*shape, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def eye(cls, dim, **kwargs): return cls([1], **kwargs).slice(((0,dim+1),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).slice(((0,dim*dim),)).reshape(dim, dim)
|
def eye(dim, **kwargs): return Tensor([1], **kwargs).slice(((0,dim+1),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).slice(((0,dim*dim),)).reshape(dim, dim)
|
||||||
|
|
||||||
# TODO: requires cumsum to remove numpy
|
# TODO: requires cumsum to remove numpy
|
||||||
@classmethod
|
@staticmethod
|
||||||
def arange(cls, stop, start=0, step=1, **kwargs): return cls(np.arange(start=start, stop=stop, step=step, dtype=np.float32), **kwargs)
|
def arange(stop, start=0, step=1, **kwargs): return Tensor(np.arange(start=start, stop=stop, step=step, dtype=np.float32), **kwargs)
|
||||||
|
|
||||||
# ***** (numpy) rng helper functions *****
|
# ***** (numpy) rng helper functions *****
|
||||||
# TODO: move randomness generation out of numpy
|
# TODO: move randomness generation out of numpy
|
||||||
|
@ -139,24 +139,24 @@ class Tensor:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def manual_seed(seed=None): Tensor._rng = np.random.default_rng(seed=seed)
|
def manual_seed(seed=None): Tensor._rng = np.random.default_rng(seed=seed)
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def rand(cls, *shape, **kwargs): return cls(cls._rng.random(size=shape, dtype=np.float32), **kwargs)
|
def rand(*shape, **kwargs): return Tensor(Tensor._rng.random(size=shape, dtype=np.float32), **kwargs)
|
||||||
|
|
||||||
# TODO: replace with a transformation from uniform -> gaussian
|
# TODO: replace with a transformation from uniform -> gaussian
|
||||||
@classmethod
|
@staticmethod
|
||||||
def randn(cls, *shape, **kwargs): return cls(cls._rng.standard_normal(size=shape, dtype=np.float32), **kwargs)
|
def randn(*shape, **kwargs): return Tensor(Tensor._rng.standard_normal(size=shape, dtype=np.float32), **kwargs)
|
||||||
|
|
||||||
# ***** rng hlops *****
|
# ***** rng hlops *****
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def uniform(cls, *shape, **kwargs): return cls.rand(*shape, **kwargs) * 2 - 1
|
def uniform(*shape, **kwargs) -> Tensor: return Tensor.rand(*shape, **kwargs) * 2 - 1
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def scaled_uniform(cls, *shape, **kwargs): return cls.uniform(*shape, **kwargs) * (prod(shape)**-0.5)
|
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs) * (prod(shape)**-0.5)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
|
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
|
||||||
def glorot_uniform(cls, *shape, **kwargs): return cls.uniform(*shape, **kwargs) * ((6/(shape[0]+prod(shape[1:])))**0.5)
|
@staticmethod
|
||||||
|
def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs) * ((6/(shape[0]+prod(shape[1:])))**0.5)
|
||||||
|
|
||||||
# ***** toposort and backward pass *****
|
# ***** toposort and backward pass *****
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue