Skip to content

Commit

Permalink
Merge pull request #243 from groutr/priority-flood
Browse files Browse the repository at this point in the history
Use Priority Flood algorithm to fill depressions
  • Loading branch information
mdbartos authored Feb 19, 2024
2 parents 174cafa + 4f0558e commit 56e6aea
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 21 deletions.
152 changes: 150 additions & 2 deletions pysheds/_sgrid.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from heapq import heappop, heappush
from heapq import heappop, heappush, heapify
import math
import numpy as np
from numba import njit, prange
from functools import wraps
from numba import njit, prange, from_dtype
from numba.types import float64, int64, uint32, uint16, uint8, boolean, UniTuple, Tuple, List, DictType, void
from numba.typed import typedlist

# Functions for 'flowdir'

Expand Down Expand Up @@ -1856,3 +1858,149 @@ def _fill_pits_numba(dem, pit_indices):
adjustment = min(diff, adjustment)
pits_filled.flat[k] += (adjustment)
return pits_filled

@njit(boundscheck=True, cache=True)
def _first_true1d(arr, start=0, end=None, step=1, invert=False):
if end is None:
end = len(arr)

if invert:
for i in range(start, end, step):
if not arr[i]:
return i
else:
return -1
else:
for i in range(start, end, step):
if arr[i]:
return i
else:
return -1

@njit(parallel=True, cache=True)
def _top(mask):
nc = mask.shape[1]
rv = np.zeros(nc, dtype='int64')
for i in prange(nc):
rv[i] = _first_true1d(mask[:, i], invert=True)
return rv

@njit(parallel=True, cache=True)
def _bottom(mask):
nr, nc = mask.shape[0], mask.shape[1]
rv = np.zeros(nc, dtype='int64')
for i in prange(nc):
rv[i] = _first_true1d(mask[:, i], start=nr - 1, end=-1, step=-1, invert=True)
return rv

@njit(parallel=True, cache=True)
def _left(mask):
nr = mask.shape[0]
rv = np.zeros(nr, dtype='int64')
for i in prange(nr):
rv[i] = _first_true1d(mask[i, :], invert=True)
return rv

@njit(parallel=True, cache=True)
def _right(mask):
nr, nc = mask.shape[0], mask.shape[1]
rv = np.zeros(nr, dtype='int64')
for i in prange(nr):
rv[i] = _first_true1d(mask[i, :], start=nc - 1, end=-1, step=-1, invert=True)
return rv


@njit(cache=True)
def count(start=0, step=1):
# Numba accelerated count() from itertools
# count(10) --> 10 11 12 13 14 ...
# count(2.5, 0.5) --> 2.5 3.0 3.5 ...
n = start
while True:
yield n
n += step


def pfwrapper(func):
# Implemenation detail of priority-flood algorithm
# Needed to define the types used in priority queue
@wraps(func)
def _wrapper(dem, mask, *args):
# Tuple elements:
# 0: dem data type (for elevation priority)
# 1: int64 for insertion index (to maintain total ordering)
# 2: int64 for row index
# 3: int64 for col index
tuple_type = Tuple([from_dtype(dem.dtype), int64, int64, int64])
return func(dem, mask, tuple_type, *args)
return _wrapper


@pfwrapper
@njit(cache=True)
def _priority_flood(dem, dem_mask, tuple_type):
open_cells = typedlist.List.empty_list(tuple_type) # Priority queue
pits = typedlist.List.empty_list(tuple_type) # FIFO queue
closed_cells = dem_mask.copy()
isertn = count()

# Push the edges onto priority queue
y, x = dem.shape

edge = _left(dem_mask)[:-1]
for row, col in zip(count(), edge):
if col >= 0:
open_cells.append((dem[row, col], next(isertn), row, col))
closed_cells[row, col] = True
edge = _bottom(dem_mask)[:-1]
for row, col in zip(edge, count()):
if row >= 0:
open_cells.append((dem[row, col], next(isertn), row, col))
closed_cells[row, col] = True
edge = np.flip(_right(dem_mask))[:-1]
for row, col in zip(count(y - 1, step=-1), edge):
if col >= 0:
open_cells.append((dem[row, col], next(isertn), row, col))
closed_cells[row, col] = True
edge = np.flip(_top(dem_mask))[:-1]
for row, col in zip(edge, count(x - 1, step=-1)):
if row >= 0:
open_cells.append((dem[row, col], next(isertn), row, col))
closed_cells[row, col] = True
heapify(open_cells)

row_offsets = np.array([-1, -1, 0, 1, 1, 1, 0, -1])
col_offsets = np.array([0, 1, 1, 1, 0, -1, -1, -1])

pits_pos = 0
while open_cells or pits_pos < len(pits):
if pits_pos < len(pits):
elv, _, i, j = pits[pits_pos]
pits_pos += 1
else:
elv, _, i, j = heappop(open_cells)

for n in range(8):
row = i + row_offsets[n]
col = j + col_offsets[n]

if row < 0 or row >= y or col < 0 or col >= x:
continue

if dem_mask[row, col] or closed_cells[row, col]:
continue

if dem[row, col] <= elv:
dem[row, col] = elv
pits.append((elv, next(isertn), row, col))
else:
heappush(open_cells, (dem[row, col], next(isertn), row, col))
closed_cells[row, col] = True

# pits book-keeping
if pits_pos == len(pits) and len(pits) > 1024:
# Queue is empty, lets clear it out
pits.clear()
pits_pos = 0

return dem
28 changes: 9 additions & 19 deletions pysheds/sgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import pandas as pd
import geojson
from affine import Affine
from numba.types import Tuple, int64
from numba import from_dtype

try:
import skimage.measure
import skimage.morphology
_HAS_SKIMAGE = True
except ModuleNotFoundError:
_HAS_SKIMAGE = False
Expand Down Expand Up @@ -2113,8 +2115,6 @@ def detect_depressions(self, dem, **kwargs):
depressions : Raster
Boolean Raster indicating locations of depressions.
"""
if not _HAS_SKIMAGE:
raise ImportError('detect_depressions requires skimage.morphology module')
input_overrides = {'dtype' : np.float64, 'nodata' : dem.nodata}
kwargs.update(input_overrides)
dem = self._input_handler(dem, **kwargs)
Expand Down Expand Up @@ -2148,23 +2148,13 @@ def fill_depressions(self, dem, nodata_out=np.nan, **kwargs):
Raster representing digital elevation data with multi-celled
depressions removed.
"""
if not _HAS_SKIMAGE:
raise ImportError('resolve_flats requires skimage.morphology module')
input_overrides = {'dtype' : np.float64, 'nodata' : dem.nodata}
kwargs.update(input_overrides)
dem = self._input_handler(dem, **kwargs)
dem_mask = self._get_nodata_cells(dem)
dem_mask[0, :] = True
dem_mask[-1, :] = True
dem_mask[:, 0] = True
dem_mask[:, -1] = True
# Make sure nothing flows to the nodata cells
seed = np.copy(dem)
seed[~dem_mask] = np.nanmax(dem)
dem_out = skimage.morphology.reconstruction(seed, dem, method='erosion')
dem_out = self._output_handler(data=dem_out, viewfinder=dem.viewfinder,
metadata=dem.metadata, nodata=nodata_out)
return dem_out
result = _self._priority_flood(dem, dem_mask)
dem_filled = self._output_handler(data=result,
viewfinder=dem.viewfinder,
metadata=dem.metadata,
nodata=dem.nodata)
return dem_filled

def detect_flats(self, dem, **kwargs):
"""
Expand Down

0 comments on commit 56e6aea

Please sign in to comment.