fix acc init value for MUL (#6263)

This commit is contained in:
chenyu 2024-08-23 23:19:44 -04:00 committed by GitHub
parent a7bf20c7cd
commit da5cf11859
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 2 deletions

View File

@ -2,7 +2,7 @@ import time, math, unittest
import numpy as np
from typing import List, Callable
import torch
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context
from tinygrad import Tensor, Device, dtypes
from tinygrad.tensor import _to_np_dtype
import functools
@ -894,6 +894,7 @@ class TestOps(unittest.TestCase):
def test_prod(self):
helper_test_op(None, lambda x: x.prod(), vals=[[1.0, 2.0, 3.0]])
with Context(NOOPT=1): helper_test_op(None, lambda x: x.prod(), vals=[[1.0, 2.0, 3.0]])
helper_test_op([(3,4,5,6)], lambda x: x.prod(dim=3), lambda x: x.prod(axis=3))
helper_test_op([(3,4,5,6)], lambda x: x.prod(dim=1), lambda x: x.prod(axis=1))
helper_test_op([(3,4,5,6)], lambda x: x.prod(dim=1, keepdim=True), lambda x: x.prod(axis=1, keepdim=True))

View File

@ -400,7 +400,8 @@ def do_reduce(root:UOp):
ret = root.src[0]
if len(reduce_parented):
assert root.dtype is not None
const = UOp.const(root.dtype, 0 if root.arg is BinaryOps.ADD else dtypes.min(root.dtype.scalar()))
# TODO: helper to reuse this in 0 size folding
const = UOp.const(root.dtype, {BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(root.dtype.scalar())}[root.arg])
acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(reduce_parented), (acc_number,))
acc_number += 1
ret = UOp(UOps.PHI, root.dtype, (acc, acc.alu(root.arg, ret)))