regenerate kernel ast dataset (#2968)

added back the log ast function and removed hacks that work around the old dataset
This commit is contained in:
chenyu 2024-01-01 20:26:17 -05:00 committed by GitHub
parent cc2969f690
commit b1d9e54ea3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 8 additions and 14 deletions

Binary file not shown.

View File

@ -12,6 +12,9 @@ WINO=1 STEPS=3 python3 examples/hlb_cifar10.py
python3 examples/stable_diffusion.py --noshow
python3 examples/llama.py --prompt "hello" --count 5
python3 examples/gpt2.py --count 5
python3 HALF=1 examples/gpt2.py --count 5
python3 python examples/beautiful_mnist.py
python3 python examples/beautiful_cartpole.py
python3 examples/mlperf/model_spec.py
python3 examples/yolov8.py ./test/models/efficientnet/Chicken.jpg
openpilot/go.sh

View File

@ -3,21 +3,13 @@ from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, Buf
from tinygrad.dtype import dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.shape.symbolic import Variable
from tinygrad.shape.symbolic import Variable, NumNode
inf, nan = float('inf'), float('nan')
# HACK: it used to be called MEM
setattr(BufferOps, "MEM", BufferOps.LOAD)
# HACK: no more NOOP
setattr(UnaryOps, "NOOP", UnaryOps.NEG)
# kernel unpacker
from tinygrad.codegen.linearizer import Linearizer
def ast_str_to_ast(ast_str:str) -> LazyOp: return eval(ast_str)
def ast_str_to_lin(ast_str:str):
# HACK: it used to not have stores
from test.test_linearizer_failures import helper_add_store
return Linearizer(helper_add_store(ast_str_to_ast(ast_str)))
def ast_str_to_lin(ast_str:str): return Linearizer(ast_str_to_ast(ast_str))
# load worlds, a dataset of about 12k kernels
import gzip
@ -27,9 +19,6 @@ from tinygrad.helpers import dedup
def load_worlds(filter_reduce=True, filter_noimage=True, filter_novariable=True):
fn = Path(__file__).parent.parent / "datasets/sops.gz"
ast_strs = dedup(gzip.open(fn).read().decode('utf-8').strip().split("\n"))
# HACK: TernaryOps.WHERE has vin[0] as non-bool in the data set
ignore_ops = ["TernaryOps.WHERE", "BinaryOps.CMPLT", "BinaryOps.MAX", "BinaryOps.ADD", "BinaryOps.DIV", "BinaryOps.MUL"]
ast_strs = [x for x in ast_strs if not any(y in x for y in ignore_ops)]
if filter_reduce: ast_strs = [x for x in ast_strs if "ReduceOps" in x]
if filter_noimage: ast_strs = [x for x in ast_strs if "dtypes.image" not in x]
if filter_novariable: ast_strs = [x for x in ast_strs if "Variable" not in x]

View File

@ -2,7 +2,7 @@ from typing import List, Dict, Optional, cast
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters
from tinygrad.device import Device, Buffer, BufferCopy, JITRunner, update_stats
from tinygrad.graph import print_tree, realized_lazybuffer
from tinygrad.helpers import prod, colored
from tinygrad.helpers import prod, colored, getenv
from tinygrad.shape.symbolic import Variable
# *** schedule running ***
@ -20,9 +20,11 @@ def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
if si.ast.op is LoadOps.CUSTOM: return CustomOp(si.ast.arg)
return Device[si.out.device].get_runner(si.ast)
logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
def run_schedule(schedule:List[ScheduleItem]):
while len(schedule):
si = schedule.pop(0)
if logops and si.ast.op not in LoadOps: logops.write(str(si.ast)+"\n")
# get the program
prg = lower_schedule_item(si)