Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 108 additions & 75 deletions simpeg_drivers/utils/regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,129 +40,154 @@ def cell_neighbors_along_axis(mesh: TreeMesh, axis: str) -> np.ndarray:
stencil = getattr(mesh, f"cell_gradient_{axis}")

ith_neighbor, jth_neighbor, _ = ssp.find(stencil)
n_stencils = int(ith_neighbor.shape[0] / 2)
stencil_indices = jth_neighbor[np.argsort(ith_neighbor)].reshape((n_stencils, 2))
stencil_indices = jth_neighbor[np.argsort(ith_neighbor)].reshape((-1, 2))

return np.sort(stencil_indices, axis=1)


def clean_index_array(index_array: np.ndarray) -> np.ndarray:
"""
Remove duplicate rows or rows with -1 in index array.

:param index_array: Array of index pairs.

:return: Cleaned array of index pairs.
"""
array = np.unique(index_array, axis=0)

# Remove all the -1 for TreeMesh
mask = ~np.any(array == -1, axis=1)
return array[mask, :]


def collect_all_neighbors(
neighbors: list[np.ndarray],
neighbors_backwards: list[np.ndarray],
adjacent: list[np.ndarray],
adjacent_backwards: list[np.ndarray],
) -> np.ndarray:
adjacent: np.ndarray,
adjacent_backwards: np.ndarray,
) -> list[np.ndarray]:
"""
Collect all neighbors for cells in the mesh.

:param neighbors: Direct neighbors in each principle axes.
:param neighbors_backwards: Direct neighbors in reverse order.
:param adjacent: Adjacent neighbors (corners).
:param adjacent_backwards: Adjacent neighbors in reverse order.
"""
all_neighbors = [] # Store

all_neighbors += [neighbors[0]]
all_neighbors += [neighbors[1]]

all_neighbors += [np.c_[neighbors[0][:, 0], adjacent[0][neighbors[0][:, 1]]]]
all_neighbors += [np.c_[neighbors[0][:, 1], adjacent[0][neighbors[0][:, 0]]]]

all_neighbors += [np.c_[adjacent[1][neighbors[1][:, 0]], neighbors[1][:, 1]]]
all_neighbors += [np.c_[adjacent[1][neighbors[1][:, 1]], neighbors[1][:, 0]]]

# Repeat backward for Treemesh
all_neighbors += [neighbors_backwards[0]]
all_neighbors += [neighbors_backwards[1]]

