Skip to content

Commit

Permalink
Add create_nonzero_conncomp_counts to create counts of valid unwrap…
Browse files Browse the repository at this point in the history
…ped outputs (#485)

* Add function to count nonzero conncomp regions by date

* rename for clarity

* rename to match other `create_` functions
  • Loading branch information
scottstanie authored Nov 7, 2024
1 parent 1fa4962 commit 8d7b43b
Showing 1 changed file with 116 additions and 4 deletions.
120 changes: 116 additions & 4 deletions src/dolphin/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,9 @@ def invert_stack(


def get_incidence_matrix(
ifg_pairs: Sequence[tuple[T, T]], sar_idxs: Sequence[T] | None = None
ifg_pairs: Sequence[tuple[T, T]],
sar_idxs: Sequence[T] | None = None,
delete_first_date_column: bool = True,
) -> np.ndarray:
"""Build the indicator matrix from a list of ifg pairs (index 1, index 2).
Expand All @@ -538,6 +540,10 @@ def get_incidence_matrix(
were formed from.
Otherwise, created from the unique entries in `ifg_pairs`.
Only provide if there are some dates which are not present in `ifg_pairs`.
delete_first_date_column : bool
If True, removes the first column of the matrix to make it full column rank.
Size will be `n_sar_dates - 1` columns.
Otherwise, the matrix will have `n_sar_dates`, but rank `n_sar_dates - 1`.
Returns
-------
Expand All @@ -553,13 +559,13 @@ def get_incidence_matrix(
sar_idxs = sorted(set(flatten(ifg_pairs)))

M = len(ifg_pairs)
N = len(sar_idxs) - 1
col_iter = sar_idxs[1:] if delete_first_date_column else sar_idxs
N = len(col_iter)
A = np.zeros((M, N))

# Create a dictionary mapping sar dates to matrix columns
# We take the first SAR acquisition to be time 0, leave out of matrix
date_to_col = {date: i for i, date in enumerate(sar_idxs[1:])}
# Populate the matrix
date_to_col = {date: i for i, date in enumerate(col_iter)}
for i, (early, later) in enumerate(ifg_pairs):
if early in date_to_col:
A[i, date_to_col[early]] = -1
Expand Down Expand Up @@ -1316,3 +1322,109 @@ def invert_stack_l1(A: ArrayLike, dphi: ArrayLike) -> Array:
# residuals = jnp.sum(residual_vecs, axis=0)

return phase, residuals


def create_nonzero_conncomp_counts(
conncomp_file_list: Sequence[PathOrStr],
output_dir: PathOrStr,
ifg_date_pairs: Sequence[Sequence[DateOrDatetime]] | None = None,
block_shape: tuple[int, int] = (256, 256),
num_threads: int = 4,
) -> list[Path]:
"""Count the number of valid interferograms per date.
Parameters
----------
conncomp_file_list : Sequence[PathOrStr]
List of connected component files
output_dir : PathOrStr
The directory to save the output files
ifg_date_pairs : Sequence[Sequence[DateOrDatetime]], optional
List of date pairs corresponding to the interferograms.
If not provided, will be parsed from filenames.
block_shape : tuple[int, int], optional
The shape of the blocks to process in parallel.
num_threads : int
The number of parallel blocks to process at once.
Returns
-------
out_paths : list[Path]
List of output files, one per unique date
"""
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True, parents=True)

if ifg_date_pairs is None:
ifg_date_pairs = [get_dates(str(f))[:2] for f in conncomp_file_list]
try:
# Ensure it's a list of pairs
ifg_tuples = [(ref, sec) for (ref, sec) in ifg_date_pairs] # noqa: C416
except ValueError as e:
raise ValueError(
"Each item in `ifg_date_pairs` must be a sequence of length 2"
) from e

# Get unique dates and create the counting matrix
sar_dates: list[DateOrDatetime] = sorted(set(utils.flatten(ifg_date_pairs)))

date_counting_matrix = np.abs(
get_incidence_matrix(ifg_tuples, sar_dates, delete_first_date_column=False)
)

# Create output paths for each date
suffix = "_valid_count.tif"
out_paths = [output_dir / f"{d.strftime('%Y%m%d')}{suffix}" for d in sar_dates]

if all(p.exists() for p in out_paths):
logger.info("All output files exist, skipping counting")
return out_paths

logger.info("Counting valid interferograms per date")

# Create VRT stack for reading
vrt_name = Path(output_dir) / "conncomp_network.vrt"
conncomp_reader = io.VRTStack(
file_list=conncomp_file_list,
outfile=vrt_name,
skip_size_check=True,
read_masked=True,
)

def count_by_date(
readers: Sequence[io.StackReader], rows: slice, cols: slice
) -> tuple[np.ndarray, slice, slice]:
"""Process each block by counting valid interferograms per date."""
stack = readers[0][:, rows, cols]
valid_mask = stack.filled(0) != 0 # Shape: (n_ifgs, block_rows, block_cols)

# Use the counting matrix to map from interferograms to dates
# For each pixel, multiply the valid_mask to get counts per date
# Reshape valid_mask to (n_ifgs, -1) to handle all pixels at once
valid_flat = valid_mask.reshape(valid_mask.shape[0], -1)
# Matrix multiply to get counts per date
# (date_counting_matrix.T) is shape (n_sar_dates, n_ifgs), and each row
# has a number of 1s equal to the nonzero conncomps for that date.
date_count_cols = date_counting_matrix.T @ valid_flat
date_counts = date_count_cols.reshape(-1, *valid_mask.shape[1:])

return date_counts, rows, cols

# Setup writer for all output files
writer = io.BackgroundStackWriter(
out_paths, like_filename=conncomp_file_list[0], dtype=np.uint16, units="count"
)

# Process the blocks
io.process_blocks(
readers=[conncomp_reader],
writer=writer,
func=count_by_date,
block_shape=block_shape,
num_threads=num_threads,
)
writer.notify_finished()

logger.info("Completed counting valid interferograms per date")
return out_paths

0 comments on commit 8d7b43b

Please sign in to comment.