try out attr.s for its converter (doesn't work recursively yet, but interesting!)

This commit is contained in:
Shane Smiskol 2024-08-12 23:02:00 -07:00
parent 09bb822fab
commit ff2434f7bb
1 changed files with 48 additions and 6 deletions

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass as _dataclass, field, is_dataclass
import attr
from dataclasses import field
from enum import Enum, StrEnum as _StrEnum, auto
from typing import dataclass_transform, get_origin
@ -9,6 +10,27 @@ def auto_field():
return auto_obj
# @dataclass_transform()
# def auto_dataclass(cls=None, /, **kwargs):
# cls_annotations = cls.__dict__.get('__annotations__', {})
# for name, typ in cls_annotations.items():
# current_value = getattr(cls, name, None)
# if current_value is auto_obj:
# origin_typ = get_origin(typ) or typ
# if isinstance(origin_typ, str):
# raise TypeError(f"Forward references are not supported for auto_field: '{origin_typ}'. Use a default_factory with lambda instead.")
# elif origin_typ in (int, float, str, bytes, list, tuple, set, dict, bool) or is_dataclass(origin_typ):
# setattr(cls, name, field(default_factory=origin_typ))
# elif origin_typ is None:
# setattr(cls, name, field(default=origin_typ))
# elif issubclass(origin_typ, Enum): # first enum is the default
# setattr(cls, name, field(default=next(iter(origin_typ))))
# else:
# raise TypeError(f"Unsupported type for auto_field: {origin_typ}")
#
# return _dataclass(cls, **kwargs)
@dataclass_transform()
def auto_dataclass(cls=None, /, **kwargs):
cls_annotations = cls.__dict__.get('__annotations__', {})
@ -18,16 +40,24 @@ def auto_dataclass(cls=None, /, **kwargs):
origin_typ = get_origin(typ) or typ
if isinstance(origin_typ, str):
raise TypeError(f"Forward references are not supported for auto_field: '{origin_typ}'. Use a default_factory with lambda instead.")
elif origin_typ in (int, float, str, bytes, list, tuple, set, dict, bool) or is_dataclass(origin_typ):
setattr(cls, name, field(default_factory=origin_typ))
elif origin_typ in (int, float, str, bytes, list, tuple, set, dict, bool):
setattr(cls, name, attr.attr(factory=origin_typ))
elif attr.has(origin_typ):
def convert(data, _origin_typ=origin_typ):
print('got data', data)
if attr.has(data):# or (not any(isinstance(k, dict) for k in data) and len(data)):
return data
# print('ret cls **data', cls)
return _origin_typ(**data)
setattr(cls, name, attr.attr(factory=origin_typ, converter=convert))
elif origin_typ is None:
setattr(cls, name, field(default=origin_typ))
setattr(cls, name, attr.attr(default=origin_typ))
elif issubclass(origin_typ, Enum): # first enum is the default
setattr(cls, name, field(default=next(iter(origin_typ))))
setattr(cls, name, attr.attr(default=next(iter(origin_typ))))
else:
raise TypeError(f"Unsupported type for auto_field: {origin_typ}")
return _dataclass(cls, **kwargs)
return attr.dataclass(cls, slots=True, **kwargs)
class StrEnum(_StrEnum):
@ -497,3 +527,15 @@ class CarParams:
class NetworkLocation(StrEnum):
fwdCamera = auto() # Standard/default integration at LKAS camera
gateway = auto() # Integration at vehicle's CAN gateway
# @attr.dataclass(slots=True)
@auto_dataclass
class Test:
# actuators: CarControl.Actuators = attr.attr(factory=lambda: CarControl.Actuators(), converter=lambda arg: CarControl.Actuators(**arg))
actuators: CarControl.Actuators = auto_field()
hudControl: CarControl.HUDControl = auto_field() # attr.attr(factory=lambda: CarControl.HUDControl(), converter=lambda arg: CarControl.HUDControl(**arg))
# Test(**{'actuators': {'gas': 1.0}})