add caches there (#3042)

* add caches there

* no curl
This commit is contained in:
George Hotz 2024-01-08 13:02:16 -08:00 committed by GitHub
parent c5a941d466
commit 50754f1494
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 1 deletions

View File

@ -97,7 +97,7 @@ jobs:
- name: Compile C to native
run: clang -O2 recognize.c -lm -o recognize
- name: Test EfficientNet
run: curl https://media.istockphoto.com/photos/hen-picture-id831791190 | ./recognize | grep hen
run: cat test/models/efficientnet/Chicken.jpg | ./recognize | grep cock
testtorch:
name: Torch Tests

View File

@ -82,10 +82,12 @@ class View:
contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
return View(shape, strides, offset, mask, contiguous)
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def vars(self) -> Set[Variable]:
flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], set())
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def unbind(self) -> View:
unbound_vars:Dict[Variable,Node] = {v: v.unbind()[0] for v in self.vars() if v.val is not None}
new_shape = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.shape])