mirror of https://github.com/commaai/tinygrad.git
fix acc init value for MUL (#6263)
This commit is contained in:
parent
a7bf20c7cd
commit
da5cf11859
|
@ -2,7 +2,7 @@ import time, math, unittest
|
|||
import numpy as np
|
||||
from typing import List, Callable
|
||||
import torch
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
import functools
|
||||
|
@ -894,6 +894,7 @@ class TestOps(unittest.TestCase):
|
|||
|
||||
def test_prod(self):
|
||||
helper_test_op(None, lambda x: x.prod(), vals=[[1.0, 2.0, 3.0]])
|
||||
with Context(NOOPT=1): helper_test_op(None, lambda x: x.prod(), vals=[[1.0, 2.0, 3.0]])
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.prod(dim=3), lambda x: x.prod(axis=3))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.prod(dim=1), lambda x: x.prod(axis=1))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.prod(dim=1, keepdim=True), lambda x: x.prod(axis=1, keepdim=True))
|
||||
|
|
|
@ -400,7 +400,8 @@ def do_reduce(root:UOp):
|
|||
ret = root.src[0]
|
||||
if len(reduce_parented):
|
||||
assert root.dtype is not None
|
||||
const = UOp.const(root.dtype, 0 if root.arg is BinaryOps.ADD else dtypes.min(root.dtype.scalar()))
|
||||
# TODO: helper to reuse this in 0 size folding
|
||||
const = UOp.const(root.dtype, {BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(root.dtype.scalar())}[root.arg])
|
||||
acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(reduce_parented), (acc_number,))
|
||||
acc_number += 1
|
||||
ret = UOp(UOps.PHI, root.dtype, (acc, acc.alu(root.arg, ret)))
|
||||
|
|
Loading…
Reference in New Issue