From c234320ed4e6dc74c956f714f99df8debd842dae Mon Sep 17 00:00:00 2001 From: Ryan Grout Date: Mon, 5 Feb 2024 11:16:52 -0700 Subject: [PATCH 1/4] Use Priority Flood algorithm to fill depressions. --- pysheds/_sgrid.py | 131 +++++++++++++++++++++++++++++++++++++++++++++- pysheds/sgrid.py | 26 +++------ 2 files changed, 137 insertions(+), 20 deletions(-) diff --git a/pysheds/_sgrid.py b/pysheds/_sgrid.py index 7ec23da..4dc91d2 100644 --- a/pysheds/_sgrid.py +++ b/pysheds/_sgrid.py @@ -1,8 +1,9 @@ -from heapq import heappop, heappush +from heapq import heappop, heappush, heapify import math import numpy as np from numba import njit, prange from numba.types import float64, int64, uint32, uint16, uint8, boolean, UniTuple, Tuple, List, DictType, void +from numba.typed import typedlist # Functions for 'flowdir' @@ -1856,3 +1857,131 @@ def _fill_pits_numba(dem, pit_indices): adjustment = min(diff, adjustment) pits_filled.flat[k] += (adjustment) return pits_filled + +@njit(boundscheck=True, cache=True) +def _first_true1d(arr, start=0, end=None, step=1, invert=False): + if end is None: + end = len(arr) + + if invert: + for i in range(start, end, step): + if not arr[i]: + return i + else: + return -1 + else: + for i in range(start, end, step): + if arr[i]: + return i + else: + return -1 + +@njit(parallel=True, cache=True) +def _top(mask): + nc = mask.shape[1] + rv = np.zeros(nc, dtype='int64') + for i in prange(nc): + rv[i] = _first_true1d(mask[:, i], invert=True) + return rv + +@njit(parallel=True, cache=True) +def _bottom(mask): + nr, nc = mask.shape[0], mask.shape[1] + rv = np.zeros(nc, dtype='int64') + for i in prange(nc): + rv[i] = _first_true1d(mask[:, i], start=nr - 1, end=-1, step=-1, invert=True) + return rv + +@njit(parallel=True, cache=True) +def _left(mask): + nr = mask.shape[0] + rv = np.zeros(nr, dtype='int64') + for i in prange(nr): + rv[i] = _first_true1d(mask[i, :], invert=True) + return rv + +@njit(parallel=True, cache=True) +def _right(mask): + nr, nc = mask.shape[0], mask.shape[1] + rv = np.zeros(nr, dtype='int64') + for i in prange(nr): + rv[i] = _first_true1d(mask[i, :], start=nc - 1, end=-1, step=-1, invert=True) + return rv + + +@njit(cache=True) +def count(start=0, step=1): + # Numba accelerated count() from itertools + # count(10) --> 10 11 12 13 14 ... + # count(2.5, 0.5) --> 2.5 3.0 3.5 ... + n = start + while True: + yield n + n += step + +@njit +def _priority_flood(dem, dem_mask, tuple_type): + open_cells = typedlist.List.empty_list(tuple_type) # Priority queue + pits = typedlist.List.empty_list(tuple_type) # FIFO queue + closed_cells = dem_mask.copy() + + # Push the edges onto priority queue + y, x = dem.shape + + edge = _left(dem_mask)[:-1] + for row, col in zip(count(), edge): + if col >= 0: + open_cells.append((dem[row, col], row, col)) + closed_cells[row, col] = True + edge = _bottom(dem_mask)[:-1] + for row, col in zip(edge, count()): + if row >= 0: + open_cells.append((dem[row, col], row, col)) + closed_cells[row, col] = True + edge = np.flip(_right(dem_mask))[:-1] + for row, col in zip(count(y - 1, step=-1), edge): + if col >= 0: + open_cells.append((dem[row, col], row, col)) + closed_cells[row, col] = True + edge = np.flip(_top(dem_mask))[:-1] + for row, col in zip(edge, count(x - 1, step=-1)): + if row >= 0: + open_cells.append((dem[row, col], row, col)) + closed_cells[row, col] = True + heapify(open_cells) + + row_offsets = np.array([-1, -1, 0, 1, 1, 1, 0, -1]) + col_offsets = np.array([0, 1, 1, 1, 0, -1, -1, -1]) + + pits_pos = 0 + while open_cells or pits_pos < len(pits): + if pits_pos < len(pits): + elv, i, j = pits[pits_pos] + pits_pos += 1 + else: + elv, i, j = heappop(open_cells) + + for n in range(8): + row = i + row_offsets[n] + col = j + col_offsets[n] + + if row < 0 or row >= y or col < 0 or col >= x: + continue + + if dem_mask[row, col] or closed_cells[row, col]: + continue + + if dem[row, col] <= elv: + dem[row, col] = elv + pits.append((elv, row, col)) + else: + heappush(open_cells, (dem[row, col], row, col)) + closed_cells[row, col] = True + + # pits book-keeping + if pits_pos == len(pits) and len(pits) > 1024: + # Queue is empty, lets clear it out + pits.clear() + pits_pos = 0 + + return dem \ No newline at end of file diff --git a/pysheds/sgrid.py b/pysheds/sgrid.py index 1093b0c..f0b3229 100644 --- a/pysheds/sgrid.py +++ b/pysheds/sgrid.py @@ -6,9 +6,11 @@ import pandas as pd import geojson from affine import Affine +from numba.types import Tuple, int64 +from numba import from_dtype + try: import skimage.measure - import skimage.morphology _HAS_SKIMAGE = True except ModuleNotFoundError: _HAS_SKIMAGE = False @@ -2113,8 +2115,6 @@ def detect_depressions(self, dem, **kwargs): depressions : Raster Boolean Raster indicating locations of depressions. """ - if not _HAS_SKIMAGE: - raise ImportError('detect_depressions requires skimage.morphology module') input_overrides = {'dtype' : np.float64, 'nodata' : dem.nodata} kwargs.update(input_overrides) dem = self._input_handler(dem, **kwargs) @@ -2148,23 +2148,11 @@ def fill_depressions(self, dem, nodata_out=np.nan, **kwargs): Raster representing digital elevation data with multi-celled depressions removed. """ - if not _HAS_SKIMAGE: - raise ImportError('resolve_flats requires skimage.morphology module') - input_overrides = {'dtype' : np.float64, 'nodata' : dem.nodata} - kwargs.update(input_overrides) - dem = self._input_handler(dem, **kwargs) + # Implementation detail of priority flood algorithm. + tuple_type = Tuple([from_dtype(dem.dtype), int64, int64]) dem_mask = self._get_nodata_cells(dem) - dem_mask[0, :] = True - dem_mask[-1, :] = True - dem_mask[:, 0] = True - dem_mask[:, -1] = True - # Make sure nothing flows to the nodata cells - seed = np.copy(dem) - seed[~dem_mask] = np.nanmax(dem) - dem_out = skimage.morphology.reconstruction(seed, dem, method='erosion') - dem_out = self._output_handler(data=dem_out, viewfinder=dem.viewfinder, - metadata=dem.metadata, nodata=nodata_out) - return dem_out + return _self._priority_flood(dem, dem_mask, tuple_type) + def detect_flats(self, dem, **kwargs): """ From 91176eb1cc0ad21843e31a678332dabb0c79740f Mon Sep 17 00:00:00 2001 From: Ryan Grout Date: Mon, 5 Feb 2024 13:23:00 -0700 Subject: [PATCH 2/4] Cache jitted function. --- pysheds/_sgrid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pysheds/_sgrid.py b/pysheds/_sgrid.py index 4dc91d2..e0db269 100644 --- a/pysheds/_sgrid.py +++ b/pysheds/_sgrid.py @@ -1919,7 +1919,7 @@ def count(start=0, step=1): yield n n += step -@njit +@njit(cache=True) def _priority_flood(dem, dem_mask, tuple_type): open_cells = typedlist.List.empty_list(tuple_type) # Priority queue pits = typedlist.List.empty_list(tuple_type) # FIFO queue From 04b061b943201fd59fdf10505c0a61984ad69402 Mon Sep 17 00:00:00 2001 From: Ryan Grout Date: Wed, 14 Feb 2024 12:48:18 -0700 Subject: [PATCH 3/4] Add wrapper function. The decorator hides an implementation detail of numba that the caller need not worry about. --- pysheds/_sgrid.py | 37 ++++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/pysheds/_sgrid.py b/pysheds/_sgrid.py index e0db269..f46865a 100644 --- a/pysheds/_sgrid.py +++ b/pysheds/_sgrid.py @@ -1,7 +1,8 @@ from heapq import heappop, heappush, heapify import math import numpy as np -from numba import njit, prange +from functools import wraps +from numba import njit, prange, from_dtype from numba.types import float64, int64, uint32, uint16, uint8, boolean, UniTuple, Tuple, List, DictType, void from numba.typed import typedlist @@ -1919,11 +1920,29 @@ def count(start=0, step=1): yield n n += step + +def pfwrapper(func): + # Implemenation detail of priority-flood algorithm + # Needed to define the types used in priority queue + @wraps(func) + def _wrapper(dem, mask, *args): + # Tuple elements: + # 0: dem data type (for elevation priority) + # 1: int64 for insertion index (to maintain total ordering) + # 2: int64 for row index + # 3: int64 for col index + tuple_type = Tuple([from_dtype(dem.dtype), int64, int64, int64]) + return func(dem, mask, tuple_type, *args) + return _wrapper + + +@pfwrapper @njit(cache=True) def _priority_flood(dem, dem_mask, tuple_type): open_cells = typedlist.List.empty_list(tuple_type) # Priority queue pits = typedlist.List.empty_list(tuple_type) # FIFO queue closed_cells = dem_mask.copy() + isertn = count() # Push the edges onto priority queue y, x = dem.shape @@ -1931,22 +1950,22 @@ def _priority_flood(dem, dem_mask, tuple_type): edge = _left(dem_mask)[:-1] for row, col in zip(count(), edge): if col >= 0: - open_cells.append((dem[row, col], row, col)) + open_cells.append((dem[row, col], next(isertn), row, col)) closed_cells[row, col] = True edge = _bottom(dem_mask)[:-1] for row, col in zip(edge, count()): if row >= 0: - open_cells.append((dem[row, col], row, col)) + open_cells.append((dem[row, col], next(isertn), row, col)) closed_cells[row, col] = True edge = np.flip(_right(dem_mask))[:-1] for row, col in zip(count(y - 1, step=-1), edge): if col >= 0: - open_cells.append((dem[row, col], row, col)) + open_cells.append((dem[row, col], next(isertn), row, col)) closed_cells[row, col] = True edge = np.flip(_top(dem_mask))[:-1] for row, col in zip(edge, count(x - 1, step=-1)): if row >= 0: - open_cells.append((dem[row, col], row, col)) + open_cells.append((dem[row, col], next(isertn), row, col)) closed_cells[row, col] = True heapify(open_cells) @@ -1956,10 +1975,10 @@ def _priority_flood(dem, dem_mask, tuple_type): pits_pos = 0 while open_cells or pits_pos < len(pits): if pits_pos < len(pits): - elv, i, j = pits[pits_pos] + elv, _, i, j = pits[pits_pos] pits_pos += 1 else: - elv, i, j = heappop(open_cells) + elv, _, i, j = heappop(open_cells) for n in range(8): row = i + row_offsets[n] @@ -1973,9 +1992,9 @@ def _priority_flood(dem, dem_mask, tuple_type): if dem[row, col] <= elv: dem[row, col] = elv - pits.append((elv, row, col)) + pits.append((elv, next(isertn), row, col)) else: - heappush(open_cells, (dem[row, col], row, col)) + heappush(open_cells, (dem[row, col], next(isertn), row, col)) closed_cells[row, col] = True # pits book-keeping From 4f0558e6ef1fd822852969dbb068efa1cf363cb4 Mon Sep 17 00:00:00 2001 From: Ryan Grout Date: Wed, 14 Feb 2024 12:48:54 -0700 Subject: [PATCH 4/4] Rework fill_depressions to use _output_handler. --- pysheds/sgrid.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pysheds/sgrid.py b/pysheds/sgrid.py index f0b3229..71625e2 100644 --- a/pysheds/sgrid.py +++ b/pysheds/sgrid.py @@ -2148,11 +2148,13 @@ def fill_depressions(self, dem, nodata_out=np.nan, **kwargs): Raster representing digital elevation data with multi-celled depressions removed. """ - # Implementation detail of priority flood algorithm. - tuple_type = Tuple([from_dtype(dem.dtype), int64, int64]) dem_mask = self._get_nodata_cells(dem) - return _self._priority_flood(dem, dem_mask, tuple_type) - + result = _self._priority_flood(dem, dem_mask) + dem_filled = self._output_handler(data=result, + viewfinder=dem.viewfinder, + metadata=dem.metadata, + nodata=dem.nodata) + return dem_filled def detect_flats(self, dem, **kwargs): """