diff --git a/festim/hydrogen_transport_problem.py b/festim/hydrogen_transport_problem.py index 9898f4ba2..614e240d0 100644 --- a/festim/hydrogen_transport_problem.py +++ b/festim/hydrogen_transport_problem.py @@ -120,12 +120,8 @@ def define_markers_and_measures(self): dofs_facets, tags_facets = [], [] - # TODO this should be a property of mesh - fdim = self.mesh.mesh.topology.dim - 1 - vdim = self.mesh.mesh.topology.dim - # find all cells in domain and mark them as 0 - num_cells = self.mesh.mesh.topology.index_map(vdim).size_local + num_cells = self.mesh.mesh.topology.index_map(self.mesh.vdim).size_local mesh_cell_indices = np.arange(num_cells, dtype=np.int32) tags_volumes = np.full(num_cells, 0, dtype=np.int32) @@ -137,7 +133,9 @@ def define_markers_and_measures(self): if isinstance(sub_dom, F.VolumeSubdomain1D): # find all cells in subdomain and mark them as sub_dom.id self.volume_subdomains.append(sub_dom) - entities = sub_dom.locate_subdomain_entities(self.mesh.mesh, vdim) + entities = sub_dom.locate_subdomain_entities( + self.mesh.mesh, self.mesh.vdim + ) tags_volumes[entities] = sub_dom.id # dofs and tags need to be in np.in32 format for meshtags @@ -145,9 +143,11 @@ def define_markers_and_measures(self): tags_facets = np.array(tags_facets, dtype=np.int32) # define mesh tags - self.facet_meshtags = meshtags(self.mesh.mesh, fdim, dofs_facets, tags_facets) + self.facet_meshtags = meshtags( + self.mesh.mesh, self.mesh.fdim, dofs_facets, tags_facets + ) self.volume_meshtags = meshtags( - self.mesh.mesh, vdim, mesh_cell_indices, tags_volumes + self.mesh.mesh, self.mesh.vdim, mesh_cell_indices, tags_volumes ) # define measures diff --git a/festim/mesh/mesh.py b/festim/mesh/mesh.py index 5f6be3dac..bb2aa0654 100644 --- a/festim/mesh/mesh.py +++ b/festim/mesh/mesh.py @@ -1,15 +1,14 @@ -import ufl - - class Mesh: """ Mesh class Args: - mesh (dolfinx.mesh.Mesh, optional): the mesh. Defaults to None. + mesh (dolfinx.mesh.Mesh, optional): the mesh. Defaults to None. Attributes: mesh (dolfinx.mesh.Mesh): the mesh + vdim (int): the dimension of the mesh cells + fdim (int): the dimension of the mesh facets """ def __init__(self, mesh=None): @@ -25,3 +24,11 @@ def __init__(self, mesh=None): self.mesh.topology.create_connectivity( self.mesh.topology.dim - 1, self.mesh.topology.dim ) + + @property + def vdim(self): + return self.mesh.topology.dim + + @property + def fdim(self): + return self.mesh.topology.dim - 1 diff --git a/festim/mesh/mesh_1d.py b/festim/mesh/mesh_1d.py index f7de96bee..7378f4783 100644 --- a/festim/mesh/mesh_1d.py +++ b/festim/mesh/mesh_1d.py @@ -1,4 +1,4 @@ -from dolfinx import fem, mesh +from dolfinx import mesh from mpi4py import MPI import ufl import numpy as np @@ -10,7 +10,7 @@ class Mesh1D(Mesh): 1D Mesh Args: - vertices (list): the mesh x-coordinates (m) + vertices (list): the mesh x-coordinates (m) Attributes: vertices (list): the mesh x-coordinates (m) diff --git a/test/test_mesh.py b/test/test_mesh.py new file mode 100644 index 000000000..7edb9f661 --- /dev/null +++ b/test/test_mesh.py @@ -0,0 +1,38 @@ +import festim as F +from dolfinx import mesh as fenics_mesh +from mpi4py import MPI +import pytest + +mesh_1D = fenics_mesh.create_unit_interval(MPI.COMM_WORLD, 10) +mesh_2D = fenics_mesh.create_unit_square(MPI.COMM_WORLD, 10, 10) +mesh_3D = fenics_mesh.create_unit_cube(MPI.COMM_WORLD, 10, 10, 10) + + +@pytest.mark.parametrize("mesh", [mesh_1D, mesh_2D, mesh_3D]) +def test_get_fdim(mesh): + my_mesh = F.Mesh(mesh) + + assert my_mesh.fdim == mesh.topology.dim - 1 + + +def test_fdim_changes_when_mesh_changes(): + my_mesh = F.Mesh() + + for mesh in [mesh_1D, mesh_2D, mesh_3D]: + my_mesh.mesh = mesh + assert my_mesh.fdim == mesh.topology.dim - 1 + + +@pytest.mark.parametrize("mesh", [mesh_1D, mesh_2D, mesh_3D]) +def test_get_vdim(mesh): + my_mesh = F.Mesh(mesh) + + assert my_mesh.vdim == mesh.topology.dim + + +def test_vdim_changes_when_mesh_changes(): + my_mesh = F.Mesh() + + for mesh in [mesh_1D, mesh_2D, mesh_3D]: + my_mesh.mesh = mesh + assert my_mesh.vdim == mesh.topology.dim