diff --git a/dnd_character/equipment.py b/dnd_character/equipment.py index 610e62f..85aea76 100644 --- a/dnd_character/equipment.py +++ b/dnd_character/equipment.py @@ -62,5 +62,12 @@ def __iter__(self): yield k, v -def Item(index: str) -> _Item: - return _Item(**SRD_equipment[index]) +def Item(item: Union[str, dict]) -> _Item: + """ + Create new Item by calling with string (e.g., torch) + Deserialize item by calling with a dict + """ + if type(item) == str: + return _Item(**SRD_equipment[item]) + else: + return _Item(**item) diff --git a/dnd_character/monsters.py b/dnd_character/monsters.py index 945c9f0..80fdd39 100644 --- a/dnd_character/monsters.py +++ b/dnd_character/monsters.py @@ -12,7 +12,7 @@ @dataclass(kw_only=True) class _Monster: - """Dataclass for items. Deserialize item with `_Monster(**dict)` or Monster() function""" + """Dataclass for monsters. Deserialize item with `_Monster(**dict)` or Monster() function""" index: str uid: str = uuid4().hex @@ -57,5 +57,14 @@ def __iter__(self): yield k, v -def Monster(index: str) -> _Monster: - return _Monster(**SRD_monsters[index]) +def Monster(monster: Union[str, dict]) -> _Monster: + """ + Create new Monster by calling with string (e.g., zombie) + Deserialize monster by calling with a dict + """ + if type(monster) == str: + # new monster + return _Monster(**SRD_monsters[monster]) + else: + # deserialized monster + return _Monster(**monster) diff --git a/tests/test_equipment.py b/tests/test_equipment.py index 9b4fe55..d591908 100644 --- a/tests/test_equipment.py +++ b/tests/test_equipment.py @@ -346,3 +346,8 @@ def test_item_serialization(item_name: str, expected_value: dict): item = Item(item_name) serialized_item = literal_eval(str(dict(item))) assert all([serialized_item[k] == v for k, v in expected_value.items()]) + + +def test_item_function_deserializes_dict(): + torch = Item("torch") + assert Item(dict(torch)) == torch diff --git a/tests/test_monsters.py b/tests/test_monsters.py index ea5a4c7..34b632b 100644 --- a/tests/test_monsters.py +++ b/tests/test_monsters.py @@ -79,3 +79,8 @@ def test_zombie_serialization(): } serialized_zombie = dict(zombie) assert all([serialized_zombie[k] == v for k, v in expected_zombie.items()]) + + +def test_monster_function_deserializes_dict(): + roper = Monster("roper") + assert Monster(dict(roper)) == roper