Files
carrotpilot/tinygrad_repo/tinygrad/codegen/lowerer.py
carrot 77a8919349 TR16 Model, fix radar routine (#211)
* UV+DTR model

* DTR model.. again.

* fix naviGPS

* fix radar...

* fix..

* test

* fix..

* carrot serv

* fix..

* fix.. fleet

* fix.. radar

* fix atc

* Steam Powered model..

* fix.. radarLatFactor range.. 200->500

* fix.. dbc..

* side

* SP v2

* brake light

* fix brakelight

* fix..

* add datetime...

* fix..

* fix..

* fix..

* fix..

* blind spot

* fix tz

* fix..

* ff

* radarLatFactor

* fix.. bsd

* Revert "fix.. bsd"

This reverts commit 1d0d143447.

* fix.. bsd side..

* test

* fix.. e2e conditions

* Revert "test"

This reverts commit 0ce791dbd6.

* TR16

* fix cut-in detect threshold  3.4 -> 2.6

* fix.. jerk_l limit 5->10

* fix..

* fix.. gm

* fix.. OPTIMA_H mass

* fix.. radar..

* fix radar..

* fix..

* Radar...

* fix..

* fix..

* fix..

* fix.. radartrack 3

* fix..

* fix..

* fix..

* merge..

* fix.. canfd

* fix..

* fix..

* fix..

* fix.. radard

* new cut_in

* Revert "new cut_in"

This reverts commit b9b6e9b333.

* fix..

* new cut_in detect...

* fix.. disp..

* fix..

* fix..

* fix.. center radar..

* fix.. radar y_sane..

* fix..

* fix..

* hkg jerk 10 -> 5

* fix..

* fix..

* fix.. radar dbc..

* fix..

* fix.. jLead filter..

* test new radar interface..

* fix..

* fix..

* test time...

* Revert "test time..."

This reverts commit 63e9187736.

* fix radar..

* fix..

* FireHose model..

* tinygrad

* Update interface.py

* fix..

* fix.. nff toyota corolla_tss2

* fix..

* fix..

* fix.. radar

* fix..

* fix.. radar, y_gate

* fix.. radar..

* fix.. for clone..

* scc radar enable at low speed..

* fix.. settings..

* fix.

* fix..

* fix.. radarTimeStep.

* TR16 model again..

* RELEASE.md

* fix cut-in detection...

* fix.. registeration timeout 15sec..

* fix..

* fix.. radar processing.

* fix..

* fix..

* fix..

* fix..

* fix..

* fix..
2025-09-05 15:43:10 +09:00

98 lines
4.6 KiB
Python

# the job of the lowerer is to do indexing
import functools, operator
from typing import cast
from dataclasses import dataclass
from tinygrad.dtype import dtypes, AddrSpace, PtrDType
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite
# ***** indexing *****
@dataclass
class IndexContext:
axis_types: tuple[AxisType, ...]
idxs: list[UOp]
start: int = 0
def shape_to_idx(s, axis_types, start=0):
return [UOp.range(dtypes.int, sint_to_uop(s), start+i, axistype=at) for i, (s, at) in enumerate(zip(s, axis_types))]
def get_index(ast:UOp) -> IndexContext:
axis_types = ast.arg.axis_types if isinstance(ast.arg, KernelInfo) else ()
if len(ast.full_shape) != len(axis_types):
axis_types = tuple([AxisType.REDUCE if s is not fs else AxisType.LOOP for s,fs in zip(ast.shape, ast.full_shape)])
return IndexContext(axis_types, [], 0)
# ***** lowering (given index) *****
def subblock(ctx: IndexContext, full_new_idx: list[UOp], src: UOp):
lc = IndexContext(ctx.axis_types, full_new_idx, ctx.start+1000)
ctx.start = lc.start
return graph_rewrite(src, pm_lowerer, lc, name="subblock", bottom_up=True)
def lower_reduce_axis(ctx: IndexContext, x: UOp):
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
full_new_idx = list(ctx.idxs)
for a in x.axis_arg: full_new_idx[a] = new_idxs[a]
ret = subblock(ctx, full_new_idx, x.src[0])
return UOp(Ops.REDUCE, x.dtype, (ret,)+tuple([full_new_idx[i] for i in x.axis_arg]), x.arg[0])
def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
# TODO: reenable after REDUCE_AXIS is fixed
#assert x.src[1].shape == x.src[0].shape, f"shape mismatch on store {x.src[1].shape} != {x.src[0].shape}"
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
idx, valid = x.st_arg.to_indexed_uops(new_idxs)
used_idxs = [x for x in UOp.sink(idx, valid).toposort() if x in new_idxs]
real_new_idxs = []
for i in range(len(x.src[0].shape)):
if new_idxs[i] in used_idxs or len(ctx.idxs) <= i: real_new_idxs.append(new_idxs[i])
else: real_new_idxs.append(ctx.idxs[i])
stored = subblock(ctx, real_new_idxs, x.src[1])
used_ranges = [x for x in used_idxs if x.op is Ops.RANGE]
ret = buf.index(idx, valid).store(stored, *used_ranges)
# insert BARRIER if we are ending a LOCAL, IF if we are ending a GROUP_REDUCE
if cast(PtrDType, buf.dtype).addrspace == AddrSpace.LOCAL and \
any(ctx.axis_types[x.arg[0]%1000] in {AxisType.GROUP_REDUCE, AxisType.LOCAL} for x in used_ranges):
ret = ret.barrier()
range_gates = [x.eq(0) for x in used_ranges if ctx.axis_types[x.arg[0]%1000] == AxisType.GROUP_REDUCE]
if len(range_gates): ret = UOp(Ops.IF, src=(functools.reduce(operator.and_, range_gates), ret))
return ret
def fixup_wmma(ctx:IndexContext, x:UOp):
if x.tag is not None: return None
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
full_new_idx = list(ctx.idxs)
for a in x.arg[-1]: full_new_idx[a] = new_idxs[a]
srcs = subblock(ctx, full_new_idx, UOp.sink(*x.src)).src
# NOTE: this assumes these are expanded. which now shouldn't change anything
new_x_arg_m2 = tuple([tuple([(full_new_idx[a].arg[0], sz) for a,sz in v]) for v in x.arg[-2]])
new_x_arg_m1 = tuple([full_new_idx[a].arg[0] for a in x.arg[-1]])
return x.replace(src=srcs, arg=x.arg[:-2]+(new_x_arg_m2, new_x_arg_m1), tag=1)
pm_lowerer = PatternMatcher([
# TODO: remove these hacks
# hack for old style CONST(VIEW) (now it's just VIEW(CONST))
(UPat((Ops.DEFINE_VAR, Ops.CONST), src=(UPat(Ops.VIEW, name="v"),), name="c"), lambda c,v: c.replace(src=()).view(v.arg)),
# hack for old style VALID (now it's just VIEW(CONST))
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c"), UPat(Ops.CONST, arg=0)), lambda c,v: c.replace(src=()).view(v.arg)),
# consts and loads
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),), name="view"),
lambda ctx,view,c: c if all(x.mask is None for x in view.arg.views) else view.arg.to_indexed_uops(ctx.idxs)[1].where(c, c.const_like(0))),
(UPat(Ops.LOAD, src=(UPat.var("buf").view(),), allow_any_len=True, name="x"),
lambda ctx,buf,x: UOp(Ops.LOAD, x.dtype, (buf.index(*x.st_arg.to_indexed_uops(ctx.idxs)),)+x.src[1:])),
# reduce/view_const
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
(UPat(Ops.STORE, src=(UPat.var("buf").view(),), allow_any_len=True, name="x"), lower_store),
(UPat(Ops.WMMA, name="x"), fixup_wmma),
# axis fixups for WMMA
(UPat((Ops.CONTRACT, Ops.UNROLL), name="x"),
lambda ctx,x: x.replace(tag=1, arg=tuple([(ctx.idxs[a].arg[0], sz) for a,sz in x.arg])) if x.tag is None else None),
])