diff --git a/festim/boundary_conditions/dirichlet_bc.py b/festim/boundary_conditions/dirichlet_bc.py index 6bfe2f466..2ab2d61b6 100644 --- a/festim/boundary_conditions/dirichlet_bc.py +++ b/festim/boundary_conditions/dirichlet_bc.py @@ -99,6 +99,10 @@ def create_value( if "t" in arguments and "x" not in arguments and "T" not in arguments: # only t is an argument + if not isinstance(self.value(t=float(t)), (float, int)): + raise ValueError( + f"self.value should return a float or an int, not {type(self.value(t=float(t)))} " + ) self.value_fenics = F.as_fenics_constant( mesh=mesh, value=self.value(t=float(t)) ) diff --git a/festim/hydrogen_transport_problem.py b/festim/hydrogen_transport_problem.py index 14a808eaf..dd50acbfb 100644 --- a/festim/hydrogen_transport_problem.py +++ b/festim/hydrogen_transport_problem.py @@ -191,7 +191,11 @@ def define_temperature(self): # if temperature is callable, process accordingly elif callable(self.temperature): arguments = self.temperature.__code__.co_varnames - if "t" in arguments and "x" not in arguments and "T" not in arguments: + if "t" in arguments and "x" not in arguments: + if not isinstance(self.temperature(t=float(self.t)), (float, int)): + raise ValueError( + f"self.temperature should return a float or an int, not {type(self.temperature(t=float(self.t)))} " + ) # only t is an argument self.temperature_fenics = F.as_fenics_constant( mesh=self.mesh.mesh, value=self.temperature(t=float(self.t)) diff --git a/test/test_dirichlet_bc.py b/test/test_dirichlet_bc.py index b34cd6a7f..7fff122fa 100644 --- a/test/test_dirichlet_bc.py +++ b/test/test_dirichlet_bc.py @@ -226,6 +226,7 @@ def test_callable_x_only(): lambda x, t: 1.0 + x[0] + t, lambda x, t, T: 1.0 + x[0] + t + T, lambda x, t: ufl.conditional(ufl.lt(t, 1.0), 100.0 + x[0], 0.0), + lambda t: 100.0 if t < 1 else 0.0, ], ) def test_integration_with_HTransportProblem(value): @@ -283,6 +284,32 @@ def test_integration_with_HTransportProblem(value): assert np.isclose(computed_value, expected_value) +@pytest.mark.parametrize( + "value", + [ + lambda t: ufl.conditional(ufl.lt(t, 1.0), 1, 2), + lambda t: 1 + ufl.conditional(ufl.lt(t, 1.0), 1, 2.0), + lambda t: 2 * ufl.conditional(ufl.lt(t, 1.0), 1, 2.0), + lambda t: 2 / ufl.conditional(ufl.lt(t, 1.0), 1, 2.0), + ], +) +def test_define_value_error_if_ufl_conditional_t_only(value): + """Test that a ValueError is raised when the value attribute is a callable + of t only and contains a ufl conditional""" + + subdomain = F.SurfaceSubdomain1D(1, x=1) + species = F.Species("test") + + bc = F.DirichletBC(subdomain, value, species) + + t = fem.Constant(mesh, 0.0) + + with pytest.raises( + ValueError, match="self.value should return a float or an int, not " + ): + bc.create_value(mesh=mesh, function_space=None, temperature=None, t=t) + + def test_species_predefined(): """Test a ValueError is raised when the species defined in the boundary condition is not predefined in the model""" diff --git a/test/test_h_transport_problem.py b/test/test_h_transport_problem.py index a8dfc570f..ab30b9e77 100644 --- a/test/test_h_transport_problem.py +++ b/test/test_h_transport_problem.py @@ -77,6 +77,7 @@ def test_define_temperature_value_error_raised(): (lambda x: 1.0 + x[0], fem.Function), (lambda x, t: 1.0 + x[0] + t, fem.Function), (lambda x, t: ufl.conditional(ufl.lt(t, 1.0), 100.0 + x[0], 0.0), fem.Function), + (lambda t: 100.0 if t < 1 else 0.0, fem.Constant), ], ) def test_define_temperature(input, expected_type): @@ -97,6 +98,29 @@ def test_define_temperature(input, expected_type): assert isinstance(my_model.temperature_fenics, expected_type) +@pytest.mark.parametrize( + "input", + [ + lambda t: ufl.conditional(ufl.lt(t, 1.0), 1, 2), + lambda t: 1 + ufl.conditional(ufl.lt(t, 1.0), 1, 2.0), + lambda t: 2 * ufl.conditional(ufl.lt(t, 1.0), 1, 2.0), + lambda t: 2 / ufl.conditional(ufl.lt(t, 1.0), 1, 2.0), + ], +) +def test_define_temperature_error_if_ufl_conditional_t_only(input): + """Test that a ValueError is raised when the temperature attribute is a callable + of t only and contains a ufl conditional""" + my_model = F.HydrogenTransportProblem(mesh=test_mesh) + my_model.t = fem.Constant(test_mesh.mesh, 0.0) + + my_model.temperature = input + + with pytest.raises( + ValueError, match="self.temperature should return a float or an int, not " + ): + my_model.define_temperature() + + def test_iterate(): """Test that the iterate method updates the solution and time correctly""" # BUILD