mirror of https://github.com/commaai/tinygrad.git
lower schedule (#2559)
* lower schedule * remove RAND, and don't put load in the JIT yet * better fix for that test
This commit is contained in:
parent
077567f62d
commit
6733425095
|
@ -98,13 +98,13 @@ class LazyOp:
|
|||
src: Tuple[Union[LazyOp, LazyBuffer], ...] # the sources
|
||||
arg: Optional[Any] = None # and an optional static argument
|
||||
|
||||
# there's currently 27 Ops you have to implement for an accelerator.
|
||||
# there's currently 26 Ops you have to implement for an accelerator.
|
||||
class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto()
|
||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); CMPLT = auto(); MAX = auto()
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto()
|
||||
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto()
|
||||
class TernaryOps(Enum): MULACC = auto(); WHERE = auto()
|
||||
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto()
|
||||
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto()
|
||||
# NOTE: if you have a CompiledBuffer(DeviceBuffer)
|
||||
# you do not need to implement the MovementOps
|
||||
# as they are handled by the ShapeTracker(in tinygrad/shape/shapetracker.py, code 7/10)
|
||||
|
|
|
@ -11,7 +11,7 @@ unary_op (NOOP, EXP2, LOG2, CAST, SIN, SQRT) # A -> A
|
|||
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
|
||||
binary_op (ADD, SUB, MUL, DIV, CMPEQ, MAX) # A + A -> A (all the same size)
|
||||
movement_op (EXPAND, RESHAPE, PERMUTE, PAD, SHRINK, STRIDE) # A -> B (different size)
|
||||
load_op (EMPTY, RAND, CONST, FROM, CONTIGUOUS, CUSTOM) # -> A (initialize data on device)
|
||||
load_op (EMPTY, CONST, FROM, CONTIGUOUS, CUSTOM) # -> A (initialize data on device)
|
||||
ternary_op (WHERE) # A, A, A -> A
|
||||
ternary_op [[optional]] (MULACC) # A * A -> B
|
||||
```
|
||||
|
|
|
@ -88,8 +88,9 @@ class TestInferenceMinKernels(unittest.TestCase):
|
|||
args_tiny = {"dim": 512, "hidden_dim": 1024, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
model = Transformer(**args_tiny)
|
||||
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
|
||||
inp = Tensor([[1,2,3,4]])
|
||||
with CLCache(100):
|
||||
model(Tensor([[1,2,3,4]]), 0).realize()
|
||||
model(inp, 0).realize()
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
class TestOptBinOp(unittest.TestCase):
|
||||
|
|
|
@ -5,8 +5,8 @@ from tinygrad.nn.state import get_parameters
|
|||
# for speed
|
||||
def derandomize(x):
|
||||
if isinstance(x, LazyOp):
|
||||
new_op = LoadOps.EMPTY if x.op == LoadOps.RAND else x.op
|
||||
return LazyOp(new_op, tuple([derandomize(s) for s in x.src]), None if x.op == LoadOps.RAND else x.arg)
|
||||
new_op = LoadOps.EMPTY if x.op == LoadOps.CUSTOM else x.op
|
||||
return LazyOp(new_op, tuple([derandomize(s) for s in x.src]), None if x.op == LoadOps.CUSTOM else x.arg)
|
||||
x.op = derandomize(x.op)
|
||||
return x
|
||||
|
||||
|
|
|
@ -33,6 +33,31 @@ class _Device:
|
|||
return "CPU"
|
||||
Device = _Device()
|
||||
|
||||
# **************** base Runner + helpers ****************
|
||||
|
||||
class JITRunner:
|
||||
def __init__(self):
|
||||
self.op_estimate, self.mem_estimate = 0, 0
|
||||
def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
|
||||
var_vals = var_vals if var_vals is not None else {}
|
||||
from tinygrad.jit import CacheCollector
|
||||
et = self(rawbufs, var_vals)
|
||||
CacheCollector.add(self, rawbufs, var_vals)
|
||||
return et
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
raise NotImplementedError("override this")
|
||||
|
||||
def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count, jit=False, num_kernels=1, lra: Optional[Dict]=None):
|
||||
if var_vals is None: var_vals = {}
|
||||
op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals)
|
||||
if DEBUG >= 2:
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} {str(lra.get('local_size', '') if lra else ''):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
|
||||
GlobalCounters.kernel_count += num_kernels
|
||||
GlobalCounters.global_ops += op_estimate
|
||||
GlobalCounters.global_mem += mem_estimate
|
||||
if et is not None: GlobalCounters.time_sum_s += et
|
||||
|
||||
# **************** Buffer / Allocator ****************
|
||||
|
||||
class Buffer:
|
||||
|
@ -47,27 +72,6 @@ class Buffer:
|
|||
if self.device == Device.DEFAULT: GlobalCounters.mem_used -= self.size * self.dtype.itemsize
|
||||
self.allocator.free(self._buf, self.size, self.dtype)
|
||||
def __repr__(self): return f"<buf device:{self.device} size:{self.size}>"
|
||||
def copy_(self, src:Buffer):
|
||||
assert self.size == src.size and self.dtype == src.dtype, "buffer copy size/dtype mismatch"
|
||||
if hasattr(self.allocator, 'transfer') and type(self.allocator) is type(src.allocator):
|
||||
# fast path, used on HIP between GPUs
|
||||
self.allocator.transfer(self._buf, src._buf, self.size*self.dtype.itemsize)
|
||||
return
|
||||
if getenv("FROM_BUFFER") and hasattr(self.allocator, 'from_buffer') and hasattr(self.allocator, 'transfer') and hasattr(src.allocator, 'as_buffer'):
|
||||
# fast path, used on Metal in OS X Sonoma
|
||||
# NOTE: this is *only* faster if the pages from disk are already loaded into memory
|
||||
fb = self.allocator.from_buffer(src.allocator.as_buffer(src._buf))
|
||||
if fb:
|
||||
self.allocator.transfer(self._buf, fb, self.size*self.dtype.itemsize)
|
||||
return
|
||||
if hasattr(self.allocator, 'as_buffer'):
|
||||
# fast(ish) path, uses readinto in diskbuffers
|
||||
src.allocator.copyout(self.allocator.as_buffer(self._buf), src._buf)
|
||||
elif hasattr(src.allocator, 'as_buffer'):
|
||||
self.allocator.copyin(self._buf, src.allocator.as_buffer(src._buf))
|
||||
else:
|
||||
# slow path, allocates a CPU buffer
|
||||
self.copyin(src.toCPU().data)
|
||||
def copyin(self, mv:memoryview):
|
||||
mv = flat_mv(mv)
|
||||
assert len(mv) == self.size*self.dtype.itemsize, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
||||
|
@ -82,6 +86,33 @@ class Buffer:
|
|||
if self.size > 0: self.allocator.copyout(flat_mv(ret.data), self._buf)
|
||||
return ret
|
||||
|
||||
class _BufferCopy(JITRunner):
|
||||
# TODO: make wait work
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False):
|
||||
dest, src = rawbufs
|
||||
assert dest.size == src.size and dest.dtype == src.dtype, "buffer copy size/dtype mismatch"
|
||||
if DEBUG >= 2: print(f"*** copy {dest.device} <- {src.device} size {dest.size:<16d} dtype {dest.dtype}")
|
||||
if hasattr(dest.allocator, 'transfer') and type(dest.allocator) is type(src.allocator):
|
||||
# fast path, used on HIP between GPUs
|
||||
dest.allocator.transfer(dest._buf, src._buf, dest.size*dest.dtype.itemsize)
|
||||
return
|
||||
if getenv("FROM_BUFFER") and hasattr(dest.allocator, 'from_buffer') and hasattr(dest.allocator, 'transfer') and hasattr(src.allocator, 'as_buffer'):
|
||||
# fast path, used on Metal in OS X Sonoma
|
||||
# NOTE: this is *only* faster if the pages from disk are already loaded into memory
|
||||
fb = dest.allocator.from_buffer(src.allocator.as_buffer(src._buf))
|
||||
if fb:
|
||||
dest.allocator.transfer(dest._buf, fb, dest.size*dest.dtype.itemsize)
|
||||
return
|
||||
if hasattr(dest.allocator, 'as_buffer'):
|
||||
# fast(ish) path, uses readinto in diskbuffers
|
||||
src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf)
|
||||
elif hasattr(src.allocator, 'as_buffer'):
|
||||
dest.allocator.copyin(dest._buf, src.allocator.as_buffer(src._buf))
|
||||
else:
|
||||
# slow path, allocates a CPU buffer
|
||||
dest.copyin(src.toCPU().data)
|
||||
BufferCopy = _BufferCopy()
|
||||
|
||||
# TODO: size, dest, src are the same type. can we enforce this?
|
||||
class Allocator:
|
||||
def alloc(self, size:int, dtype:DType):
|
||||
|
@ -115,31 +146,6 @@ class _MallocAllocator(LRUAllocator):
|
|||
def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
|
||||
MallocAllocator = _MallocAllocator()
|
||||
|
||||
# **************** base Runner + helpers ****************
|
||||
|
||||
class JITRunner:
|
||||
def __init__(self):
|
||||
self.op_estimate, self.mem_estimate = 0, 0
|
||||
def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
|
||||
var_vals = var_vals if var_vals is not None else {}
|
||||
from tinygrad.jit import CacheCollector
|
||||
et = self(rawbufs, var_vals)
|
||||
CacheCollector.add(self, rawbufs, var_vals)
|
||||
return et
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
raise NotImplementedError("override this")
|
||||
|
||||
def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count, jit=False, num_kernels=1, lra: Optional[Dict]=None):
|
||||
if var_vals is None: var_vals = {}
|
||||
op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals)
|
||||
if DEBUG >= 2:
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} {str(lra.get('local_size', '') if lra else ''):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
|
||||
GlobalCounters.kernel_count += num_kernels
|
||||
GlobalCounters.global_ops += op_estimate
|
||||
GlobalCounters.global_mem += mem_estimate
|
||||
if et is not None: GlobalCounters.time_sum_s += et
|
||||
|
||||
# **************** for Interpreted Devices ****************
|
||||
|
||||
class InterpretedASTRunner(JITRunner):
|
||||
|
|
|
@ -86,7 +86,7 @@ def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dty
|
|||
if 0 in st.shape: return LazyBuffer(device, ShapeTracker.from_shape(st.shape), LoadOps, LazyOp(LoadOps.CONST, tuple(), 0.0), dtype)
|
||||
|
||||
# fromcpu aren't cached
|
||||
if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype, base=base)
|
||||
if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.CUSTOM, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype, base=base)
|
||||
|
||||
# wop is the deduping key. i feel this used to compare more deeply
|
||||
wop = (device, dtype, optype, ref(op), ref(base) if base else None)
|
||||
|
@ -166,8 +166,8 @@ class LazyBuffer:
|
|||
|
||||
op, base_bufs = _replace_bufferops(op)
|
||||
|
||||
# add the store
|
||||
if op.op not in LoadOps:
|
||||
# add the store
|
||||
info = get_lazyop_info(op)
|
||||
assert info.dtype == self.dtype or isinstance(self.dtype, ImageDType), f"dtype mismatch {info.dtype=} != {self.dtype=}"
|
||||
|
||||
|
@ -179,6 +179,9 @@ class LazyBuffer:
|
|||
# TODO: why doesn't this match?
|
||||
#assert info.shape == self.shape, f"shape mismatch {info.shape=} != {self.shape=}"
|
||||
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, self.dtype, ShapeTracker.from_shape(info.shape)))
|
||||
else:
|
||||
# check loadop validity of bufferops
|
||||
for i,s in enumerate(op.src): assert isinstance(s, LazyOp) and s.op == BufferOps.LOAD and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
|
||||
|
||||
return ret + [ScheduleItem(op, self, tuple(base_bufs), {k:var_vals[k] for k in vars_from_ast(op)})]
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
|||
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
|
||||
# Ops below this line are not allowed in ASTs
|
||||
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
|
||||
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, BufferOps]
|
||||
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]]
|
||||
|
|
|
@ -1,17 +1,32 @@
|
|||
from typing import List, cast, Dict, Callable
|
||||
import numpy as np
|
||||
from tinygrad.ops import ScheduleItem, LazyOp, LoadOps, BufferOps
|
||||
from tinygrad.device import Device, Buffer
|
||||
from typing import List, Dict, Tuple, Generator, Optional
|
||||
from tinygrad.ops import ScheduleItem, LoadOps, BufferOps
|
||||
from tinygrad.device import Device, Buffer, BufferCopy, JITRunner
|
||||
from tinygrad.graph import log_schedule_item, print_tree
|
||||
from tinygrad.helpers import DEBUG, prod
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
||||
def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
|
||||
# NOTE: if you for loop the schedule it's slow because nothing frees
|
||||
class CustomOp(JITRunner):
|
||||
def __init__(self, fxn):
|
||||
self.fxn = fxn
|
||||
super().__init__()
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): self.fxn(*rawbufs)
|
||||
|
||||
def lower_schedule(schedule:List[ScheduleItem]) -> Generator[Tuple[ScheduleItem, Optional[JITRunner]], None, None]:
|
||||
while len(schedule):
|
||||
si = schedule.pop(0)
|
||||
assert all(si.out.device == x.device for x in si.inputs) or si.ast.op is LoadOps.FROM, f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}"
|
||||
if si.ast.op is LoadOps.EMPTY: yield si, None
|
||||
elif si.ast.op is LoadOps.FROM: yield si, BufferCopy
|
||||
elif si.ast.op is LoadOps.CUSTOM: yield si, CustomOp(si.ast.arg)
|
||||
else: yield si, Device[si.out.device].get_runner(si.ast)
|
||||
del si.out.op
|
||||
for v in si.out.views: del v.op
|
||||
|
||||
def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
|
||||
for si, fxn in lower_schedule(schedule):
|
||||
if not disable_logging: log_schedule_item(si)
|
||||
assert all(x.realized for x in si.inputs), "can't run schedule, some inputs aren't realized"
|
||||
assert all(si.out.device == x.device for x in si.inputs) or si.ast.op is LoadOps.FROM, f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}"
|
||||
|
||||
# check if we can reuse the output buffer
|
||||
# if it's aliased, don't use it
|
||||
# TODO: this is pretty wrong actually, who knows where else this buffer is used?
|
||||
|
@ -24,35 +39,15 @@ def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
|
|||
if any(not x.arg.st.contiguous for x in si.ast.get_lazyops() if x.op == BufferOps.LOAD and x.arg.idx == i+1):
|
||||
si.out.output_buffer = None
|
||||
break
|
||||
|
||||
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
|
||||
si.out.realized = si.out.output_buffer if si.out.output_buffer is not None else \
|
||||
Buffer(si.out.device, prod((s if isinstance(s, int) else s.max for s in si.out.shape)), si.out.dtype)
|
||||
if si.ast.op in LoadOps:
|
||||
if DEBUG >= 2: print(f"*** {si.ast.op:>15s} {f'{si.out.device} <- {si.inputs[0].device}' if si.ast.op is LoadOps.FROM else si.out.device:25s} sz {si.out.realized.size:5d} shape {si.out.shape} dtype {si.out.dtype} arg {si.ast.arg}")
|
||||
# confirm the LoadOps are contiguous and in order
|
||||
for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.LOAD and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
|
||||
kwargs = {"arg": si.ast.arg} if si.ast.arg is not None else {}
|
||||
LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out.realized, *[x.realized for x in si.inputs], **kwargs)
|
||||
else:
|
||||
Device[si.out.device].get_runner(si.ast).exec([si.out.realized] + [x.realized for x in si.inputs], si.var_vals)
|
||||
del si.out.op
|
||||
for v in si.out.views: del v.op
|
||||
#assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}"
|
||||
assert si.out.realized.dtype == si.out.dtype, f"realized dtype is incorrect, {si.out.realized.dtype} != {si.out.dtype}"
|
||||
|
||||
# *** LoadOps implementation ***
|
||||
# get all the buffers
|
||||
rawbufs = [si.out.realized] + [x.realized for x in si.inputs]
|
||||
|
||||
# TODO: remove this and write the RNG in tinygrad
|
||||
def _realize_rand(buffer: Buffer, arg) -> None:
|
||||
rng = np.random.default_rng(arg)
|
||||
rng_np_buffer = rng.random(size=buffer.size, dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False)
|
||||
buffer.copyin(rng_np_buffer.data)
|
||||
|
||||
def _realize_custom(*buffers: Buffer, arg) -> None: arg(*buffers)
|
||||
|
||||
LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = {
|
||||
LoadOps.EMPTY: lambda x: None,
|
||||
LoadOps.RAND: _realize_rand,
|
||||
LoadOps.FROM: Buffer.copy_,
|
||||
LoadOps.CUSTOM: _realize_custom
|
||||
}
|
||||
# run the function and put in JIT
|
||||
if fxn: fxn.exec(rawbufs, si.var_vals)
|
||||
assert si.out.realized.device == si.out.device, f"realized device is incorrect, {si.out.realized.device=} != {si.out.device=}"
|
||||
assert si.out.realized.dtype == si.out.dtype, f"realized dtype is incorrect, {si.out.realized.dtype=} != {si.out.dtype=}"
|
||||
|
|
|
@ -10,7 +10,7 @@ import numpy as np
|
|||
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int, round_up
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.realize import run_schedule
|
||||
|
||||
|
@ -156,8 +156,7 @@ class Tensor:
|
|||
|
||||
@staticmethod
|
||||
def rand(*shape, **kwargs):
|
||||
Tensor._seed += 1
|
||||
return Tensor._loadop(LoadOps.RAND, prod((shape:=argfix(*shape))), arg=Tensor._seed, **kwargs).reshape(shape)
|
||||
return Tensor._loadop(LoadOps.CUSTOM, prod((shape:=argfix(*shape))), arg=custom_random, **kwargs).reshape(shape)
|
||||
|
||||
# ***** creation helper functions *****
|
||||
|
||||
|
@ -828,3 +827,11 @@ if IMAGE:
|
|||
from tinygrad.features.image import image_conv2d, image_dot
|
||||
setattr(Tensor, "conv2d", image_conv2d)
|
||||
setattr(Tensor, "dot", image_dot)
|
||||
|
||||
# TODO: remove the custom op and replace with threefry
|
||||
def custom_random(out:Buffer):
|
||||
Tensor._seed += 1
|
||||
if DEBUG >= 2: print(f"*** rand {out.device} seed {Tensor._seed} size {out.size:<16d} dtype {out.dtype}")
|
||||
rng = np.random.default_rng(Tensor._seed)
|
||||
rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=out.dtype.np, copy=False)
|
||||
out.copyin(rng_np_buffer.data)
|
||||
|
|
Loading…
Reference in New Issue