mirror of https://github.com/commaai/tinygrad.git
remove get_lazyop_info (#5570)
* don't use get_lazyop_info more * keep that min * no ptx for that test
This commit is contained in:
parent
9d7edc9269
commit
2de82b8a5d
|
@ -7,10 +7,9 @@ from tinygrad.device import Compiled
|
|||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
|
||||
from tinygrad.helpers import DEBUG, ansilen, getenv
|
||||
from tinygrad.ops import MetaOps, get_lazyop_info
|
||||
from tinygrad.ops import MetaOps
|
||||
from tinygrad.shape.symbolic import sym_infer
|
||||
|
||||
|
||||
def get_sched_resnet():
|
||||
mdl = ResNet50()
|
||||
optim = (nn.optim.LARS if getenv("LARS") else nn.optim.SGD)(nn.state.get_parameters(mdl))
|
||||
|
@ -78,8 +77,6 @@ if __name__ == "__main__":
|
|||
running_gflops = 0
|
||||
usage = {}
|
||||
for i,si in enumerate(sched):
|
||||
ops = get_lazyop_info(si.ast.src[0]).flops
|
||||
|
||||
if DEBUG >= 2: print(si.ast)
|
||||
|
||||
rawbufs = bufs_from_lin(Kernel(si.ast))
|
||||
|
@ -107,6 +104,7 @@ if __name__ == "__main__":
|
|||
choices = []
|
||||
for lin in lins:
|
||||
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
|
||||
ops = lin.to_program().op_estimate
|
||||
gflops = sym_infer(ops, {k:k.min for k in lin.ast.vars()})*1e-9/tm
|
||||
choices.append((tm, gflops, lin.linearize()))
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import lower_schedule_item
|
||||
from tinygrad.codegen.uops import flops_mem, UOps, UOp
|
||||
|
@ -7,12 +8,6 @@ from tinygrad.codegen.uopgraph import UOpGraph
|
|||
from tinygrad.ops import BinaryOps, TernaryOps
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
# TODO: can copy this in here when we remove it
|
||||
#from tinygrad.ops import get_lazyop_info
|
||||
#info = get_lazyop_info(ast)
|
||||
#print(ops, mem, expected_mem)
|
||||
#print(info.flops, info.mem_estimate)
|
||||
|
||||
# **************** new FlopCounter ****************
|
||||
|
||||
def get_stats(x:Tensor):
|
||||
|
@ -21,6 +16,7 @@ def get_stats(x:Tensor):
|
|||
return ei.prg.op_estimate, ei.prg.mem_estimate
|
||||
|
||||
class TestUOpsStats(unittest.TestCase):
|
||||
@unittest.skipIf(getenv("PTX"), "wrong in PTX")
|
||||
def test_simple_add(self):
|
||||
a = Tensor.empty(100,100)
|
||||
b = Tensor.empty(100,100)
|
||||
|
|
|
@ -1,90 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from tinygrad import dtypes, Tensor
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, get_lazyop_info, BufferOps, MemBuffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.codegen.uops import flops_mem
|
||||
|
||||
class TestFlopCounter(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.buf0 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float32, ShapeTracker.from_shape((4,))))
|
||||
self.buf1 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,))))
|
||||
self.buf2 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,4))))
|
||||
|
||||
def compare_flop_counters(self, ast):
|
||||
info = get_lazyop_info(ast.src[0])
|
||||
lin = Kernel(ast)
|
||||
# NOTE: why does hand coded optimizations change flops for the GEMM?
|
||||
#lin.hand_coded_optimizations()
|
||||
lin.linearize()
|
||||
ops, mem = flops_mem(lin.uops.uops, ignore_indexing=True)
|
||||
run_count = prod((lin.global_size or []) + (lin.local_size or []))
|
||||
self.assertEqual(info.flops, ops*run_count)
|
||||
print(info.flops, info.mem_estimate, "vs", ops*run_count, mem*run_count)
|
||||
#lin.uops.print()
|
||||
|
||||
def test_flops_sin(self):
|
||||
op0 = LazyOp(UnaryOps.SIN, (self.buf0,), None)
|
||||
info = get_lazyop_info(op0)
|
||||
self.assertEqual(info.flops, 4)
|
||||
|
||||
def test_flops_add(self):
|
||||
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
||||
info = get_lazyop_info(op0)
|
||||
self.assertEqual(info.flops, 4)
|
||||
|
||||
def test_flops_add_twice(self):
|
||||
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
||||
op1 = LazyOp(BinaryOps.ADD, (op0,self.buf1,), None)
|
||||
info = get_lazyop_info(op1)
|
||||
self.assertEqual(info.flops, 8)
|
||||
|
||||
def test_flops_add_self(self):
|
||||
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
||||
op1 = LazyOp(BinaryOps.ADD, (op0,op0,), None)
|
||||
info = get_lazyop_info(op1)
|
||||
self.assertEqual(info.flops, 8)
|
||||
|
||||
def test_flops_add_roundabout_self(self):
|
||||
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
||||
op1 = LazyOp(BinaryOps.ADD, (op0,self.buf1,), None)
|
||||
op2 = LazyOp(BinaryOps.ADD, (op0,op1,), None)
|
||||
info = get_lazyop_info(op2)
|
||||
self.assertEqual(info.flops, 12)
|
||||
|
||||
def test_flops_red(self):
|
||||
op0 = LazyOp(BinaryOps.MUL, (self.buf0,self.buf1,), None)
|
||||
op1 = LazyOp(ReduceOps.SUM, (op0,), (0,))
|
||||
op2 = LazyOp(BinaryOps.ADD, (op1, op1,), None)
|
||||
info = get_lazyop_info(op2)
|
||||
self.assertEqual(info.flops, 9)
|
||||
|
||||
def test_flops_sum1d(self):
|
||||
op0 = LazyOp(ReduceOps.SUM, (self.buf0,), (0,))
|
||||
info = get_lazyop_info(op0)
|
||||
self.assertEqual(info.flops, 4)
|
||||
self.assertEqual(info.shape, (1,))
|
||||
|
||||
def test_flops_sum2d(self):
|
||||
op0 = LazyOp(ReduceOps.SUM, (self.buf2,), (0,))
|
||||
info = get_lazyop_info(op0)
|
||||
self.assertEqual(info.flops, 16)
|
||||
self.assertEqual(info.shape, (1,4))
|
||||
|
||||
op1 = LazyOp(ReduceOps.SUM, (op0,), (1,))
|
||||
info = get_lazyop_info(op1)
|
||||
self.assertEqual(info.flops, 16+4)
|
||||
self.assertEqual(info.shape, (1,1))
|
||||
|
||||
def test_flops_conv(self):
|
||||
out = Tensor.empty(16,3,16,16).conv2d(Tensor.empty(64,3,3,3))
|
||||
self.compare_flop_counters(out.schedule()[-1].ast)
|
||||
|
||||
def test_flops_gemm(self):
|
||||
out = Tensor.empty(4,16,16) @ Tensor.empty(4,16,16)
|
||||
self.compare_flop_counters(out.schedule()[-1].ast)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -4,8 +4,7 @@ from dataclasses import replace
|
|||
from collections import defaultdict
|
||||
from typing import Optional, List, Tuple, cast, Dict, Union, Final, DefaultDict
|
||||
|
||||
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, MetaOps, UNSAFE_PAD_OPS, \
|
||||
verify_lazyop, KernelInfo, get_lazyop_info
|
||||
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, MetaOps, UNSAFE_PAD_OPS, verify_lazyop, KernelInfo
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer, TensorCore, Program
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
|
@ -776,8 +775,7 @@ class Kernel:
|
|||
if getenv("RUN_PROCESS_REPLAY"):
|
||||
table_name = f"process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}"
|
||||
diskcache_put(table_name, id(self), (self.ast, self.opts, self.applied_opts, name, src, {k:v.value for k,v in ContextVar._cache.items()}))
|
||||
info = get_lazyop_info(self.ast.src[0]) # TODO: this should be removed
|
||||
ops, mem = flops_mem(self.uops.uops)
|
||||
ops, mem = flops_mem(self.uops.uops, ignore_indexing=True)
|
||||
run_count = prod((self.global_size or []) + (self.local_size or []))
|
||||
return Program(self.name, src, self.opts.device, self.global_size, self.local_size,
|
||||
self.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
|
||||
self.uops, ops * run_count, min(mem * run_count, sum(arg.dtype.itemsize * arg.st.real_size() for arg in self.membufs)))
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Union, Tuple, Any, List, Dict, Callable
|
|||
import functools, hashlib, math, operator, ctypes, struct
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import prod, dedup, pretty_print
|
||||
from tinygrad.helpers import dedup, pretty_print
|
||||
from tinygrad.dtype import dtypes, DType, ConstType
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
@ -97,36 +97,6 @@ class LazyOp:
|
|||
def const(val, dtype:DType, shape:Tuple[sint, ...]):
|
||||
return LazyOp(BufferOps.CONST, (), ConstBuffer(val, dtype, ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape)))
|
||||
|
||||
# **************** independent FlopCounter ****************
|
||||
|
||||
@dataclass
|
||||
class FlopCounter:
|
||||
shape: Tuple[int, ...]
|
||||
flops: sint
|
||||
mem: Dict[int, int]
|
||||
@property
|
||||
def mem_estimate(self): return sum(self.mem.values())
|
||||
def consume_flops(self):
|
||||
self.flops, ret = 0, self.flops
|
||||
return ret
|
||||
|
||||
InterpretedFlopCounter: Dict[Op, Callable] = {
|
||||
BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, 0, {arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
|
||||
BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, 0, {}),
|
||||
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
|
||||
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # cast uses no flops
|
||||
UnaryOps.BITCAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # bitcast uses no flops
|
||||
**{op:lambda self: FlopCounter(self.shape, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op not in {UnaryOps.CAST, UnaryOps.BITCAST}}, # noqa: E501
|
||||
**{op:lambda self,y: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
|
||||
**{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501
|
||||
TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_lazyop_info(ast:LazyOp) -> FlopCounter:
|
||||
@functools.lru_cache(None) # NOTE: this cache needs to be recreated for new ASTs
|
||||
def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
|
||||
return run_ast(ast)
|
||||
|
||||
# **************** ops in python ****************
|
||||
|
||||
def hook_overflow(dv, fxn):
|
||||
|
|
Loading…
Reference in New Issue