all_neighbors += [
:return: List of arrays of cell neighbors in all principle directions. List
length is 8 for 2D meshes and 26 for 3D meshes.
"""
neighbours_lists = [
neighbors[0], # [i+1, j]
neighbors[1], # [i, j+1]
np.c_[neighbors[0][:, 0], adjacent[:, 0][neighbors[0][:, 1]]], # [i+1, j+1]
np.c_[neighbors[0][:, 1], adjacent[:, 0][neighbors[0][:, 0]]], # [i-1, j+1]
np.c_[adjacent[:, 1][neighbors[1][:, 1]], neighbors[1][:, 0]], # [i+1, j-1]
# Repeat backward for Treemesh
neighbors_backwards[0], # [i-1, j]
neighbors_backwards[1], # [i, j-1]
np.c_[
neighbors_backwards[0][:, 0],
adjacent_backwards[0][neighbors_backwards[0][:, 1]],
]
]
all_neighbors += [
np.c_[
neighbors_backwards[0][:, 1],
adjacent_backwards[0][neighbors_backwards[0][:, 0]],
]
adjacent_backwards[:, 0][neighbors_backwards[0][:, 1]],
], # [i-1, j-1]
]

# Stack all and keep only unique pairs
all_neighbors = np.vstack(all_neighbors)
all_neighbors = np.unique(all_neighbors, axis=0)

# Remove all the -1 for TreeMesh
all_neighbors = all_neighbors[
(all_neighbors[:, 0] != -1) & (all_neighbors[:, 1] != -1), :
]
all_neighbors = [clean_index_array(elem) for elem in neighbours_lists]

# Use all the neighbours on the xy plane to find neighbours in z
if len(neighbors) == 3:
all_neighbors_z = []
max_index = np.vstack(all_neighbors).max() + 1
neigh_z = np.c_[np.arange(max_index), np.full(max_index, -1)].astype("int")
neigh_z[neighbors[2][:, 0], 1] = neighbors[2][:, 1]

all_neighbors_z += [neighbors[2]]
all_neighbors_z += [neighbors_backwards[2]]
neigh_z_back = np.c_[np.arange(max_index), np.full(max_index, -1)].astype("int")
neigh_z_back[neighbors_backwards[2][:, 0], 1] = neighbors_backwards[2][:, 1]

all_neighbors_z += [
np.c_[all_neighbors[:, 0], adjacent[2][all_neighbors[:, 1]]]
]
all_neighbors_z += [
np.c_[all_neighbors[:, 1], adjacent[2][all_neighbors[:, 0]]]
z_list = [
neighbors[2], # z-positive
neighbors_backwards[2], # z-negative
]
for elem in all_neighbors: # All x and y neighbors
z_list.append(
clean_index_array(
np.c_[elem[:, 0], neigh_z[elem[:, 1], 1]]
) # [i, j, k+1]
)
z_list.append(
clean_index_array(
np.c_[elem[:, 0], neigh_z_back[elem[:, 1], 1]]
) # [i, j, k-1]
)

all_neighbors += z_list

all_neighbors_z += [
np.c_[all_neighbors[:, 0], adjacent_backwards[2][all_neighbors[:, 1]]]
]
all_neighbors_z += [
np.c_[all_neighbors[:, 1], adjacent_backwards[2][all_neighbors[:, 0]]]
]
return all_neighbors

# Stack all and keep only unique pairs
all_neighbors = np.vstack([all_neighbors, np.vstack(all_neighbors_z)])
all_neighbors = np.unique(all_neighbors, axis=0)

# Remove all the -1 for TreeMesh
all_neighbors = all_neighbors[
(all_neighbors[:, 0] != -1) & (all_neighbors[:, 1] != -1), :
]
def cell_adjacent(mesh: TreeMesh, backward: bool = False) -> list[np.ndarray]:
"""
Find all adjacent (corner) cells from cell neighbor array.

return all_neighbors
:param mesh: Input TreeMesh
:param backward: If True, find the opposite corner neighbors.

:return: Array of adjacent cell neighbors.
"""
neighbors = [
cell_neighbors_along_axis(mesh, "x"),
cell_neighbors_along_axis(mesh, "y"),
]

def cell_adjacent(neighbors: list[np.ndarray]) -> list[np.ndarray]:
"""Find all adjacent cells (corners) from cell neighbor array."""
if backward:
neighbors = [np.fliplr(k) for k in neighbors]

dim = len(neighbors)
max_index = np.max(np.vstack(neighbors))
corners = -1 * np.ones((dim, max_index + 1), dtype="int")
corners = -1 * np.ones((mesh.n_cells, 2), dtype="int")

corners[0, neighbors[1][:, 0]] = neighbors[1][:, 1]
corners[1, neighbors[0][:, 1]] = neighbors[0][:, 0]
if dim == 3:
corners[2, neighbors[2][:, 0]] = neighbors[2][:, 1]
corners[neighbors[1][:, 0], 0] = neighbors[1][:, 1]
corners[neighbors[0][:, 1], 1] = neighbors[0][:, 0]

return [np.array(k) for k in corners.tolist()]
return corners


def cell_neighbors(mesh: TreeMesh) -> np.ndarray:
"""Find all cell neighbors in a TreeMesh."""
def cell_neighbors_lists(mesh: TreeMesh) -> list[np.ndarray]:
"""
Find cell neighbors in all directions.

:param mesh: Input TreeMesh.

:return: List of arrays of cell neighbors in all principle directions. List
length is 8 for 2D meshes and 26 for 3D meshes.
"""
neighbors = [
cell_neighbors_along_axis(mesh, "x"),
cell_neighbors_along_axis(mesh, "y"),
]

neighbors = []
neighbors.append(cell_neighbors_along_axis(mesh, "x"))
neighbors.append(cell_neighbors_along_axis(mesh, "y"))
if mesh.dim == 3:
neighbors.append(cell_neighbors_along_axis(mesh, "z"))

neighbors_backwards = [np.fliplr(k) for k in neighbors]
corners = cell_adjacent(neighbors)
corners_backwards = cell_adjacent(neighbors_backwards)
corners = cell_adjacent(mesh)
corners_backwards = cell_adjacent(mesh, backward=True)

return collect_all_neighbors(
neighbors, neighbors_backwards, corners, corners_backwards
)


def cell_neighbors(mesh: TreeMesh) -> np.ndarray:
"""
Find all cell neighbors in a TreeMesh.

:param mesh: Input TreeMesh.

:return: Array of unique and sorted cell neighbor pairs.
"""
neighbors_lists = cell_neighbors_lists(mesh)
return np.unique(np.vstack(neighbors_lists), axis=1)


def rotate_xz_2d(mesh: TreeMesh, phi: np.ndarray) -> ssp.csr_matrix:
"""
Create a 2d ellipsoidal rotation matrix for the xz plane.
Expand All @@ -171,6 +196,8 @@ def rotate_xz_2d(mesh: TreeMesh, phi: np.ndarray) -> ssp.csr_matrix:
compensate for cell aspect ratio.
:param phi: Angle in radians for clockwise rotation about the
y-axis (xz plane).

:return: Sparse rotation matrix
"""

if mesh.dim != 2:
Expand All @@ -197,6 +224,8 @@ def rotate_yz_3d(mesh: TreeMesh, theta: np.ndarray) -> ssp.csr_matrix:
compensate for cell aspect ratio.
:param theta: Angle in radians for clockwise rotation about the
x-axis (yz plane).

:return: Sparse rotation matrix
"""
hy = mesh.h_gridded[:, 1]
hz = mesh.h_gridded[:, 2]
Expand All @@ -213,6 +242,8 @@ def rotate_xy_3d(mesh: TreeMesh, phi: np.ndarray) -> ssp.csr_matrix:
compensate for cell aspect ratio.
:param phi: Angle in radians for clockwise rotation about the
z-axis (xy plane).

:return: Sparse rotation matrix
"""
hx = mesh.h_gridded[:, 0]
hy = mesh.h_gridded[:, 1]
Expand All @@ -231,6 +262,8 @@ def get_cell_normals(n_cells: int, axis: str, outward: bool, dim: int) -> np.nda
False for inward facing normals.
:param dim: Dimension of the mesh. Either 2 for drape model or 3
for octree.

:return: Array of cell normals.
"""

ind = 1 if outward else -1
Expand Down
17 changes: 7 additions & 10 deletions tests/utils_regularization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
from discretize import TreeMesh

from simpeg_drivers.utils.regularization import (
cell_adjacent,
cell_neighbors,
cell_neighbors_along_axis,
collect_all_neighbors,
direction_and_dip,
cell_neighbors_lists,
ensure_dip_direction_convention,
)

Expand Down Expand Up @@ -56,13 +55,11 @@ def test_cell_neighbors_along_axis():
def test_collect_all_neighbors():
mesh = get_mesh()
centers = mesh.cell_centers
neighbors = [cell_neighbors_along_axis(mesh, k) for k in "xyz"]
neighbors_bck = [np.fliplr(k) for k in neighbors]
corners = cell_adjacent(neighbors)
corners_bck = cell_adjacent(neighbors_bck)
all_neighbors = collect_all_neighbors(
neighbors, neighbors_bck, corners, corners_bck
)
neighbors_lists = cell_neighbors_lists(mesh)
assert len(neighbors_lists) == 26

all_neighbors = cell_neighbors(mesh)

assert np.allclose(centers[7], [15.0, 15.0, 15.0])
neighbor_centers = centers[all_neighbors[all_neighbors[:, 0] == 7][:, 1]].tolist()
assert [5, 5, 5] in neighbor_centers
Expand Down