mirror of https://github.com/commaai/tinygrad.git
1100 lines, but sane linter rules
This commit is contained in:
parent
682dc64430
commit
f215534a64
|
@ -3,9 +3,11 @@ name: Unit Tests
|
|||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
# OMG THIS TEST IS DISABLED, PLZ MAKE TINYGRAD TINY AGAIN
|
||||
lines:
|
||||
name: Less than 1000 lines
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ false }}
|
||||
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
|
@ -34,7 +36,7 @@ jobs:
|
|||
- name: Lint with pylint
|
||||
run: python -m pylint --disable=all -e W0311 --jobs=0 --indent-string=' ' **/*.py
|
||||
- name: Lint with flake8
|
||||
run: flake8 tinygrad/ --indent-size=2 --select=F,E112,E113,E304,E502,E702,E703,E71,E72,E731,W191,W6 --statistics -j4
|
||||
run: flake8 tinygrad/ --indent-size=2 --select=F,E112,E113,E304,E502,E701,E702,E703,E71,E72,E731,W191,W6 --statistics -j4
|
||||
- name: Lint tinygrad with pylint
|
||||
run: pylint tinygrad/
|
||||
|
||||
|
|
|
@ -22,10 +22,10 @@ CONV_SRC = load(pathlib.Path(__file__).resolve().parent.parent.parent / 'accel/o
|
|||
MATMUL_SRC = load(pathlib.Path(__file__).resolve().parent.parent.parent / 'accel/opencl/matmul.cl')
|
||||
|
||||
class CLImage:
|
||||
fmt = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT)
|
||||
|
||||
def __init__(self, shape):
|
||||
# HALF_FLOAT breaks tests
|
||||
fmt = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT)
|
||||
self.cl = cl.Image(CL().cl_ctx, cl.mem_flags.READ_WRITE, fmt, shape=shape)
|
||||
self.cl = cl.Image(CL().cl_ctx, cl.mem_flags.READ_WRITE, CLImage.fmt, shape=shape)
|
||||
CL.mem_used += self.cl.row_pitch * self.cl.height
|
||||
|
||||
def __del__(self):
|
||||
|
@ -78,7 +78,8 @@ def get_getters(ewbufs, ret):
|
|||
fakebufs.append(name)
|
||||
prt = buf._backing.reshape((-1, 4))
|
||||
cc = []
|
||||
for ii in range(prt.shape[0]): cc.append("(float4)(%ff, %ff, %ff, %ff)" % (prt[ii][0], prt[ii][1], prt[ii][2], prt[ii][3]))
|
||||
for ii in range(prt.shape[0]):
|
||||
cc.append("(float4)(%ff, %ff, %ff, %ff)" % (prt[ii][0], prt[ii][1], prt[ii][2], prt[ii][3]))
|
||||
getters.append(f"const __constant float4 const_{name}[] = {{"+', '.join(cc)+"};")
|
||||
getters.append(f"inline float4 get4_{name}(int gid) {{"+
|
||||
"int idx = gid;"+buf.st.expr()+";"+
|
||||
|
@ -112,6 +113,7 @@ class OpenCLBuffer(GPUBuffer):
|
|||
def __init__(self, shape, hostbuf:Optional[OpenCLBuffer]=None, backing:Optional[np.ndarray]=None):
|
||||
self._image = hostbuf._image if hostbuf is not None else None
|
||||
super().__init__(shape, hostbuf, backing)
|
||||
assert not (self._image and self._buf)
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x): return OpenCLBuffer(x.shape, backing=x.view(np.ndarray).astype(np.float32).ravel())
|
||||
|
@ -225,8 +227,10 @@ class OpenCLBuffer(GPUBuffer):
|
|||
if C.bs > 1:
|
||||
options.append("-DBATCH")
|
||||
assert C.py == 0, "batched conv doesn't work with y-padding"
|
||||
if C.sx == 1 and C.sy == 1 and C.dx == 1 and C.dy == 1 and C.cin == 1: options.append("-DDEPTHWISE_UNSTRIDED")
|
||||
elif C.cin == 1: options.append("-DDEPTHWISE")
|
||||
if C.sx == 1 and C.sy == 1 and C.dx == 1 and C.dy == 1 and C.cin == 1:
|
||||
options.append("-DDEPTHWISE_UNSTRIDED")
|
||||
elif C.cin == 1:
|
||||
options.append("-DDEPTHWISE")
|
||||
if int(os.getenv("MATMUL", 0)) and C.groups == 1 and C.H == 1 and C.W == 1 and C.iy == 1 and C.ix == 1 and C.oy == 1 and C.ox == 1 and C.sx == 1 and C.sy == 1 and C.dx == 1 and C.dy == 1:
|
||||
options.append("-DMATMUL")
|
||||
# NOTE: this is not actually a matmul, it's a vector * matrix
|
||||
|
@ -260,7 +264,8 @@ class OpenCLBuffer(GPUBuffer):
|
|||
return ret
|
||||
|
||||
# this option is unused
|
||||
if C.H == 1 and C.W == 1: options.append("-DONLY_1X1_CONV")
|
||||
if C.H == 1 and C.W == 1:
|
||||
options.append("-DONLY_1X1_CONV")
|
||||
|
||||
assert C.cout%4 == 0
|
||||
conv_src = CONV_SRC
|
||||
|
|
|
@ -560,13 +560,24 @@ if __name__ == "__main__":
|
|||
# put into diffuser
|
||||
timesteps = Tensor([t])
|
||||
from tinygrad.llops.ops_gpu import CL
|
||||
from tinygrad.llops.ops_gpu import CLBuffer
|
||||
from tinygrad.llops.ops_opencl import CLImage, OpenCLBuffer
|
||||
import gc
|
||||
|
||||
print(CL.mem_used/1e9, sum([prod(x.shape)*4 for x in gc.get_objects() if isinstance(x, Tensor)])/1e9)
|
||||
def print_ram():
|
||||
print(CL.mem_used/1e9, sum([prod(x.shape)*4 for x in gc.get_objects() if isinstance(x, Tensor)])/1e9)
|
||||
img_count = sum([x.is_image() for x in gc.get_objects() if isinstance(x, OpenCLBuffer)])
|
||||
print("img_count", img_count)
|
||||
buffer_bytes = sum([x.cl.size for x in gc.get_objects() if isinstance(x, CLBuffer)])
|
||||
image_bytes = sum([x.cl.row_pitch*x.cl.height for x in gc.get_objects() if isinstance(x, CLImage)])
|
||||
print("buffer bytes", buffer_bytes/1e9, "image bytes", image_bytes/1e9, "sum", (buffer_bytes+image_bytes)/1e9)
|
||||
|
||||
print_ram()
|
||||
unconditional_latent = model.model.diffusion_model(latent, timesteps, unconditional_context).realize()
|
||||
print(CL.mem_used/1e9, sum([prod(x.shape)*4 for x in gc.get_objects() if isinstance(x, Tensor)])/1e9)
|
||||
print_ram()
|
||||
latent = model.model.diffusion_model(latent, timesteps, context).realize()
|
||||
print(CL.mem_used/1e9, sum([prod(x.shape)*4 for x in gc.get_objects() if isinstance(x, Tensor)])/1e9)
|
||||
print_ram()
|
||||
|
||||
unconditional_guidance_scale = 7.5
|
||||
e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
|
||||
return e_t
|
||||
|
|
|
@ -11,7 +11,8 @@ def get_conv_args(x_shape, w_shape, stride=1, groups=1, padding=0, dilation=1, o
|
|||
# TODO: https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html#tensor-layout
|
||||
cout,cin,H,W = w_shape
|
||||
sy,sx = (stride, stride) if isinstance(stride, int) else stride
|
||||
if not isinstance(padding, int) and len(padding) == 4: px,px_,py,py_ = padding
|
||||
if not isinstance(padding, int) and len(padding) == 4:
|
||||
px,px_,py,py_ = padding
|
||||
else:
|
||||
py,px = (padding, padding) if isinstance(padding, int) else padding
|
||||
py_, px_ = py, px
|
||||
|
@ -28,7 +29,8 @@ def get_conv_args(x_shape, w_shape, stride=1, groups=1, padding=0, dilation=1, o
|
|||
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html describes these sizes well
|
||||
oy = (iy + py + py_ - dy * (H-1) - 1)//sy + 1
|
||||
ox = (ix + px + px_ - dx * (W-1) - 1)//sx + 1
|
||||
if cin*groups != cin_: raise Exception(f"Input Tensor shape {x_shape} does not match the shape of the weights {w_shape}. ({cin*groups} vs. {cin_})")
|
||||
if cin*groups != cin_:
|
||||
raise Exception(f"Input Tensor shape {x_shape} does not match the shape of the weights {w_shape}. ({cin*groups} vs. {cin_})")
|
||||
assert cout % groups == 0 and (out_shape is None or out_shape == (bs, cout, oy, ox))
|
||||
return ConvArgs(H, W, groups, cout//groups, cin, oy, ox, iy, ix, sy, sx, bs, cout, py, py_, px, px_, dy, dx, (bs, cout, oy, ox))
|
||||
|
||||
|
@ -38,7 +40,8 @@ def get_available_llops():
|
|||
for op in [os.path.splitext(x)[0] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "llops"))) if x.startswith("ops_")]:
|
||||
name = op[len("ops_"):].upper()
|
||||
DEFAULT = name if os.environ.get(name, 0) == "1" else DEFAULT
|
||||
try: _buffers[name] = [cls for cname, cls in inspect.getmembers(importlib.import_module('tinygrad.llops.'+op), inspect.isclass) if (cname.upper() == name + "BUFFER")][0]
|
||||
try:
|
||||
_buffers[name] = [cls for cname, cls in inspect.getmembers(importlib.import_module('tinygrad.llops.'+op), inspect.isclass) if (cname.upper() == name + "BUFFER")][0]
|
||||
except ImportError as e: # NOTE: this can't be put on one line due to mypy issue
|
||||
print(op, "not available", e)
|
||||
return _buffers, DEFAULT
|
|
@ -33,18 +33,28 @@ class CPUBuffer(np.ndarray):
|
|||
def reduce_op(x, op, new_shape):
|
||||
assert len(x.shape) == len(new_shape)
|
||||
axis = tuple([i for i,(a,b) in enumerate(zip(x.shape, new_shape)) if a != b])
|
||||
if x.shape == new_shape: return x[:] # this is just a copy, regardless of the reduce op
|
||||
elif op == ReduceOps.SUM: return x.sum(axis, keepdims=True)
|
||||
elif op == ReduceOps.MAX: return x.amax(axis, keepdims=True)
|
||||
if x.shape == new_shape:
|
||||
return x[:] # this is just a copy, regardless of the reduce op
|
||||
elif op == ReduceOps.SUM:
|
||||
return x.sum(axis, keepdims=True)
|
||||
elif op == ReduceOps.MAX:
|
||||
return x.amax(axis, keepdims=True)
|
||||
|
||||
def movement_op(x, op, arg=None):
|
||||
if op == MovementOps.RESHAPE: return x.reshape(arg)
|
||||
elif op == MovementOps.PERMUTE: return x.permute(arg)
|
||||
elif op == MovementOps.FLIP: return x.flip(arg)
|
||||
elif op == MovementOps.PAD: return x.custompad(arg)
|
||||
elif op == MovementOps.SHRINK: return x[tuple(slice(p[0], p[1], None) for p in arg)]
|
||||
elif op == MovementOps.EXPAND: return x.expand(arg)
|
||||
elif op == MovementOps.STRIDED: return x.contiguous().as_strided([x[0] for x in arg], [x[1] for x in arg])
|
||||
if op == MovementOps.RESHAPE:
|
||||
return x.reshape(arg)
|
||||
elif op == MovementOps.PERMUTE:
|
||||
return x.permute(arg)
|
||||
elif op == MovementOps.FLIP:
|
||||
return x.flip(arg)
|
||||
elif op == MovementOps.PAD:
|
||||
return x.custompad(arg)
|
||||
elif op == MovementOps.SHRINK:
|
||||
return x[tuple(slice(p[0], p[1], None) for p in arg)]
|
||||
elif op == MovementOps.EXPAND:
|
||||
return x.expand(arg)
|
||||
elif op == MovementOps.STRIDED:
|
||||
return x.contiguous().as_strided([x[0] for x in arg], [x[1] for x in arg])
|
||||
|
||||
PREPAD = True
|
||||
def processing_op(x,op,w,C):
|
||||
|
|
|
@ -11,15 +11,18 @@ from tinygrad.shapetracker import ShapeTracker
|
|||
CLCACHE = int(os.getenv("CLCACHE", "1"))
|
||||
class CLBuffer:
|
||||
def __init__(self, size):
|
||||
if len(CL.BUFFER_CACHE[size]) > 0: self.cl = CL.BUFFER_CACHE[size].pop()
|
||||
if len(CL.BUFFER_CACHE[size]) > 0:
|
||||
self.cl = CL.BUFFER_CACHE[size].pop()
|
||||
else:
|
||||
# TODO: on GPU OOM, clear the cache
|
||||
self.cl = cl.Buffer(CL().cl_ctx, cl.mem_flags.READ_WRITE, size)
|
||||
CL.mem_used += self.cl.size
|
||||
|
||||
def __del__(self):
|
||||
if CLCACHE: CL.BUFFER_CACHE[self.cl.size].append(self.cl)
|
||||
else: CL.mem_used -= self.cl.size
|
||||
if CLCACHE:
|
||||
CL.BUFFER_CACHE[self.cl.size].append(self.cl)
|
||||
else:
|
||||
CL.mem_used -= self.cl.size
|
||||
|
||||
class CL:
|
||||
CACHE, kernel_count, mem_used, time_sum, ops_sum = None, -1, 0, 0.0, 0.0
|
||||
|
@ -27,18 +30,22 @@ class CL:
|
|||
cl_ctx : Optional[cl.Context] = None
|
||||
cl_queue : Optional[cl.CommandQueue] = None
|
||||
def __init__(self):
|
||||
if CL.cl_queue is not None: return
|
||||
if CL.cl_queue is not None: # already initted
|
||||
return
|
||||
devices = sum([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()], [])
|
||||
if len(devices) == 0: # settle for CPU
|
||||
devices = sum([x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()], [])
|
||||
CL.cl_ctx = cl.Context(devices=[devices[int(os.getenv("CL_DEVICE", "0"))]])
|
||||
if len(devices) > 1 or DEBUG >= 1: print(f"using {CL.cl_ctx.devices}")
|
||||
if len(devices) > 1 or DEBUG >= 1:
|
||||
print(f"using {CL.cl_ctx.devices}")
|
||||
CL.cl_queue = cl.CommandQueue(self.cl_ctx, properties=cl.command_queue_properties.PROFILING_ENABLE) # this is an in-order command queue
|
||||
|
||||
@staticmethod
|
||||
def enqueue_copy(a, b, is_blocking=False):
|
||||
if CL.CACHE is not None: assert False, "can't copy while caching"
|
||||
if DEBUG >= 1: print(f"**CL** copy in {b.shape}" if isinstance(b, np.ndarray) else f"**CL** copy OUT {a.shape}")
|
||||
if CL.CACHE is not None:
|
||||
assert False, "can't copy while caching"
|
||||
if DEBUG >= 1:
|
||||
print(f"**CL** copy in {b.shape}" if isinstance(b, np.ndarray) else f"**CL** copy OUT {a.shape}")
|
||||
cl.enqueue_copy(CL().cl_queue, a, b, is_blocking=is_blocking)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
|
@ -48,19 +55,24 @@ class CLProgram:
|
|||
self.name, self.prg, self.options, self.argdtypes = f"{name}_{CLProgram.kernel_cnt}", prg.replace(f"{name}(", f"{name}_{CLProgram.kernel_cnt}("), options, argdtypes
|
||||
self.clprogram = cl.Program(CL().cl_ctx, self.prg)
|
||||
self.clprg = self.clprogram.build(options=list(self.options)).__getattr__(self.name)
|
||||
if self.argdtypes is not None: self.clprg.set_scalar_arg_dtypes(self.argdtypes)
|
||||
if self.argdtypes is not None:
|
||||
self.clprg.set_scalar_arg_dtypes(self.argdtypes)
|
||||
CLProgram.kernel_cnt += 1
|
||||
def __call__(self, *args, op_estimate=0):
|
||||
CL.kernel_count += 1
|
||||
if CL.CACHE is not None: CL.CACHE.append((self, args))
|
||||
else: e = self.clprg(CL().cl_queue, *args)
|
||||
if DEBUG >= 2: CL.cl_queue.finish()
|
||||
if CL.CACHE is not None:
|
||||
CL.CACHE.append((self, args))
|
||||
else:
|
||||
e = self.clprg(CL().cl_queue, *args)
|
||||
if DEBUG >= 2:
|
||||
CL.cl_queue.finish()
|
||||
if DEBUG >= 1:
|
||||
CL.time_sum += 0 if DEBUG <= 1 or CL.CACHE is not None else (e.profile.end - e.profile.start)
|
||||
CL.ops_sum += op_estimate
|
||||
print(f"**CL** {CL.kernel_count:6d} {self.name:20s} args {len(args[2:]):5d} kernels {str(args[0]):18s} {str(args[1]):12s} OPs {op_estimate/1e6:6.1f}M/{CL.ops_sum/1e9:7.2f}G mem {CL.mem_used/1e9:5.2f} GB " +
|
||||
("" if DEBUG <= 1 or CL.CACHE is not None else f"tm {(e.profile.end - e.profile.start)/1e3:9.2f}us/{CL.time_sum/1e6:9.2f}ms ({op_estimate/(e.profile.end - e.profile.start):8.2f} GFLOPS)"))
|
||||
if DEBUG >= 4: print(self.prg)
|
||||
if DEBUG >= 4:
|
||||
print(self.prg)
|
||||
|
||||
# **** end CL wrappers ****
|
||||
|
||||
|
@ -79,11 +91,13 @@ class GPUBuffer:
|
|||
self._base_shape : Tuple[int, ...] = hostbuf._base_shape if hostbuf is not None else self.shape
|
||||
self._backing : Optional[np.ndarray] = hostbuf._backing if hostbuf is not None else backing
|
||||
# early copy in for large buffers
|
||||
if self._backing is not None and self._backing.shape != (1,): self.cl
|
||||
if self._backing is not None and self._backing.shape != (1,):
|
||||
self.cl
|
||||
|
||||
@property
|
||||
def cl(self):
|
||||
if self._buf is None: self._buf = CLBuffer(4*prod(self._base_shape))
|
||||
if self._buf is None:
|
||||
self._buf = CLBuffer(4*prod(self._base_shape))
|
||||
if self._backing is not None:
|
||||
CL.enqueue_copy(self._buf.cl, self._backing, is_blocking=False)
|
||||
self._backing = None
|
||||
|
@ -116,14 +130,16 @@ class GPUBuffer:
|
|||
|
||||
def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, op=ReduceOps.SUM, reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc") -> GPUBuffer:
|
||||
assert C is None
|
||||
for _, b in bufs: assert prod(b.shape) < 2**31, f"GPU buffers must be under 2**31, {b.shape} isn't"
|
||||
for _, b in bufs:
|
||||
assert prod(b.shape) < 2**31, f"GPU buffers must be under 2**31, {b.shape} isn't"
|
||||
|
||||
# get the input/output shape and the reduce amount
|
||||
reduce_shape = (bufs[0][1].shape, ret.shape) if reduce_shape is None else reduce_shape
|
||||
red = prod([s for s,n in zip(*reduce_shape) if n == 1])
|
||||
|
||||
# if it's a partial reduce, assert last non reduced axis is before the first reduced axis
|
||||
if red > 1 and prod(ret.shape) != 1: assert max([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s == n and n != 1]) < min([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s != 1 and n == 1])
|
||||
if red > 1 and prod(ret.shape) != 1:
|
||||
assert max([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s == n and n != 1]) < min([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s != 1 and n == 1])
|
||||
inter_red = 256 if (prod(ret.shape) < 8192 and red >= 256) else 1
|
||||
|
||||
kernel_name = "reduce" if red > 1 else "elementwise"
|
||||
|
|
|
@ -33,7 +33,8 @@ class BatchNorm2D:
|
|||
batch_mean, batch_var = self.running_mean, self.running_var
|
||||
|
||||
# NOTE: this can be precomputed for static inference. if you manually update running_var, you have to reset this
|
||||
if Tensor.training or getattr(self, "batch_invstd", None) is None: self.batch_invstd = batch_var.add(self.eps)**-0.5
|
||||
if Tensor.training or getattr(self, "batch_invstd", None) is None:
|
||||
self.batch_invstd = batch_var.add(self.eps)**-0.5
|
||||
return batch_normalize(x, self.weight, self.bias, batch_mean, self.batch_invstd)
|
||||
|
||||
class Conv2d:
|
||||
|
|
|
@ -6,11 +6,13 @@ class Optimizer:
|
|||
self.params = [x for x in params if x.requires_grad]
|
||||
|
||||
def zero_grad(self):
|
||||
for param in self.params: param.grad = None
|
||||
for param in self.params:
|
||||
param.grad = None
|
||||
|
||||
def realize(self, extra=None):
|
||||
# TODO: corealize
|
||||
for p in self.params + extra if extra is not None else self.params: p.realize()
|
||||
for p in self.params + extra if extra is not None else self.params:
|
||||
p.realize()
|
||||
|
||||
class SGD(Optimizer):
|
||||
def __init__(self, params, lr=0.001):
|
||||
|
@ -18,7 +20,8 @@ class SGD(Optimizer):
|
|||
self.lr = lr
|
||||
|
||||
def step(self):
|
||||
for t in self.params: t.assign(t.detach() - t.grad * self.lr)
|
||||
for t in self.params:
|
||||
t.assign(t.detach() - t.grad * self.lr)
|
||||
self.realize()
|
||||
|
||||
class RMSprop(Optimizer):
|
||||
|
|
114
tinygrad/ops.py
114
tinygrad/ops.py
|
@ -49,15 +49,18 @@ if GRAPH:
|
|||
import networkx as nx # type: ignore
|
||||
G = nx.DiGraph()
|
||||
def save_graph_exit():
|
||||
for k,v in cnts.items(): print(k, v)
|
||||
for k,v in cnts.items():
|
||||
print(k, v)
|
||||
if int(os.getenv("PRUNEGRAPH", "0")):
|
||||
dead_nodes = []
|
||||
for n in G.nodes:
|
||||
# prune movementops and loadops
|
||||
if 'fillcolor' in G.nodes[n] and G.nodes[n]['fillcolor'] in ["#80ff8080", "#80ff80", "#FFFF8080", "#FFFF80"]:
|
||||
for (x,_),(_,y) in itertools.product(G.in_edges(n), G.out_edges(n)): G.add_edge(x, y)
|
||||
for (x,_),(_,y) in itertools.product(G.in_edges(n), G.out_edges(n)):
|
||||
G.add_edge(x, y)
|
||||
dead_nodes.append(n)
|
||||
for n in dead_nodes: G.remove_node(n)
|
||||
for n in dead_nodes:
|
||||
G.remove_node(n)
|
||||
print("saving", G)
|
||||
nx.drawing.nx_pydot.write_dot(G, '/tmp/net.dot')
|
||||
# -Gnslimit=100 can make it finish, but you won't like results
|
||||
|
@ -67,7 +70,8 @@ if GRAPH:
|
|||
global_num_max = 0
|
||||
def log_op(optype : OpType, op : List[Op], ret : DeviceBuffer, inp : List[DeviceBuffer]):
|
||||
cnts[optype] += 1
|
||||
if DEBUG >= 3: print(f"{op} : {', '.join([str(x.shape) for x in inp])} -> {ret.shape}")
|
||||
if DEBUG >= 3:
|
||||
print(f"{op} : {', '.join([str(x.shape) for x in inp])} -> {ret.shape}")
|
||||
if GRAPH:
|
||||
def nm(x):
|
||||
global global_num_max
|
||||
|
@ -80,15 +84,22 @@ def log_op(optype : OpType, op : List[Op], ret : DeviceBuffer, inp : List[Device
|
|||
dashed = (optype == LoadOps and getattr(ret, "_backing", None) is not None) or (getattr(ret, "st", None) is not None and not ret.st.contiguous)
|
||||
|
||||
for x in inp:
|
||||
if len(op) <= 2: sop = '.'.join([str(y).split(".")[1] for y in op][::-1])
|
||||
elif len(op) <= 4: sop = '.'.join([str(y).split(".")[1][0:2] for y in op][::-1])
|
||||
else: sop = str(len(op))
|
||||
if len(op) <= 2:
|
||||
sop = '.'.join([str(y).split(".")[1] for y in op][::-1])
|
||||
elif len(op) <= 4:
|
||||
sop = '.'.join([str(y).split(".")[1][0:2] for y in op][::-1])
|
||||
else:
|
||||
sop = str(len(op))
|
||||
G.add_edge(nm(x), nm(ret), label=sop)
|
||||
if 'label' not in G.nodes[nm(x)]: G.nodes[nm(x)]['label'] = str(x.shape)
|
||||
if nm(ret) not in G.nodes: G.add_node(nm(ret))
|
||||
if 'label' not in G.nodes[nm(x)]:
|
||||
G.nodes[nm(x)]['label'] = str(x.shape)
|
||||
if nm(ret) not in G.nodes:
|
||||
G.add_node(nm(ret))
|
||||
|
||||
if optype == ReduceOps: G.nodes[nm(ret)]['label'] = str(set(x.shape for x in inp))+"\n"+str(ret.shape)
|
||||
else: G.nodes[nm(ret)]['label'] = str(ret.shape)
|
||||
if optype == ReduceOps:
|
||||
G.nodes[nm(ret)]['label'] = str(set(x.shape for x in inp))+"\n"+str(ret.shape)
|
||||
else:
|
||||
G.nodes[nm(ret)]['label'] = str(ret.shape)
|
||||
G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('80' if dashed else '')) if optype in top_colors else "#ffffff"
|
||||
G.nodes[nm(ret)]['style'] = 'filled, dashed' if dashed else 'filled'
|
||||
|
||||
|
@ -96,11 +107,14 @@ def log_op(optype : OpType, op : List[Op], ret : DeviceBuffer, inp : List[Device
|
|||
# **** realize helpers ****
|
||||
|
||||
def _ast(x: Union[LazyBuffer, LazyOp], buf_names: Dict[LazyBuffer, str], code_for_op: Dict[Op, str]) -> str:
|
||||
if isinstance(x, LazyBuffer): return buf_names[x]
|
||||
if isinstance(x, LazyBuffer):
|
||||
return buf_names[x]
|
||||
srcs_code = [_ast(src, buf_names, code_for_op) for src in x.src]
|
||||
code = code_for_op[x.op]
|
||||
if len(srcs_code) >= 1: code = code.replace("A", srcs_code[0])
|
||||
if len(srcs_code) >= 2: code = code.replace("B", srcs_code[1])
|
||||
if len(srcs_code) >= 1:
|
||||
code = code.replace("A", srcs_code[0])
|
||||
if len(srcs_code) >= 2:
|
||||
code = code.replace("B", srcs_code[1])
|
||||
return code
|
||||
|
||||
# **** realize functions ****
|
||||
|
@ -155,7 +169,8 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
|
|||
src = psrcs[0][1].op.src[0]
|
||||
reduce_shape = (src.shape, psrcs[0][1].shape)
|
||||
|
||||
if MERGE_ELEMENTWISE_INTO_REDUCE and getattr(self.dbuffer, "start_for_op", None) and src.realized is None and src.optype == BinaryOps and len(src.children) <= 1: src = src.op
|
||||
if MERGE_ELEMENTWISE_INTO_REDUCE and getattr(self.dbuffer, "start_for_op", None) and src.realized is None and src.optype == BinaryOps and len(src.children) <= 1:
|
||||
src = src.op
|
||||
for i,x in enumerate(get_lazybuffers(src) if isinstance(src, LazyOp) else [src]):
|
||||
real_srcs[x] = None
|
||||
buf_names[x] = f"earlyarg_{i}"
|
||||
|
@ -164,19 +179,24 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
|
|||
del real_srcs[psrcs[0][0]]
|
||||
buf_names[psrcs[0][0]] = "acc"
|
||||
|
||||
for x in real_srcs.keys(): real_srcs[x] = x.realize(self.device)
|
||||
for x in real_srcs.keys():
|
||||
real_srcs[x] = x.realize(self.device)
|
||||
# fast path, no middle buffers
|
||||
return self.dbuffer(self.shape)._processing_op([(buf_names[lb], db) for lb,db in real_srcs.items()],
|
||||
_ast(self.op, buf_names, self.dbuffer.code_for_op), earlycode=earlycode, earlybufs=set(x for x in buf_names.values() if x.startswith("earlyarg_")),
|
||||
C=conv_args, reduce_shape=reduce_shape), \
|
||||
list(real_srcs.values()), ProcessingOps if conv_args is not None else (ReduceOps if reduce_shape[0] != reduce_shape[1] else BinaryOps)
|
||||
else:
|
||||
for x in real_srcs.keys(): real_srcs[x] = x.realize(self.device)
|
||||
for x in real_srcs.keys():
|
||||
real_srcs[x] = x.realize(self.device)
|
||||
# slow path, creates middle buffers
|
||||
def ast_eval(x: Union[LazyBuffer, LazyOp]) -> DeviceBuffer:
|
||||
if isinstance(x, LazyBuffer): return real_srcs[x]
|
||||
if x.op in UnaryOps: return ast_eval(x.src[0]).unary_op(x.op)
|
||||
if x.op in BinaryOps: return ast_eval(x.src[0]).binary_op(x.op, ast_eval(x.src[1]))
|
||||
if isinstance(x, LazyBuffer):
|
||||
return real_srcs[x]
|
||||
if x.op in UnaryOps:
|
||||
return ast_eval(x.src[0]).unary_op(x.op)
|
||||
if x.op in BinaryOps:
|
||||
return ast_eval(x.src[0]).binary_op(x.op, ast_eval(x.src[1]))
|
||||
return ast_eval(self.op), list(real_srcs.values()), BinaryOps
|
||||
|
||||
_realize = {LoadOps:_realize_loadops, ReduceOps:_realize_reduceops, MovementOps:_realize_movementops, BinaryOps:_realize_binaryops, ProcessingOps:_realize_processingops}
|
||||
|
@ -201,28 +221,34 @@ class LazyBuffer:
|
|||
lazycache : weakref.WeakValueDictionary[LazyOp, LazyBuffer] = weakref.WeakValueDictionary()
|
||||
def __new__(cls, device, shape, optype, op):
|
||||
# loadops aren't cached
|
||||
if optype == LoadOps: return super().__new__(cls)
|
||||
if optype == LoadOps:
|
||||
return super().__new__(cls)
|
||||
wop = (device, optype, get_weakop(op)) # NOTE: shape should be deterministic. annoying to cache with the ShapeTracker
|
||||
# NOTE: we need "ret" to prevent the new buffer from being immediately deleted
|
||||
if wop not in LazyBuffer.lazycache: LazyBuffer.lazycache[wop] = ret = super().__new__(cls) # noqa: F841, pylint: disable=W0612
|
||||
if wop not in LazyBuffer.lazycache:
|
||||
LazyBuffer.lazycache[wop] = ret = super().__new__(cls) # noqa: F841, pylint: disable=W0612
|
||||
return LazyBuffer.lazycache[wop]
|
||||
|
||||
def __init__(self, device, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp):
|
||||
if getattr(self, 'device', None) is not None: return # cache hit, we return and don't reinit
|
||||
if getattr(self, 'device', None) is not None:
|
||||
return # cache hit, we return and don't reinit
|
||||
self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
|
||||
self.shape, self.optype, self.op = self.st.shape, optype, op
|
||||
self.realized : Optional[DeviceBuffer] = None
|
||||
self.device, self.dbuffer = device, Device._buffers[device]
|
||||
self.children : weakref.WeakSet[LazyBuffer] = weakref.WeakSet()
|
||||
# NOTE: op should be read only after construction of LazyBuffer
|
||||
for x in get_lazybuffers(op): x.children.add(self)
|
||||
if not LAZY: self.realize()
|
||||
for x in get_lazybuffers(op):
|
||||
x.children.add(self)
|
||||
if not LAZY:
|
||||
self.realize()
|
||||
|
||||
def __repr__(self): return f"<LB {self.shape} op:{self.op.op if self.realized is None else 'realized'}>"
|
||||
|
||||
# this produces a device buffer
|
||||
def realize(self:LazyBuffer, required_device=None) -> DeviceBuffer:
|
||||
if required_device is not None: assert required_device == self.device
|
||||
if required_device is not None:
|
||||
assert required_device == self.device
|
||||
if self.realized is None:
|
||||
# we haven't realized the Buffer yet
|
||||
self.realized, real_srcs, real_type = _realize[self.optype](self)
|
||||
|
@ -244,7 +270,8 @@ class LazyBuffer:
|
|||
def contiguous_op(self:LazyBuffer) -> LazyBuffer: return self if self.st.contiguous else self.unary_op(UnaryOps.NOOP)
|
||||
|
||||
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
if self.shape == tuple(new_shape): return self
|
||||
if self.shape == tuple(new_shape):
|
||||
return self
|
||||
reduce = list(enumerate(zip(self.shape, new_shape)))
|
||||
# move the reduce axes to the end
|
||||
x = self.movement_op(MovementOps.PERMUTE, [i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n])
|
||||
|
@ -263,24 +290,34 @@ class LazyBuffer:
|
|||
arg = tuple(copy(arg))
|
||||
|
||||
# instant nops
|
||||
if op in [MovementOps.RESHAPE, MovementOps.EXPAND] and arg == self.shape: return self
|
||||
if op == MovementOps.PERMUTE and arg == tuple(range(len(self.shape))): return self
|
||||
if op == MovementOps.SHRINK and arg == tuple((0,i) for i in self.shape): return self
|
||||
if op == MovementOps.PAD and arg == tuple((0,0) for _ in self.shape): return self
|
||||
if op == MovementOps.FLIP and all(s == 1 or i not in arg for i,s in enumerate(self.shape)): return self
|
||||
if op in [MovementOps.RESHAPE, MovementOps.EXPAND] and arg == self.shape:
|
||||
return self
|
||||
if op == MovementOps.PERMUTE and arg == tuple(range(len(self.shape))):
|
||||
return self
|
||||
if op == MovementOps.SHRINK and arg == tuple((0,i) for i in self.shape):
|
||||
return self
|
||||
if op == MovementOps.PAD and arg == tuple((0,0) for _ in self.shape):
|
||||
return self
|
||||
if op == MovementOps.FLIP and all(s == 1 or i not in arg for i,s in enumerate(self.shape)):
|
||||
return self
|
||||
|
||||
# two ops in a row is one op
|
||||
if op in [MovementOps.RESHAPE, MovementOps.EXPAND, MovementOps.SHRINK] and self.realized is None and self.op.op == op: return self.op.src[0].movement_op(op, arg)
|
||||
if op == MovementOps.PERMUTE and self.realized is None and self.op.op == op: return self.op.src[0].movement_op(op, tuple(self.op.arg[i] for i in arg))
|
||||
if op == MovementOps.PAD and self.realized is None and self.op.op == op: return self.op.src[0].movement_op(op, tuple((b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)))
|
||||
if op in [MovementOps.RESHAPE, MovementOps.EXPAND, MovementOps.SHRINK] and self.realized is None and self.op.op == op:
|
||||
return self.op.src[0].movement_op(op, arg)
|
||||
if op == MovementOps.PERMUTE and self.realized is None and self.op.op == op:
|
||||
return self.op.src[0].movement_op(op, tuple(self.op.arg[i] for i in arg))
|
||||
if op == MovementOps.PAD and self.realized is None and self.op.op == op:
|
||||
return self.op.src[0].movement_op(op, tuple((b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)))
|
||||
|
||||
# some permutes are actually just reshapes
|
||||
if op == MovementOps.PERMUTE and ShapeTracker(self.shape).movement_op(op, arg).contiguous: return self.movement_op(MovementOps.RESHAPE, tuple(self.shape[i] for i in arg))
|
||||
if op == MovementOps.PERMUTE and ShapeTracker(self.shape).movement_op(op, arg).contiguous:
|
||||
return self.movement_op(MovementOps.RESHAPE, tuple(self.shape[i] for i in arg))
|
||||
|
||||
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and self.realized is None and len(self.children) == 0 and (SHUFFLE_PAD_OPS or op != MovementOps.PAD) and op not in [MovementOps.EXPAND, MovementOps.STRIDED]:
|
||||
# if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead
|
||||
def replace_with_movement_op(y:Union[LazyOp, LazyBuffer]) -> LazyBuffer:
|
||||
if isinstance(y, LazyBuffer): return y.movement_op(op, arg)
|
||||
if isinstance(y, LazyBuffer):
|
||||
return y.movement_op(op, arg)
|
||||
assert y.op in BinaryOps or y.op in UnaryOps
|
||||
return elementwise_op(y.op, *[replace_with_movement_op(z) for z in y.src]) # type: ignore
|
||||
return replace_with_movement_op(self.op)
|
||||
|
@ -300,7 +337,8 @@ class LazyBuffer:
|
|||
def processing_op(self:LazyBuffer, op:ProcessingOps, w:LazyBuffer, C:ConvArgs) -> LazyBuffer:
|
||||
x = self
|
||||
# TODO: fixup C?
|
||||
if NOCONV or not getattr(x.dbuffer, "SUPPORTS_PADDING", False): x = x.slice(((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
|
||||
if NOCONV or not getattr(x.dbuffer, "SUPPORTS_PADDING", False):
|
||||
x = x.slice(((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
|
||||
|
||||
if NOCONV or not getattr(x.dbuffer, "processing_op", False):
|
||||
# universal conv, just mul and reduce
|
||||
|
|
|
@ -13,8 +13,10 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup
|
|||
assert len(shape) == len(strides)
|
||||
ret = [(shape[0], strides[0])]
|
||||
for i in range(1, len(shape)):
|
||||
if (strides[i] != 0 and ret[-1][1] == shape[i]*strides[i]) or (strides[i] == 0 and ret[-1][1] == 0): ret[-1] = (ret[-1][0] * shape[i], strides[i])
|
||||
else: ret.append((shape[i], strides[i]))
|
||||
if (strides[i] != 0 and ret[-1][1] == shape[i]*strides[i]) or (strides[i] == 0 and ret[-1][1] == 0):
|
||||
ret[-1] = (ret[-1][0] * shape[i], strides[i])
|
||||
else:
|
||||
ret.append((shape[i], strides[i]))
|
||||
return ret
|
||||
|
||||
class View:
|
||||
|
@ -54,14 +56,14 @@ class ZeroView:
|
|||
ViewTypes = Union[View, ZeroView]
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def strides_for_shape(shape):
|
||||
def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
strides = [1]
|
||||
for d in shape[::-1][:-1]:
|
||||
strides = [d*strides[0]] + strides
|
||||
return tuple(strides)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def view_from_shape(shape:Tuple[int, ...]):
|
||||
def view_from_shape(shape:Tuple[int, ...]) -> View:
|
||||
assert all([isinstance(x, int) for x in shape]) and len(shape) != 0
|
||||
return View(tuple(shape), strides_for_shape(shape))
|
||||
|
||||
|
@ -92,8 +94,10 @@ class ShapeTracker:
|
|||
# if we replace, confirm the ops taken fold into one view
|
||||
def strided(self, *arg):
|
||||
view = View([x[0] for x in arg], [x[1] for x in arg])
|
||||
if self.contiguous: self.views[-1] = view
|
||||
else: self.views.append(view)
|
||||
if self.contiguous:
|
||||
self.views[-1] = view
|
||||
else:
|
||||
self.views.append(view)
|
||||
|
||||
def reshape(self, *new_shape):
|
||||
assert all([isinstance(x, int) for x in new_shape])
|
||||
|
@ -107,8 +111,10 @@ class ShapeTracker:
|
|||
return
|
||||
|
||||
view = View(new_shape, strides_for_shape(new_shape))
|
||||
if self.contiguous: self.views[-1] = view # NOTE: if it's contiguous it can't have an offset
|
||||
else: self.views.append(view)
|
||||
if self.contiguous:
|
||||
self.views[-1] = view # NOTE: if it's contiguous it can't have an offset
|
||||
else:
|
||||
self.views.append(view)
|
||||
|
||||
def permute(self, *axis):
|
||||
assert all([isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis])
|
||||
|
|
|
@ -21,10 +21,13 @@ class Tensor:
|
|||
data = data.realize().toCPU()
|
||||
|
||||
if isinstance(data, np.ndarray):
|
||||
if data.shape == tuple(): data = data.reshape((1,))
|
||||
if data.shape == tuple():
|
||||
data = data.reshape((1,))
|
||||
self.lazydata = LazyBuffer.fromCPU(data.astype(np.float32), device)
|
||||
elif isinstance(data, LazyBuffer): self.lazydata = data
|
||||
else: raise Exception(f"can't create Tensor from {data}")
|
||||
elif isinstance(data, LazyBuffer):
|
||||
self.lazydata = data
|
||||
else:
|
||||
raise Exception(f"can't create Tensor from {data}")
|
||||
|
||||
# tensors have gradients, buffers do not
|
||||
self.grad : Optional[Tensor] = None
|
||||
|
@ -53,7 +56,8 @@ class Tensor:
|
|||
return self
|
||||
|
||||
def assign(self, x):
|
||||
if not isinstance(x, Tensor): x = Tensor(x)
|
||||
if not isinstance(x, Tensor):
|
||||
x = Tensor(x)
|
||||
assert self.shape == x.shape
|
||||
self.lazydata = x.lazydata
|
||||
return x
|
||||
|
@ -69,11 +73,13 @@ class Tensor:
|
|||
def to_(self, device:str):
|
||||
assert self.lazydata.realized is None
|
||||
self.lazydata.device = device
|
||||
if self.grad: self.grad.lazydata.device = device
|
||||
if self.grad:
|
||||
self.grad.lazydata.device = device
|
||||
|
||||
def to(self, device:str):
|
||||
ret = Tensor(self.lazydata, device)
|
||||
if self.grad: ret.grad = self.grad.to(device)
|
||||
if self.grad:
|
||||
ret.grad = self.grad.to(device)
|
||||
return ret
|
||||
|
||||
# ***** creation helper functions *****
|
||||
|
@ -121,7 +127,8 @@ class Tensor:
|
|||
self.grad = Tensor.ones(*self.shape, device=self.device, requires_grad=False)
|
||||
|
||||
for t0 in reversed(self.deepwalk()):
|
||||
if not any(x.requires_grad for x in t0._ctx.parents): continue
|
||||
if not any(x.requires_grad for x in t0._ctx.parents):
|
||||
continue
|
||||
assert (t0.grad is not None)
|
||||
grads = t0._ctx.backward(t0.grad.lazydata)
|
||||
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
|
||||
|
@ -137,7 +144,8 @@ class Tensor:
|
|||
def __getitem__(self, val):
|
||||
arg = []
|
||||
for i, s in enumerate(val if isinstance(val, (list, tuple)) else [val]) if val is not None else []:
|
||||
if isinstance(s, int): s = slice(s, s+1, None)
|
||||
if isinstance(s, int):
|
||||
s = slice(s, s+1, None)
|
||||
arg.append((s.start if s.start is not None else 0,
|
||||
(s.stop if s.stop >=0 else self.shape[i]+s.stop) if s.stop is not None else self.shape[i]))
|
||||
assert s.step is None or s.step == 1
|
||||
|
@ -145,11 +153,13 @@ class Tensor:
|
|||
|
||||
def cat(self, *args, dim=0):
|
||||
dim = (dim + len(self.shape)) if dim < 0 else dim
|
||||
for y in args: assert len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim)
|
||||
for y in args:
|
||||
assert len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim)
|
||||
args = [self] + list(args)
|
||||
shape_cumsum = [0, *itertools.accumulate(y.shape[dim] for y in args)]
|
||||
slc = [[(0, s) for s in self.shape] for _ in args]
|
||||
for s,k in zip(slc, shape_cumsum): s[dim] = (-k, shape_cumsum[-1]-k)
|
||||
for s,k in zip(slc, shape_cumsum):
|
||||
s[dim] = (-k, shape_cumsum[-1]-k)
|
||||
return functools.reduce(Tensor.__iadd__, [arg.slice(arg=s) for arg,s in zip(args, slc)])
|
||||
|
||||
# TODO: make this nicer with syntactic sugar in slice
|
||||
|
@ -164,8 +174,10 @@ class Tensor:
|
|||
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
|
||||
cin, cout = w.shape[-2], w.shape[-1]
|
||||
out_shape_t = tuple(list(self.shape[0:-2])+[cout,-1])
|
||||
if len(self.shape) > 1: order = tuple(list(range(len(self.shape)-2))+[len(self.shape)-1, len(self.shape)-2])
|
||||
else: order, out_shape_t = (0,), (cout, )
|
||||
if len(self.shape) > 1:
|
||||
order = tuple(list(range(len(self.shape)-2))+[len(self.shape)-1, len(self.shape)-2])
|
||||
else:
|
||||
order, out_shape_t = (0,), (cout, )
|
||||
worder = tuple(list(range(len(w.shape)-2))+[len(w.shape)-1, len(w.shape)-2])
|
||||
|
||||
# NOTE: with NHWC we can remove the transposes
|
||||
|
@ -185,8 +197,10 @@ class Tensor:
|
|||
def flatten(self, start_dim=0): return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1]))
|
||||
|
||||
def _reduce(self, fxn, axis=None, keepdim=False):
|
||||
if axis is None: axis = range(len(self.shape))
|
||||
if isinstance(axis, int): axis = [axis]
|
||||
if axis is None:
|
||||
axis = range(len(self.shape))
|
||||
if isinstance(axis, int):
|
||||
axis = [axis]
|
||||
axis = tuple([x if x >= 0 else x+len(self.shape) for x in axis])
|
||||
shape = [self.shape[i] for i in range(len(self.shape)) if i not in axis]
|
||||
ret = fxn(axis=axis)
|
||||
|
@ -213,7 +227,8 @@ class Tensor:
|
|||
return m - ss.log()
|
||||
|
||||
def dropout(self, p=0.5):
|
||||
if not Tensor.training: return self
|
||||
if not Tensor.training:
|
||||
return self
|
||||
_mask = np.asarray(np.random.binomial(1, 1.0-p, size=self.shape), dtype=self.dtype)
|
||||
return self * Tensor(_mask, requires_grad=False, device=self.device) * (1/(1.0 - p))
|
||||
|
||||
|
@ -307,7 +322,8 @@ class Function:
|
|||
def apply(cls, *x:Tensor, **kwargs):
|
||||
ctx = cls(x[0].device, *x)
|
||||
ret = Tensor(ctx.forward(*[t.lazydata for t in x], **kwargs), device=ctx.device, requires_grad=ctx.requires_grad)
|
||||
if ctx.requires_grad and not Tensor.no_grad: ret._ctx = ctx # used by autograd engine
|
||||
if ctx.requires_grad and not Tensor.no_grad:
|
||||
ret._ctx = ctx # used by autograd engine
|
||||
return ret
|
||||
|
||||
# register functions to move between devices
|
||||
|
@ -319,7 +335,8 @@ for device in [device for device in Device.__dict__.keys() if device[0] != "_"]:
|
|||
def register(name:str, fxn:Function):
|
||||
setattr(Tensor, "_"+name if (getattr(Tensor, name, None) is not None) else name, functools.partialmethod(fxn.apply))
|
||||
for name, cls in inspect.getmembers(importlib.import_module('tinygrad.mlops'), inspect.isclass):
|
||||
if name[0] != "_" and name != "Function" and not name.endswith("Ops"): register(name.lower(), cls)
|
||||
if name[0] != "_" and name != "Function" and not name.endswith("Ops"):
|
||||
register(name.lower(), cls)
|
||||
|
||||
# register the operators
|
||||
# TODO: add div
|
||||
|
@ -327,4 +344,5 @@ def register_op(name, fxn):
|
|||
setattr(Tensor, f"__{name}__", fxn)
|
||||
setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(fxn(self,x)))
|
||||
setattr(Tensor, f"__r{name}__", lambda self,x: fxn(x,self))
|
||||
for name in ['add', 'sub', 'mul', 'pow', 'matmul']: register_op(name, getattr(Tensor, name))
|
||||
for name in ['add', 'sub', 'mul', 'pow', 'matmul']:
|
||||
register_op(name, getattr(Tensor, name))
|
Loading…
Reference in New Issue