From 8d7b43bbe8b7ecdbf44e4d7a833c51adab55468e Mon Sep 17 00:00:00 2001 From: Scott Staniewicz Date: Wed, 6 Nov 2024 21:37:08 -0500 Subject: [PATCH] Add `create_nonzero_conncomp_counts` to create counts of valid unwrapped outputs (#485) * Add function to count nonzero conncomp regions by date * rename for clarity * rename to match other `create_` functions --- src/dolphin/timeseries.py | 120 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 116 insertions(+), 4 deletions(-) diff --git a/src/dolphin/timeseries.py b/src/dolphin/timeseries.py index dd269f02..03007602 100644 --- a/src/dolphin/timeseries.py +++ b/src/dolphin/timeseries.py @@ -524,7 +524,9 @@ def invert_stack( def get_incidence_matrix( - ifg_pairs: Sequence[tuple[T, T]], sar_idxs: Sequence[T] | None = None + ifg_pairs: Sequence[tuple[T, T]], + sar_idxs: Sequence[T] | None = None, + delete_first_date_column: bool = True, ) -> np.ndarray: """Build the indicator matrix from a list of ifg pairs (index 1, index 2). @@ -538,6 +540,10 @@ def get_incidence_matrix( were formed from. Otherwise, created from the unique entries in `ifg_pairs`. Only provide if there are some dates which are not present in `ifg_pairs`. + delete_first_date_column : bool + If True, removes the first column of the matrix to make it full column rank. + Size will be `n_sar_dates - 1` columns. + Otherwise, the matrix will have `n_sar_dates`, but rank `n_sar_dates - 1`. Returns ------- @@ -553,13 +559,13 @@ def get_incidence_matrix( sar_idxs = sorted(set(flatten(ifg_pairs))) M = len(ifg_pairs) - N = len(sar_idxs) - 1 + col_iter = sar_idxs[1:] if delete_first_date_column else sar_idxs + N = len(col_iter) A = np.zeros((M, N)) # Create a dictionary mapping sar dates to matrix columns # We take the first SAR acquisition to be time 0, leave out of matrix - date_to_col = {date: i for i, date in enumerate(sar_idxs[1:])} - # Populate the matrix + date_to_col = {date: i for i, date in enumerate(col_iter)} for i, (early, later) in enumerate(ifg_pairs): if early in date_to_col: A[i, date_to_col[early]] = -1 @@ -1316,3 +1322,109 @@ def invert_stack_l1(A: ArrayLike, dphi: ArrayLike) -> Array: # residuals = jnp.sum(residual_vecs, axis=0) return phase, residuals + + +def create_nonzero_conncomp_counts( + conncomp_file_list: Sequence[PathOrStr], + output_dir: PathOrStr, + ifg_date_pairs: Sequence[Sequence[DateOrDatetime]] | None = None, + block_shape: tuple[int, int] = (256, 256), + num_threads: int = 4, +) -> list[Path]: + """Count the number of valid interferograms per date. + + Parameters + ---------- + conncomp_file_list : Sequence[PathOrStr] + List of connected component files + output_dir : PathOrStr + The directory to save the output files + ifg_date_pairs : Sequence[Sequence[DateOrDatetime]], optional + List of date pairs corresponding to the interferograms. + If not provided, will be parsed from filenames. + block_shape : tuple[int, int], optional + The shape of the blocks to process in parallel. + num_threads : int + The number of parallel blocks to process at once. + + Returns + ------- + out_paths : list[Path] + List of output files, one per unique date + + """ + output_dir = Path(output_dir) + output_dir.mkdir(exist_ok=True, parents=True) + + if ifg_date_pairs is None: + ifg_date_pairs = [get_dates(str(f))[:2] for f in conncomp_file_list] + try: + # Ensure it's a list of pairs + ifg_tuples = [(ref, sec) for (ref, sec) in ifg_date_pairs] # noqa: C416 + except ValueError as e: + raise ValueError( + "Each item in `ifg_date_pairs` must be a sequence of length 2" + ) from e + + # Get unique dates and create the counting matrix + sar_dates: list[DateOrDatetime] = sorted(set(utils.flatten(ifg_date_pairs))) + + date_counting_matrix = np.abs( + get_incidence_matrix(ifg_tuples, sar_dates, delete_first_date_column=False) + ) + + # Create output paths for each date + suffix = "_valid_count.tif" + out_paths = [output_dir / f"{d.strftime('%Y%m%d')}{suffix}" for d in sar_dates] + + if all(p.exists() for p in out_paths): + logger.info("All output files exist, skipping counting") + return out_paths + + logger.info("Counting valid interferograms per date") + + # Create VRT stack for reading + vrt_name = Path(output_dir) / "conncomp_network.vrt" + conncomp_reader = io.VRTStack( + file_list=conncomp_file_list, + outfile=vrt_name, + skip_size_check=True, + read_masked=True, + ) + + def count_by_date( + readers: Sequence[io.StackReader], rows: slice, cols: slice + ) -> tuple[np.ndarray, slice, slice]: + """Process each block by counting valid interferograms per date.""" + stack = readers[0][:, rows, cols] + valid_mask = stack.filled(0) != 0 # Shape: (n_ifgs, block_rows, block_cols) + + # Use the counting matrix to map from interferograms to dates + # For each pixel, multiply the valid_mask to get counts per date + # Reshape valid_mask to (n_ifgs, -1) to handle all pixels at once + valid_flat = valid_mask.reshape(valid_mask.shape[0], -1) + # Matrix multiply to get counts per date + # (date_counting_matrix.T) is shape (n_sar_dates, n_ifgs), and each row + # has a number of 1s equal to the nonzero conncomps for that date. + date_count_cols = date_counting_matrix.T @ valid_flat + date_counts = date_count_cols.reshape(-1, *valid_mask.shape[1:]) + + return date_counts, rows, cols + + # Setup writer for all output files + writer = io.BackgroundStackWriter( + out_paths, like_filename=conncomp_file_list[0], dtype=np.uint16, units="count" + ) + + # Process the blocks + io.process_blocks( + readers=[conncomp_reader], + writer=writer, + func=count_by_date, + block_shape=block_shape, + num_threads=num_threads, + ) + writer.notify_finished() + + logger.info("Completed counting valid interferograms per date") + return out_paths