remove get_lazyop_info (#5570)

* don't use get_lazyop_info more

* keep that min

* no ptx for that test
This commit is contained in:
George Hotz 2024-07-19 03:05:33 -07:00 committed by GitHub
parent 9d7edc9269
commit 2de82b8a5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 8 additions and 136 deletions

View File

@ -7,10 +7,9 @@ from tinygrad.device import Compiled
from tinygrad.engine.schedule import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
from tinygrad.helpers import DEBUG, ansilen, getenv from tinygrad.helpers import DEBUG, ansilen, getenv
from tinygrad.ops import MetaOps, get_lazyop_info from tinygrad.ops import MetaOps
from tinygrad.shape.symbolic import sym_infer from tinygrad.shape.symbolic import sym_infer
def get_sched_resnet(): def get_sched_resnet():
mdl = ResNet50() mdl = ResNet50()
optim = (nn.optim.LARS if getenv("LARS") else nn.optim.SGD)(nn.state.get_parameters(mdl)) optim = (nn.optim.LARS if getenv("LARS") else nn.optim.SGD)(nn.state.get_parameters(mdl))
@ -78,8 +77,6 @@ if __name__ == "__main__":
running_gflops = 0 running_gflops = 0
usage = {} usage = {}
for i,si in enumerate(sched): for i,si in enumerate(sched):
ops = get_lazyop_info(si.ast.src[0]).flops
if DEBUG >= 2: print(si.ast) if DEBUG >= 2: print(si.ast)
rawbufs = bufs_from_lin(Kernel(si.ast)) rawbufs = bufs_from_lin(Kernel(si.ast))
@ -107,6 +104,7 @@ if __name__ == "__main__":
choices = [] choices = []
for lin in lins: for lin in lins:
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10) tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
ops = lin.to_program().op_estimate
gflops = sym_infer(ops, {k:k.min for k in lin.ast.vars()})*1e-9/tm gflops = sym_infer(ops, {k:k.min for k in lin.ast.vars()})*1e-9/tm
choices.append((tm, gflops, lin.linearize())) choices.append((tm, gflops, lin.linearize()))

View File

@ -1,5 +1,6 @@
import unittest import unittest
from tinygrad import Tensor from tinygrad import Tensor
from tinygrad.helpers import getenv
from tinygrad.engine.schedule import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import lower_schedule_item from tinygrad.engine.realize import lower_schedule_item
from tinygrad.codegen.uops import flops_mem, UOps, UOp from tinygrad.codegen.uops import flops_mem, UOps, UOp
@ -7,12 +8,6 @@ from tinygrad.codegen.uopgraph import UOpGraph
from tinygrad.ops import BinaryOps, TernaryOps from tinygrad.ops import BinaryOps, TernaryOps
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
# TODO: can copy this in here when we remove it
#from tinygrad.ops import get_lazyop_info
#info = get_lazyop_info(ast)
#print(ops, mem, expected_mem)
#print(info.flops, info.mem_estimate)
# **************** new FlopCounter **************** # **************** new FlopCounter ****************
def get_stats(x:Tensor): def get_stats(x:Tensor):
@ -21,6 +16,7 @@ def get_stats(x:Tensor):
return ei.prg.op_estimate, ei.prg.mem_estimate return ei.prg.op_estimate, ei.prg.mem_estimate
class TestUOpsStats(unittest.TestCase): class TestUOpsStats(unittest.TestCase):
@unittest.skipIf(getenv("PTX"), "wrong in PTX")
def test_simple_add(self): def test_simple_add(self):
a = Tensor.empty(100,100) a = Tensor.empty(100,100)
b = Tensor.empty(100,100) b = Tensor.empty(100,100)

View File

@ -1,90 +0,0 @@
#!/usr/bin/env python
import unittest
from tinygrad import dtypes, Tensor
from tinygrad.helpers import prod
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, get_lazyop_info, BufferOps, MemBuffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.uops import flops_mem
class TestFlopCounter(unittest.TestCase):
def setUp(self):
self.buf0 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float32, ShapeTracker.from_shape((4,))))
self.buf1 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,))))
self.buf2 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,4))))
def compare_flop_counters(self, ast):
info = get_lazyop_info(ast.src[0])
lin = Kernel(ast)
# NOTE: why does hand coded optimizations change flops for the GEMM?
#lin.hand_coded_optimizations()
lin.linearize()
ops, mem = flops_mem(lin.uops.uops, ignore_indexing=True)
run_count = prod((lin.global_size or []) + (lin.local_size or []))
self.assertEqual(info.flops, ops*run_count)
print(info.flops, info.mem_estimate, "vs", ops*run_count, mem*run_count)
#lin.uops.print()
def test_flops_sin(self):
op0 = LazyOp(UnaryOps.SIN, (self.buf0,), None)
info = get_lazyop_info(op0)
self.assertEqual(info.flops, 4)
def test_flops_add(self):
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
info = get_lazyop_info(op0)
self.assertEqual(info.flops, 4)
def test_flops_add_twice(self):
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
op1 = LazyOp(BinaryOps.ADD, (op0,self.buf1,), None)
info = get_lazyop_info(op1)
self.assertEqual(info.flops, 8)
def test_flops_add_self(self):
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
op1 = LazyOp(BinaryOps.ADD, (op0,op0,), None)
info = get_lazyop_info(op1)
self.assertEqual(info.flops, 8)
def test_flops_add_roundabout_self(self):
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
op1 = LazyOp(BinaryOps.ADD, (op0,self.buf1,), None)
op2 = LazyOp(BinaryOps.ADD, (op0,op1,), None)
info = get_lazyop_info(op2)
self.assertEqual(info.flops, 12)
def test_flops_red(self):
op0 = LazyOp(BinaryOps.MUL, (self.buf0,self.buf1,), None)
op1 = LazyOp(ReduceOps.SUM, (op0,), (0,))
op2 = LazyOp(BinaryOps.ADD, (op1, op1,), None)
info = get_lazyop_info(op2)
self.assertEqual(info.flops, 9)
def test_flops_sum1d(self):
op0 = LazyOp(ReduceOps.SUM, (self.buf0,), (0,))
info = get_lazyop_info(op0)
self.assertEqual(info.flops, 4)
self.assertEqual(info.shape, (1,))
def test_flops_sum2d(self):
op0 = LazyOp(ReduceOps.SUM, (self.buf2,), (0,))
info = get_lazyop_info(op0)
self.assertEqual(info.flops, 16)
self.assertEqual(info.shape, (1,4))
op1 = LazyOp(ReduceOps.SUM, (op0,), (1,))
info = get_lazyop_info(op1)
self.assertEqual(info.flops, 16+4)
self.assertEqual(info.shape, (1,1))
def test_flops_conv(self):
out = Tensor.empty(16,3,16,16).conv2d(Tensor.empty(64,3,3,3))
self.compare_flop_counters(out.schedule()[-1].ast)
def test_flops_gemm(self):
out = Tensor.empty(4,16,16) @ Tensor.empty(4,16,16)
self.compare_flop_counters(out.schedule()[-1].ast)
if __name__ == '__main__':
unittest.main()

