diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 344faf1..7a23c0d 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -348,7 +348,7 @@ def __init__( aev_computer=self._aev_computer, aev_mask=aev_mask, alpha_mode=self._alpha_mode, - species=self._species, + species=params.get("species", self._species), device=device, dtype=dtype, )