mirror of https://github.com/commaai/tinygrad.git
A beautiful MNIST training example (#2272)
* beautiful mnist * beautiful mnist example * from tinygrad import Tensor * more beautiful * the jit is super core tinygrad * globalcounters reset on jit run * symlinks and exclude * beautiful_cartpole * evaluate is it's own function * no symlinks * more beautiful * jit reset for double speed * type hinting for JIT * beautiful_mnist gets 98% * beautiful_mnist < 4s with BEAM=2 * better cartpole * use actor critic * zero_grad got lost * delete double relu * stable cartpole with PPO * beautiful_cartpole is more beautiful * REPLAY_BUFFER * beautiful stuff typechecks * None support in shape * hp tuning
This commit is contained in:
parent
74e6b6c9fc
commit
c7b38b324b
|
@ -0,0 +1,114 @@
|
|||
from typing import Tuple
|
||||
import time
|
||||
from tinygrad import Tensor, TinyJit, nn, Variable
|
||||
from tinygrad.helpers import dtypes # TODO: wouldn't need this if argmax returned the right dtype
|
||||
import gymnasium as gym
|
||||
from tqdm import trange
|
||||
import numpy as np # TODO: remove numpy import
|
||||
|
||||
class ActorCritic:
|
||||
def __init__(self, in_features, out_features, hidden_state=32):
|
||||
self.l1 = nn.Linear(in_features, hidden_state)
|
||||
self.l2 = nn.Linear(hidden_state, out_features)
|
||||
|
||||
self.c1 = nn.Linear(in_features, hidden_state)
|
||||
self.c2 = nn.Linear(hidden_state, 1)
|
||||
|
||||
def __call__(self, obs:Tensor) -> Tuple[Tensor, Tensor]:
|
||||
x = self.l1(obs).tanh()
|
||||
act = self.l2(x).log_softmax()
|
||||
x = self.c1(obs).relu()
|
||||
return act, self.c2(x)
|
||||
|
||||
def evaluate(model:ActorCritic, test_env:gym.Env) -> float:
|
||||
(obs, _), terminated, truncated = test_env.reset(), False, False
|
||||
total_rew = 0.0
|
||||
while not terminated and not truncated:
|
||||
act = model(Tensor(obs))[0].argmax().cast(dtypes.int32).item()
|
||||
obs, rew, terminated, truncated, _ = test_env.step(act)
|
||||
total_rew += float(rew)
|
||||
return total_rew
|
||||
|
||||
# TODO: time should be < 5s on M1 Max
|
||||
if __name__ == "__main__":
|
||||
env = gym.make('CartPole-v1')
|
||||
|
||||
model = ActorCritic(env.observation_space.shape[0], int(env.action_space.n)) # type: ignore
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(model), lr=1e-2)
|
||||
|
||||
@TinyJit
|
||||
def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
with Tensor.train():
|
||||
log_dist, value = model(x)
|
||||
|
||||
# get advantage
|
||||
advantage = reward.reshape(-1, 1) - value
|
||||
mask = selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)
|
||||
masked_advantage = mask * advantage.detach()
|
||||
|
||||
# PPO
|
||||
ratios = (log_dist - old_log_dist).exp() * masked_advantage
|
||||
clipped_ratios = ratios.clip(1-0.2, 1+0.2) * masked_advantage
|
||||
action_loss = -ratios.minimum(clipped_ratios).sum(-1).mean()
|
||||
|
||||
entropy_loss = (log_dist.exp() * log_dist).sum(-1).mean() # this encourages diversity
|
||||
critic_loss = advantage.square().mean()
|
||||
opt.zero_grad()
|
||||
(action_loss + entropy_loss*0.0005 + critic_loss).backward()
|
||||
opt.step()
|
||||
return action_loss.realize(), entropy_loss.realize(), critic_loss.realize()
|
||||
|
||||
@TinyJit
|
||||
def get_action_dist(obs:Tensor) -> Tensor:
|
||||
# TODO: with no_grad
|
||||
Tensor.no_grad = True
|
||||
ret = model(obs)[0].exp().realize()
|
||||
Tensor.no_grad = False
|
||||
return ret
|
||||
|
||||
BS = 256
|
||||
MAX_REPLAY_BUFFER = 2000
|
||||
st, steps = time.perf_counter(), 0
|
||||
Xn, An, Rn = [], [], []
|
||||
for i in (t:=trange(40)):
|
||||
get_action_dist.reset() # NOTE: if you don't reset the jit here it captures the wrong model on the first run through
|
||||
|
||||
obs:np.ndarray = env.reset()[0]
|
||||
rews, terminated, truncated = [], False, False
|
||||
# NOTE: we don't want to early stop since then the rewards are wrong for the last episode
|
||||
while not terminated and not truncated:
|
||||
# pick actions
|
||||
# TODO: move the multinomial into jitted tinygrad when JIT rand works
|
||||
# TODO: what's the temperature here?
|
||||
act = get_action_dist(Tensor(obs)).multinomial().item()
|
||||
|
||||
# save this state action pair
|
||||
# TODO: don't use np.copy here on the CPU, what's the tinygrad way to do this and keep on device? need __setitem__ assignment
|
||||
Xn.append(np.copy(obs))
|
||||
An.append(act)
|
||||
|
||||
obs, rew, terminated, truncated, _ = env.step(act)
|
||||
rews.append(float(rew))
|
||||
steps += len(rews)
|
||||
|
||||
# reward to go
|
||||
# TODO: move this into tinygrad
|
||||
discounts = np.power(0.99, np.arange(len(rews)))
|
||||
Rn += [np.sum(rews[i:] * discounts[:len(rews)-i]) for i in range(len(rews))]
|
||||
|
||||
Xn, An, Rn = Xn[-MAX_REPLAY_BUFFER:], An[-MAX_REPLAY_BUFFER:], Rn[-MAX_REPLAY_BUFFER:]
|
||||
X, A, R = Tensor(Xn), Tensor(An), Tensor(Rn)
|
||||
|
||||
# TODO: make this work
|
||||
#vsz = Variable("sz", 1, MAX_REPLAY_BUFFER-1).bind(len(Xn))
|
||||
#X, A, R = Tensor(Xn).reshape(vsz, None), Tensor(An).reshape(vsz), Tensor(Rn).reshape(vsz)
|
||||
|
||||
old_log_dist = model(X)[0] # TODO: could save these instead of recomputing
|
||||
for i in range(5):
|
||||
samples = Tensor.randint(BS, high=X.shape[0]).realize() # TODO: remove the need for this
|
||||
# TODO: is this recompiling based on the shape?
|
||||
action_loss, entropy_loss, critic_loss = train_step(X[samples], A[samples], R[samples], old_log_dist[samples])
|
||||
t.set_description(f"sz: {len(Xn):5d} steps/s: {steps/(time.perf_counter()-st):7.2f} action_loss: {action_loss.item():7.2f} entropy_loss: {entropy_loss.item():7.2f} critic_loss: {critic_loss.item():7.2f} reward: {sum(rews):6.2f}")
|
||||
|
||||
test_rew = evaluate(model, gym.make('CartPole-v1', render_mode='human'))
|
||||
print(f"test reward: {test_rew}")
|
|
@ -0,0 +1,47 @@
|
|||
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
|
||||
from typing import List, Callable
|
||||
from tinygrad import Tensor, TinyJit, nn
|
||||
from tinygrad.helpers import GlobalCounters
|
||||
from extra.datasets import fetch_mnist
|
||||
from tqdm import trange
|
||||
|
||||
class Model:
|
||||
def __init__(self):
|
||||
self.layers: List[Callable[[Tensor], Tensor]] = [
|
||||
nn.Conv2d(1, 32, 5), Tensor.relu,
|
||||
nn.Conv2d(32, 32, 5, bias=False),
|
||||
nn.BatchNorm2d(32), Tensor.relu, Tensor.max_pool2d,
|
||||
nn.Conv2d(32, 64, 3), Tensor.relu,
|
||||
nn.Conv2d(64, 64, 3, bias=False),
|
||||
nn.BatchNorm2d(64), Tensor.relu, Tensor.max_pool2d,
|
||||
lambda x: x.flatten(1), nn.Linear(576, 10)]
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
|
||||
|
||||
if __name__ == "__main__":
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist(tensors=True)
|
||||
|
||||
model = Model()
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(model))
|
||||
|
||||
# TODO: there's a compiler error if you comment out TinyJit since randint isn't being realized and there's something weird with int
|
||||
@TinyJit
|
||||
def train_step(samples:Tensor) -> Tensor:
|
||||
with Tensor.train():
|
||||
opt.zero_grad()
|
||||
# TODO: this "gather" of samples is very slow and not the desired way to do things in practice
|
||||
# will be under 5s when this is fixed
|
||||
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
|
||||
opt.step()
|
||||
return loss.realize()
|
||||
|
||||
@TinyJit
|
||||
def get_test_acc() -> Tensor: return ((model(X_test).argmax(axis=1) == Y_test).mean()*100).realize()
|
||||
|
||||
test_acc = float('nan')
|
||||
for i in (t:=trange(70)):
|
||||
GlobalCounters.reset()
|
||||
samples = Tensor.randint(512, high=X_train.shape[0]) # TODO: put this in the JIT when rand is fixed
|
||||
loss = train_step(samples)
|
||||
if i%10 == 9: test_acc = get_test_acc().item()
|
||||
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")
|
|
@ -5,14 +5,15 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.helpers import dtypes
|
||||
from extra.utils import download_file
|
||||
|
||||
def fetch_mnist():
|
||||
def fetch_mnist(tensors=False):
|
||||
parse = lambda file: np.frombuffer(gzip.open(file).read(), dtype=np.uint8).copy()
|
||||
dirname = Path(__file__).parent.resolve()
|
||||
X_train = parse(dirname / "mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28*28)).astype(np.float32)
|
||||
Y_train = parse(dirname / "mnist/train-labels-idx1-ubyte.gz")[8:]
|
||||
X_test = parse(dirname / "mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28*28)).astype(np.float32)
|
||||
Y_test = parse(dirname / "mnist/t10k-labels-idx1-ubyte.gz")[8:]
|
||||
return X_train, Y_train, X_test, Y_test
|
||||
if tensors: return Tensor(X_train).reshape(-1, 1, 28, 28), Tensor(Y_train), Tensor(X_test).reshape(-1, 1, 28, 28), Tensor(Y_test)
|
||||
else: return X_train, Y_train, X_test, Y_test
|
||||
|
||||
cifar_mean = [0.4913997551666284, 0.48215855929893703, 0.4465309133731618]
|
||||
cifar_std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628]
|
||||
|
|
|
@ -23,7 +23,7 @@ class TinyBobNet:
|
|||
return get_parameters(self)
|
||||
|
||||
def forward(self, x):
|
||||
return x.dot(self.l1).relu().dot(self.l2).log_softmax()
|
||||
return x.dot(self.l1).relu().dot(self.l2)
|
||||
|
||||
# create a model with a conv layer
|
||||
class TinyConvNet:
|
||||
|
@ -49,7 +49,7 @@ class TinyConvNet:
|
|||
x = self.bn1(x.conv2d(self.c1)).relu().max_pool2d()
|
||||
x = self.bn2(x.conv2d(self.c2)).relu().max_pool2d()
|
||||
x = x.reshape(shape=[x.shape[0], -1])
|
||||
return x.dot(self.l1).log_softmax()
|
||||
return x.dot(self.l1)
|
||||
|
||||
class TestMNIST(unittest.TestCase):
|
||||
def test_sgd_onestep(self):
|
||||
|
|
|
@ -19,8 +19,8 @@ class TestNN(unittest.TestCase):
|
|||
loss = loss_fun(input, target)
|
||||
|
||||
input_tiny = Tensor(input.detach().numpy())
|
||||
taret_tiny = Tensor(target.detach().numpy())
|
||||
loss_tiny = input_tiny.sparse_categorical_crossentropy(taret_tiny)
|
||||
target_tiny = Tensor(target.detach().numpy())
|
||||
loss_tiny = input_tiny.sparse_categorical_crossentropy(target_tiny)
|
||||
|
||||
np.testing.assert_allclose(loss_tiny.numpy(), loss.detach().numpy(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from tinygrad.tensor import Tensor # noqa: F401
|
||||
from tinygrad.jit import TinyJit # noqa: F401
|
||||
from tinygrad.shape.symbolic import Variable # noqa: F401
|
|
@ -22,7 +22,7 @@ actions += [
|
|||
|
||||
# returns time in seconds
|
||||
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float:
|
||||
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size}
|
||||
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2}
|
||||
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
|
||||
|
||||
# Set the midpoint value value for var_vals to optimize shapes.
|
||||
|
|
|
@ -174,7 +174,7 @@ _cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches"
|
|||
CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db")))
|
||||
CACHELEVEL = getenv("CACHELEVEL", 2)
|
||||
|
||||
VERSION = 7
|
||||
VERSION = 8
|
||||
_db_connection = None
|
||||
def db_connection():
|
||||
global _db_connection
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import annotations
|
||||
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
|
||||
from typing import Callable, List, Tuple, Dict, cast, Union, Optional, TypeVar, Generic
|
||||
import functools, itertools
|
||||
from tinygrad.helpers import DEBUG, DType, merge_dicts
|
||||
from tinygrad.ops import RawBuffer, Device, ASTRunner, BatchExecutor, JitItem
|
||||
|
@ -8,12 +8,17 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
|||
from tinygrad.shape.symbolic import Variable
|
||||
from weakref import ref, WeakKeyDictionary
|
||||
|
||||
class TinyJit:
|
||||
def __init__(self, fxn:Callable):
|
||||
self.fxn: Callable = fxn
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
|
||||
class TinyJit(Generic[ReturnType]):
|
||||
def __init__(self, fxn:Callable[..., ReturnType]):
|
||||
self.fxn = fxn
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.jit_fxn: Optional[BatchExecutor] = None
|
||||
self.cnt: int = 0
|
||||
self.ret: Any = None
|
||||
self.ret: Optional[ReturnType] = None
|
||||
self.expected_vals: Optional[Tuple[Variable, ...]] = None
|
||||
self.expected_sts_dtype: Optional[Tuple[Tuple[ShapeTracker, DType], ...]] = None
|
||||
|
||||
|
@ -25,14 +30,13 @@ class TinyJit:
|
|||
# add support for instance methods
|
||||
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
def __call__(self, *args, **kwargs) -> ReturnType:
|
||||
# all inputs are realized
|
||||
input_tensors: Dict[Union[int, str], Tensor] = {cast(Union[int, str], k):v.realize() for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor}
|
||||
expected_sts_dtype = tuple([(v.lazydata.st.unbind(), v.dtype) for v in input_tensors.values()])
|
||||
|
||||
# get rawbuffers
|
||||
input_rawbuffers: Dict[Union[int, str], RawBuffer] = {k:cast(RawBuffer, v.lazydata.realized) for k,v in input_tensors.items()}
|
||||
assert len(input_rawbuffers) != 0, "no inputs to JIT"
|
||||
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
|
||||
|
||||
# get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global
|
||||
|
@ -41,7 +45,7 @@ class TinyJit:
|
|||
|
||||
if self.cnt >= 2:
|
||||
assert self.expected_vals == expected_vals, "mismatch of var_vals"
|
||||
assert self.expected_sts_dtype == expected_sts_dtype, "mismatch of sts"
|
||||
assert self.expected_sts_dtype == expected_sts_dtype, f"mismatch of sts, expected {self.expected_sts_dtype} got {expected_sts_dtype}"
|
||||
assert self.jit_fxn, "didn't get jitted?"
|
||||
self.jit_fxn(input_rawbuffers, var_vals, DEBUG>=2)
|
||||
elif self.cnt == 1:
|
||||
|
@ -58,7 +62,7 @@ class TinyJit:
|
|||
self.ret = self.fxn(*args, **kwargs)
|
||||
|
||||
self.cnt += 1
|
||||
return self.ret
|
||||
return cast(ReturnType, self.ret)
|
||||
|
||||
class PlaceHolder:
|
||||
def __init__(self, buf:RawBuffer): self.size, self.dtype, self._device, self.ref, self.buftype, self.bufid = buf.size, buf.dtype, getattr(buf, '_device', None), ref(buf), type(buf), id(buf._buf)
|
||||
|
|
|
@ -2,9 +2,10 @@ import math
|
|||
from typing import Optional, Union, Tuple
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import prod, all_int
|
||||
from tinygrad.nn import optim, state # noqa: F401
|
||||
|
||||
class BatchNorm2d:
|
||||
def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
|
||||
def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
|
||||
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
|
||||
|
||||
if affine: self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz)
|
||||
|
|
|
@ -361,7 +361,7 @@ def get_optimized_program(linearizer_opts:LinearizerOptions, to_program, ast:Laz
|
|||
if used_tensor_cores:
|
||||
lins.append(("hc", Linearizer(ast, linearizer_opts)))
|
||||
lins[-1][1].hand_coded_optimizations()
|
||||
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, disable_cache=True, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
|
||||
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
|
||||
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
|
||||
k = timed[0][1]
|
||||
else:
|
||||
|
|
|
@ -181,6 +181,10 @@ class Tensor:
|
|||
src = Tensor.rand(2, *shape, **kwargs)
|
||||
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
|
||||
|
||||
@staticmethod
|
||||
def randint(*shape, low=0, high=10, **kwargs) -> Tensor:
|
||||
return (Tensor.rand(*shape, **kwargs)*(high-low)+low).cast(dtypes.int32)
|
||||
|
||||
@staticmethod
|
||||
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean
|
||||
|
||||
|
@ -213,9 +217,9 @@ class Tensor:
|
|||
assert replacement or num_samples == 1, "supported only with replacement"
|
||||
p = self.unsqueeze(0) if self.ndim == 1 else self
|
||||
cdf = p.cumsum(1)
|
||||
cdf /= cdf[:, -1].unsqueeze(1)
|
||||
cdf_normalized = cdf / cdf[:, -1].unsqueeze(1)
|
||||
unif_samples = Tensor.rand(num_samples, p.shape[0], 1)
|
||||
indices = (unif_samples.expand((-1, -1, p.shape[1])) >= cdf).sum(2).permute((1, 0))
|
||||
indices = (unif_samples.expand((-1, -1, p.shape[1])) >= cdf_normalized).sum(2).permute((1, 0))
|
||||
if self.ndim == 1: indices = indices.squeeze(0)
|
||||
return indices.cast(dtypes.int32)
|
||||
|
||||
|
@ -230,7 +234,7 @@ class Tensor:
|
|||
return nodes
|
||||
return _deepwalk(self, set(), [])
|
||||
|
||||
def backward(self):
|
||||
def backward(self) -> Tensor:
|
||||
assert self.shape == tuple(), f"backward can only be called for scalar tensors, but it has shape {self.shape})"
|
||||
|
||||
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
|
||||
|
@ -247,11 +251,12 @@ class Tensor:
|
|||
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
||||
t.grad = g if t.grad is None else (t.grad + g)
|
||||
del t0._ctx
|
||||
return self
|
||||
|
||||
# ***** movement mlops *****
|
||||
def reshape(self, shape, *args) -> Tensor:
|
||||
new_shape = argfix(shape, *args)
|
||||
return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]))
|
||||
return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)]))
|
||||
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
|
||||
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
|
||||
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
|
||||
|
@ -770,6 +775,7 @@ class Tensor:
|
|||
return (self.maximum(0) - y * self + (1 + self.abs().__neg__().exp()).log()).mean()
|
||||
|
||||
def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor:
|
||||
# NOTE: self is a logits input
|
||||
loss_mask = Y != ignore_index
|
||||
y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
|
||||
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
||||
|
|
Loading…
Reference in New Issue