From da5cf1185911bfc639ecdc0c7f738df4f166cb40 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 23 Aug 2024 23:19:44 -0400 Subject: [PATCH] fix acc init value for MUL (#6263) --- test/test_ops.py | 3 ++- tinygrad/codegen/uopgraph.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 8e5de5fd..f7f8c171 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2,7 +2,7 @@ import time, math, unittest import numpy as np from typing import List, Callable import torch -from tinygrad.helpers import getenv, IMAGE, DEBUG, CI +from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context from tinygrad import Tensor, Device, dtypes from tinygrad.tensor import _to_np_dtype import functools @@ -894,6 +894,7 @@ class TestOps(unittest.TestCase): def test_prod(self): helper_test_op(None, lambda x: x.prod(), vals=[[1.0, 2.0, 3.0]]) + with Context(NOOPT=1): helper_test_op(None, lambda x: x.prod(), vals=[[1.0, 2.0, 3.0]]) helper_test_op([(3,4,5,6)], lambda x: x.prod(dim=3), lambda x: x.prod(axis=3)) helper_test_op([(3,4,5,6)], lambda x: x.prod(dim=1), lambda x: x.prod(axis=1)) helper_test_op([(3,4,5,6)], lambda x: x.prod(dim=1, keepdim=True), lambda x: x.prod(axis=1, keepdim=True)) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index d856b4c8..712d4df5 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -400,7 +400,8 @@ def do_reduce(root:UOp): ret = root.src[0] if len(reduce_parented): assert root.dtype is not None - const = UOp.const(root.dtype, 0 if root.arg is BinaryOps.ADD else dtypes.min(root.dtype.scalar())) + # TODO: helper to reuse this in 0 size folding + const = UOp.const(root.dtype, {BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(root.dtype.scalar())}[root.arg]) acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(reduce_parented), (acc_number,)) acc_number += 1 ret = UOp(UOps.PHI, root.dtype, (acc, acc.alu(root.arg, ret)))