View File

@ -4,8 +4,7 @@ from dataclasses import replace
from collections import defaultdict from collections import defaultdict
from typing import Optional, List, Tuple, cast, Dict, Union, Final, DefaultDict from typing import Optional, List, Tuple, cast, Dict, Union, Final, DefaultDict
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, MetaOps, UNSAFE_PAD_OPS, \ from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, MetaOps, UNSAFE_PAD_OPS, verify_lazyop, KernelInfo
verify_lazyop, KernelInfo, get_lazyop_info
from tinygrad.device import Device from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore, Program from tinygrad.renderer import Renderer, TensorCore, Program
from tinygrad.dtype import dtypes, ImageDType from tinygrad.dtype import dtypes, ImageDType
@ -776,8 +775,7 @@ class Kernel:
if getenv("RUN_PROCESS_REPLAY"): if getenv("RUN_PROCESS_REPLAY"):
table_name = f"process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}" table_name = f"process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}"
diskcache_put(table_name, id(self), (self.ast, self.opts, self.applied_opts, name, src, {k:v.value for k,v in ContextVar._cache.items()})) diskcache_put(table_name, id(self), (self.ast, self.opts, self.applied_opts, name, src, {k:v.value for k,v in ContextVar._cache.items()}))
info = get_lazyop_info(self.ast.src[0]) # TODO: this should be removed ops, mem = flops_mem(self.uops.uops, ignore_indexing=True)
ops, mem = flops_mem(self.uops.uops)
run_count = prod((self.global_size or []) + (self.local_size or [])) run_count = prod((self.global_size or []) + (self.local_size or []))
return Program(self.name, src, self.opts.device, self.global_size, self.local_size, return Program(self.name, src, self.opts.device, self.global_size, self.local_size,
self.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count)) self.uops, ops * run_count, min(mem * run_count, sum(arg.dtype.itemsize * arg.st.real_size() for arg in self.membufs)))

View File

@ -3,7 +3,7 @@ from typing import Union, Tuple, Any, List, Dict, Callable
import functools, hashlib, math, operator, ctypes, struct import functools, hashlib, math, operator, ctypes, struct
from enum import Enum, auto from enum import Enum, auto
from dataclasses import dataclass from dataclasses import dataclass
from tinygrad.helpers import prod, dedup, pretty_print from tinygrad.helpers import dedup, pretty_print
from tinygrad.dtype import dtypes, DType, ConstType from tinygrad.dtype import dtypes, DType, ConstType
from tinygrad.shape.symbolic import Variable, sint from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
@ -97,36 +97,6 @@ class LazyOp:
def const(val, dtype:DType, shape:Tuple[sint, ...]): def const(val, dtype:DType, shape:Tuple[sint, ...]):
return LazyOp(BufferOps.CONST, (), ConstBuffer(val, dtype, ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape))) return LazyOp(BufferOps.CONST, (), ConstBuffer(val, dtype, ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape)))
# **************** independent FlopCounter ****************
@dataclass
class FlopCounter:
shape: Tuple[int, ...]
flops: sint
mem: Dict[int, int]
@property
def mem_estimate(self): return sum(self.mem.values())
def consume_flops(self):
self.flops, ret = 0, self.flops
return ret
InterpretedFlopCounter: Dict[Op, Callable] = {
BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, 0, {arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, 0, {}),
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # cast uses no flops
UnaryOps.BITCAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # bitcast uses no flops
**{op:lambda self: FlopCounter(self.shape, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op not in {UnaryOps.CAST, UnaryOps.BITCAST}}, # noqa: E501
**{op:lambda self,y: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
**{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501
TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
@functools.lru_cache(None)
def get_lazyop_info(ast:LazyOp) -> FlopCounter:
@functools.lru_cache(None) # NOTE: this cache needs to be recreated for new ASTs
def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
return run_ast(ast)
# **************** ops in python **************** # **************** ops in python ****************
def hook_overflow(dv, fxn): def hook_overflow(dv, fxn):