mirror of https://github.com/commaai/tinygrad.git
35 lines
1.4 KiB
Python
35 lines
1.4 KiB
Python
import argparse
|
|
from extra.optimization.helpers import ast_str_to_lin
|
|
|
|
from tinygrad import dtypes
|
|
from tinygrad.helpers import BEAM, getenv
|
|
from tinygrad.device import Device, Compiled
|
|
from tinygrad.codegen.linearizer import Linearizer
|
|
from tinygrad.features.search import time_linearizer, beam_search, bufs_from_lin
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description="Run a search for the optimal opts for a kernel", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
parser.add_argument("--ast", type=str, default=None, help="the ast for the kernel to be optimized")
|
|
parser.add_argument("--file", type=str, default=None, help="a file containing asts to be optimized, one per line")
|
|
args = parser.parse_args()
|
|
|
|
device: Compiled = Device[Device.DEFAULT]
|
|
print(f"optimizing for {Device.DEFAULT}")
|
|
|
|
if args.ast is not None:
|
|
ast_strs = [args.ast]
|
|
elif args.file is not None:
|
|
with open(args.file, 'r') as file:
|
|
ast_strs = file.readlines()
|
|
|
|
for i, ast_str in enumerate(ast_strs):
|
|
print(f"optimizing {i}/{len(ast_strs)}\nast={ast_str}")
|
|
lin = ast_str_to_lin(ast_str, opts=device.compiler.compiler_opts)
|
|
rawbufs = bufs_from_lin(lin)
|
|
lin = beam_search(lin, rawbufs, getenv("BEAM", 8), bool(getenv("BEAM_ESTIMATE", 1)))
|
|
|
|
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
|
|
print(f"final time {tm*1e6:9.0f} us: {lin.colored_shape()}")
|
|
print(lin.applied_opts)
|