diff --git a/simpeg_drivers/utils/regularization.py b/simpeg_drivers/utils/regularization.py index 56d6e164..506b30c8 100644 --- a/simpeg_drivers/utils/regularization.py +++ b/simpeg_drivers/utils/regularization.py @@ -40,18 +40,32 @@ def cell_neighbors_along_axis(mesh: TreeMesh, axis: str) -> np.ndarray: stencil = getattr(mesh, f"cell_gradient_{axis}") ith_neighbor, jth_neighbor, _ = ssp.find(stencil) - n_stencils = int(ith_neighbor.shape[0] / 2) - stencil_indices = jth_neighbor[np.argsort(ith_neighbor)].reshape((n_stencils, 2)) + stencil_indices = jth_neighbor[np.argsort(ith_neighbor)].reshape((-1, 2)) return np.sort(stencil_indices, axis=1) +def clean_index_array(index_array: np.ndarray) -> np.ndarray: + """ + Remove duplicate rows or rows with -1 in index array. + + :param index_array: Array of index pairs. + + :return: Cleaned array of index pairs. + """ + array = np.unique(index_array, axis=0) + + # Remove all the -1 for TreeMesh + mask = ~np.any(array == -1, axis=1) + return array[mask, :] + + def collect_all_neighbors( neighbors: list[np.ndarray], neighbors_backwards: list[np.ndarray], - adjacent: list[np.ndarray], - adjacent_backwards: list[np.ndarray], -) -> np.ndarray: + adjacent: np.ndarray, + adjacent_backwards: np.ndarray, +) -> list[np.ndarray]: """ Collect all neighbors for cells in the mesh. @@ -59,110 +73,121 @@ def collect_all_neighbors( :param neighbors_backwards: Direct neighbors in reverse order. :param adjacent: Adjacent neighbors (corners). :param adjacent_backwards: Adjacent neighbors in reverse order. - """ - all_neighbors = [] # Store - - all_neighbors += [neighbors[0]] - all_neighbors += [neighbors[1]] - - all_neighbors += [np.c_[neighbors[0][:, 0], adjacent[0][neighbors[0][:, 1]]]] - all_neighbors += [np.c_[neighbors[0][:, 1], adjacent[0][neighbors[0][:, 0]]]] - - all_neighbors += [np.c_[adjacent[1][neighbors[1][:, 0]], neighbors[1][:, 1]]] - all_neighbors += [np.c_[adjacent[1][neighbors[1][:, 1]], neighbors[1][:, 0]]] - # Repeat backward for Treemesh - all_neighbors += [neighbors_backwards[0]] - all_neighbors += [neighbors_backwards[1]] - - all_neighbors += [ + :return: List of arrays of cell neighbors in all principle directions. List + length is 8 for 2D meshes and 26 for 3D meshes. + """ + neighbours_lists = [ + neighbors[0], # [i+1, j] + neighbors[1], # [i, j+1] + np.c_[neighbors[0][:, 0], adjacent[:, 0][neighbors[0][:, 1]]], # [i+1, j+1] + np.c_[neighbors[0][:, 1], adjacent[:, 0][neighbors[0][:, 0]]], # [i-1, j+1] + np.c_[adjacent[:, 1][neighbors[1][:, 1]], neighbors[1][:, 0]], # [i+1, j-1] + # Repeat backward for Treemesh + neighbors_backwards[0], # [i-1, j] + neighbors_backwards[1], # [i, j-1] np.c_[ neighbors_backwards[0][:, 0], - adjacent_backwards[0][neighbors_backwards[0][:, 1]], - ] - ] - all_neighbors += [ - np.c_[ - neighbors_backwards[0][:, 1], - adjacent_backwards[0][neighbors_backwards[0][:, 0]], - ] + adjacent_backwards[:, 0][neighbors_backwards[0][:, 1]], + ], # [i-1, j-1] ] # Stack all and keep only unique pairs - all_neighbors = np.vstack(all_neighbors) - all_neighbors = np.unique(all_neighbors, axis=0) - - # Remove all the -1 for TreeMesh - all_neighbors = all_neighbors[ - (all_neighbors[:, 0] != -1) & (all_neighbors[:, 1] != -1), : - ] + all_neighbors = [clean_index_array(elem) for elem in neighbours_lists] # Use all the neighbours on the xy plane to find neighbours in z if len(neighbors) == 3: - all_neighbors_z = [] + max_index = np.vstack(all_neighbors).max() + 1 + neigh_z = np.c_[np.arange(max_index), np.full(max_index, -1)].astype("int") + neigh_z[neighbors[2][:, 0], 1] = neighbors[2][:, 1] - all_neighbors_z += [neighbors[2]] - all_neighbors_z += [neighbors_backwards[2]] + neigh_z_back = np.c_[np.arange(max_index), np.full(max_index, -1)].astype("int") + neigh_z_back[neighbors_backwards[2][:, 0], 1] = neighbors_backwards[2][:, 1] - all_neighbors_z += [ - np.c_[all_neighbors[:, 0], adjacent[2][all_neighbors[:, 1]]] - ] - all_neighbors_z += [ - np.c_[all_neighbors[:, 1], adjacent[2][all_neighbors[:, 0]]] + z_list = [ + neighbors[2], # z-positive + neighbors_backwards[2], # z-negative ] + for elem in all_neighbors: # All x and y neighbors + z_list.append( + clean_index_array( + np.c_[elem[:, 0], neigh_z[elem[:, 1], 1]] + ) # [i, j, k+1] + ) + z_list.append( + clean_index_array( + np.c_[elem[:, 0], neigh_z_back[elem[:, 1], 1]] + ) # [i, j, k-1] + ) + + all_neighbors += z_list - all_neighbors_z += [ - np.c_[all_neighbors[:, 0], adjacent_backwards[2][all_neighbors[:, 1]]] - ] - all_neighbors_z += [ - np.c_[all_neighbors[:, 1], adjacent_backwards[2][all_neighbors[:, 0]]] - ] + return all_neighbors - # Stack all and keep only unique pairs - all_neighbors = np.vstack([all_neighbors, np.vstack(all_neighbors_z)]) - all_neighbors = np.unique(all_neighbors, axis=0) - # Remove all the -1 for TreeMesh - all_neighbors = all_neighbors[ - (all_neighbors[:, 0] != -1) & (all_neighbors[:, 1] != -1), : - ] +def cell_adjacent(mesh: TreeMesh, backward: bool = False) -> list[np.ndarray]: + """ + Find all adjacent (corner) cells from cell neighbor array. - return all_neighbors + :param mesh: Input TreeMesh + :param backward: If True, find the opposite corner neighbors. + :return: Array of adjacent cell neighbors. + """ + neighbors = [ + cell_neighbors_along_axis(mesh, "x"), + cell_neighbors_along_axis(mesh, "y"), + ] -def cell_adjacent(neighbors: list[np.ndarray]) -> list[np.ndarray]: - """Find all adjacent cells (corners) from cell neighbor array.""" + if backward: + neighbors = [np.fliplr(k) for k in neighbors] - dim = len(neighbors) - max_index = np.max(np.vstack(neighbors)) - corners = -1 * np.ones((dim, max_index + 1), dtype="int") + corners = -1 * np.ones((mesh.n_cells, 2), dtype="int") - corners[0, neighbors[1][:, 0]] = neighbors[1][:, 1] - corners[1, neighbors[0][:, 1]] = neighbors[0][:, 0] - if dim == 3: - corners[2, neighbors[2][:, 0]] = neighbors[2][:, 1] + corners[neighbors[1][:, 0], 0] = neighbors[1][:, 1] + corners[neighbors[0][:, 1], 1] = neighbors[0][:, 0] - return [np.array(k) for k in corners.tolist()] + return corners -def cell_neighbors(mesh: TreeMesh) -> np.ndarray: - """Find all cell neighbors in a TreeMesh.""" +def cell_neighbors_lists(mesh: TreeMesh) -> list[np.ndarray]: + """ + Find cell neighbors in all directions. + + :param mesh: Input TreeMesh. + + :return: List of arrays of cell neighbors in all principle directions. List + length is 8 for 2D meshes and 26 for 3D meshes. + """ + neighbors = [ + cell_neighbors_along_axis(mesh, "x"), + cell_neighbors_along_axis(mesh, "y"), + ] - neighbors = [] - neighbors.append(cell_neighbors_along_axis(mesh, "x")) - neighbors.append(cell_neighbors_along_axis(mesh, "y")) if mesh.dim == 3: neighbors.append(cell_neighbors_along_axis(mesh, "z")) neighbors_backwards = [np.fliplr(k) for k in neighbors] - corners = cell_adjacent(neighbors) - corners_backwards = cell_adjacent(neighbors_backwards) + corners = cell_adjacent(mesh) + corners_backwards = cell_adjacent(mesh, backward=True) return collect_all_neighbors( neighbors, neighbors_backwards, corners, corners_backwards ) +def cell_neighbors(mesh: TreeMesh) -> np.ndarray: + """ + Find all cell neighbors in a TreeMesh. + + :param mesh: Input TreeMesh. + + :return: Array of unique and sorted cell neighbor pairs. + """ + neighbors_lists = cell_neighbors_lists(mesh) + return np.unique(np.vstack(neighbors_lists), axis=1) + + def rotate_xz_2d(mesh: TreeMesh, phi: np.ndarray) -> ssp.csr_matrix: """ Create a 2d ellipsoidal rotation matrix for the xz plane. @@ -171,6 +196,8 @@ def rotate_xz_2d(mesh: TreeMesh, phi: np.ndarray) -> ssp.csr_matrix: compensate for cell aspect ratio. :param phi: Angle in radians for clockwise rotation about the y-axis (xz plane). + + :return: Sparse rotation matrix """ if mesh.dim != 2: @@ -197,6 +224,8 @@ def rotate_yz_3d(mesh: TreeMesh, theta: np.ndarray) -> ssp.csr_matrix: compensate for cell aspect ratio. :param theta: Angle in radians for clockwise rotation about the x-axis (yz plane). + + :return: Sparse rotation matrix """ hy = mesh.h_gridded[:, 1] hz = mesh.h_gridded[:, 2] @@ -213,6 +242,8 @@ def rotate_xy_3d(mesh: TreeMesh, phi: np.ndarray) -> ssp.csr_matrix: compensate for cell aspect ratio. :param phi: Angle in radians for clockwise rotation about the z-axis (xy plane). + + :return: Sparse rotation matrix """ hx = mesh.h_gridded[:, 0] hy = mesh.h_gridded[:, 1] @@ -231,6 +262,8 @@ def get_cell_normals(n_cells: int, axis: str, outward: bool, dim: int) -> np.nda False for inward facing normals. :param dim: Dimension of the mesh. Either 2 for drape model or 3 for octree. + + :return: Array of cell normals. """ ind = 1 if outward else -1 diff --git a/tests/utils_regularization_test.py b/tests/utils_regularization_test.py index 28b9298b..2cfd5905 100644 --- a/tests/utils_regularization_test.py +++ b/tests/utils_regularization_test.py @@ -12,10 +12,9 @@ from discretize import TreeMesh from simpeg_drivers.utils.regularization import ( - cell_adjacent, + cell_neighbors, cell_neighbors_along_axis, - collect_all_neighbors, - direction_and_dip, + cell_neighbors_lists, ensure_dip_direction_convention, ) @@ -56,13 +55,11 @@ def test_cell_neighbors_along_axis(): def test_collect_all_neighbors(): mesh = get_mesh() centers = mesh.cell_centers - neighbors = [cell_neighbors_along_axis(mesh, k) for k in "xyz"] - neighbors_bck = [np.fliplr(k) for k in neighbors] - corners = cell_adjacent(neighbors) - corners_bck = cell_adjacent(neighbors_bck) - all_neighbors = collect_all_neighbors( - neighbors, neighbors_bck, corners, corners_bck - ) + neighbors_lists = cell_neighbors_lists(mesh) + assert len(neighbors_lists) == 26 + + all_neighbors = cell_neighbors(mesh) + assert np.allclose(centers[7], [15.0, 15.0, 15.0]) neighbor_centers = centers[all_neighbors[all_neighbors[:, 0] == 7][:, 1]].tolist() assert [5, 5, 5] in neighbor_centers