Skip to content

Commit

Permalink
Merge pull request #627 from RemDelaporteMathurin/fix-conditional
Browse files Browse the repository at this point in the history
Better error message when users use conditional for functions of t only
  • Loading branch information
RemDelaporteMathurin authored Nov 2, 2023
2 parents 798a5cb + d97cf77 commit a651794
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 1 deletion.
4 changes: 4 additions & 0 deletions festim/boundary_conditions/dirichlet_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def create_value(

if "t" in arguments and "x" not in arguments and "T" not in arguments:
# only t is an argument
if not isinstance(self.value(t=float(t)), (float, int)):
raise ValueError(
f"self.value should return a float or an int, not {type(self.value(t=float(t)))} "
)
self.value_fenics = F.as_fenics_constant(
mesh=mesh, value=self.value(t=float(t))
)
Expand Down
6 changes: 5 additions & 1 deletion festim/hydrogen_transport_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,11 @@ def define_temperature(self):
# if temperature is callable, process accordingly
elif callable(self.temperature):
arguments = self.temperature.__code__.co_varnames
if "t" in arguments and "x" not in arguments and "T" not in arguments:
if "t" in arguments and "x" not in arguments:
if not isinstance(self.temperature(t=float(self.t)), (float, int)):
raise ValueError(
f"self.temperature should return a float or an int, not {type(self.temperature(t=float(self.t)))} "
)
# only t is an argument
self.temperature_fenics = F.as_fenics_constant(
mesh=self.mesh.mesh, value=self.temperature(t=float(self.t))
Expand Down
27 changes: 27 additions & 0 deletions test/test_dirichlet_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def test_callable_x_only():
lambda x, t: 1.0 + x[0] + t,
lambda x, t, T: 1.0 + x[0] + t + T,
lambda x, t: ufl.conditional(ufl.lt(t, 1.0), 100.0 + x[0], 0.0),
lambda t: 100.0 if t < 1 else 0.0,
],
)
def test_integration_with_HTransportProblem(value):
Expand Down Expand Up @@ -283,6 +284,32 @@ def test_integration_with_HTransportProblem(value):
assert np.isclose(computed_value, expected_value)


@pytest.mark.parametrize(
"value",
[
lambda t: ufl.conditional(ufl.lt(t, 1.0), 1, 2),
lambda t: 1 + ufl.conditional(ufl.lt(t, 1.0), 1, 2.0),
lambda t: 2 * ufl.conditional(ufl.lt(t, 1.0), 1, 2.0),
lambda t: 2 / ufl.conditional(ufl.lt(t, 1.0), 1, 2.0),
],
)
def test_define_value_error_if_ufl_conditional_t_only(value):
"""Test that a ValueError is raised when the value attribute is a callable
of t only and contains a ufl conditional"""

subdomain = F.SurfaceSubdomain1D(1, x=1)
species = F.Species("test")

bc = F.DirichletBC(subdomain, value, species)

t = fem.Constant(mesh, 0.0)

with pytest.raises(
ValueError, match="self.value should return a float or an int, not "
):
bc.create_value(mesh=mesh, function_space=None, temperature=None, t=t)


def test_species_predefined():
"""Test a ValueError is raised when the species defined in the boundary
condition is not predefined in the model"""
Expand Down
24 changes: 24 additions & 0 deletions test/test_h_transport_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def test_define_temperature_value_error_raised():
(lambda x: 1.0 + x[0], fem.Function),
(lambda x, t: 1.0 + x[0] + t, fem.Function),
(lambda x, t: ufl.conditional(ufl.lt(t, 1.0), 100.0 + x[0], 0.0), fem.Function),
(lambda t: 100.0 if t < 1 else 0.0, fem.Constant),
],
)
def test_define_temperature(input, expected_type):
Expand All @@ -97,6 +98,29 @@ def test_define_temperature(input, expected_type):
assert isinstance(my_model.temperature_fenics, expected_type)


@pytest.mark.parametrize(
"input",
[
lambda t: ufl.conditional(ufl.lt(t, 1.0), 1, 2),
lambda t: 1 + ufl.conditional(ufl.lt(t, 1.0), 1, 2.0),
lambda t: 2 * ufl.conditional(ufl.lt(t, 1.0), 1, 2.0),
lambda t: 2 / ufl.conditional(ufl.lt(t, 1.0), 1, 2.0),
],
)
def test_define_temperature_error_if_ufl_conditional_t_only(input):
"""Test that a ValueError is raised when the temperature attribute is a callable
of t only and contains a ufl conditional"""
my_model = F.HydrogenTransportProblem(mesh=test_mesh)
my_model.t = fem.Constant(test_mesh.mesh, 0.0)

my_model.temperature = input

with pytest.raises(
ValueError, match="self.temperature should return a float or an int, not "
):
my_model.define_temperature()


def test_iterate():
"""Test that the iterate method updates the solution and time correctly"""
# BUILD
Expand Down

0 comments on commit a651794

Please sign in to comment.