diff --git a/cxx/isce3/Headers.cmake b/cxx/isce3/Headers.cmake index 4fa298f5f..0efc668b9 100644 --- a/cxx/isce3/Headers.cmake +++ b/cxx/isce3/Headers.cmake @@ -139,7 +139,8 @@ product/forward.h product/GeoGridParameters.h product/Metadata.h product/ProcessingInformation.h -product/Product.h +product/RadarGridProduct.h +product/GeoGridProduct.h product/RadarGridParameters.h product/Serialization.h product/Swath.h @@ -171,6 +172,14 @@ unwrap/icu/LabelMap.icc unwrap/icu/PhaseGrad.h unwrap/icu/SearchTable.h unwrap/icu/SearchTable.icc +unwrap/ortools/ebert_graph.h +unwrap/ortools/graph.h +unwrap/ortools/graphs.h +unwrap/ortools/iterators.h +unwrap/ortools/max_flow.h +unwrap/ortools/min_cost_flow.h +unwrap/ortools/permutation.h +unwrap/ortools/zvector.h unwrap/phass/ASSP.h unwrap/phass/BMFS.h unwrap/phass/CannyEdgeDetector.h diff --git a/cxx/isce3/Sources.cmake b/cxx/isce3/Sources.cmake index 0be06dc5b..666df335d 100644 --- a/cxx/isce3/Sources.cmake +++ b/cxx/isce3/Sources.cmake @@ -67,9 +67,10 @@ math/polyfunc.cpp math/RootFind1dNewton.cpp math/RootFind1dSecant.cpp polsar/symmetrize.cpp -product/GeoGridParameters.cpp -product/Product.cpp product/RadarGridParameters.cpp +product/GeoGridParameters.cpp +product/RadarGridProduct.cpp +product/GeoGridProduct.cpp signal/Covariance.cpp signal/Crossmul.cpp signal/CrossMultiply.cpp @@ -90,6 +91,8 @@ unwrap/icu/PhaseGrad.cpp unwrap/icu/Residue.cpp unwrap/icu/Tree.cpp unwrap/icu/Unwrap.cpp +unwrap/ortools/max_flow.cc +unwrap/ortools/min_cost_flow.cc unwrap/phass/ASSP.cc unwrap/phass/BMFS.cc unwrap/phass/CannyEdgeDetector.cc diff --git a/cxx/isce3/cuda/geometry/Geo2rdr.h b/cxx/isce3/cuda/geometry/Geo2rdr.h index c147967cf..2bb07e9bd 100644 --- a/cxx/isce3/cuda/geometry/Geo2rdr.h +++ b/cxx/isce3/cuda/geometry/Geo2rdr.h @@ -17,8 +17,8 @@ class isce3::cuda::geometry::Geo2rdr : public isce3::geometry::Geo2rdr { public: - /** Constructor from Product */ - inline Geo2rdr(const isce3::product::Product & product, + /** Constructor from RadarGridProduct */ + inline Geo2rdr(const isce3::product::RadarGridProduct & product, char frequency = 'A', bool nativeDoppler = false) : isce3::geometry::Geo2rdr(product, frequency, nativeDoppler) {} diff --git a/cxx/isce3/cuda/geometry/Topo.h b/cxx/isce3/cuda/geometry/Topo.h index 441e48029..39e5dfa4f 100644 --- a/cxx/isce3/cuda/geometry/Topo.h +++ b/cxx/isce3/cuda/geometry/Topo.h @@ -17,8 +17,8 @@ class isce3::cuda::geometry::Topo : public isce3::geometry::Topo { public: - /** Constructor from Product */ - inline Topo(const isce3::product::Product & product, + /** Constructor from RadarGridProduct */ + inline Topo(const isce3::product::RadarGridProduct & product, char frequency = 'A', bool nativeDoppler = false) : isce3::geometry::Topo(product, frequency, nativeDoppler){} diff --git a/cxx/isce3/cuda/geometry/gpuRTC.cu b/cxx/isce3/cuda/geometry/gpuRTC.cu index 05a5b54d1..8e239c347 100644 --- a/cxx/isce3/cuda/geometry/gpuRTC.cu +++ b/cxx/isce3/cuda/geometry/gpuRTC.cu @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include #include @@ -271,7 +271,7 @@ namespace isce3 { namespace cuda { namespace geometry { -void computeRtc(isce3::product::Product& product, isce3::io::Raster& dem, +void computeRtc(isce3::product::RadarGridProduct& product, isce3::io::Raster& dem, isce3::io::Raster& out_raster, char frequency) { diff --git a/cxx/isce3/cuda/geometry/gpuRTC.h b/cxx/isce3/cuda/geometry/gpuRTC.h index 21ada8718..f1e7511a1 100644 --- a/cxx/isce3/cuda/geometry/gpuRTC.h +++ b/cxx/isce3/cuda/geometry/gpuRTC.h @@ -4,6 +4,6 @@ #include namespace isce3 { namespace cuda { namespace geometry { -void computeRtc(isce3::product::Product& product, isce3::io::Raster& dem, +void computeRtc(isce3::product::RadarGridProduct& product, isce3::io::Raster& dem, isce3::io::Raster& out_raster, char frequency = 'A'); }}} diff --git a/cxx/isce3/cuda/image/ResampSlc.h b/cxx/isce3/cuda/image/ResampSlc.h index 650d266a6..96dc12587 100644 --- a/cxx/isce3/cuda/image/ResampSlc.h +++ b/cxx/isce3/cuda/image/ResampSlc.h @@ -10,13 +10,13 @@ class isce3::cuda::image::ResampSlc : public isce3::image::ResampSlc { public: // Meta-methods - // Constructor from an isce3::product::Product - inline ResampSlc(const isce3::product::Product &product, char frequency = 'A') : + // Constructor from an isce3::product::RadarGridProduct + inline ResampSlc(const isce3::product::RadarGridProduct &product, char frequency = 'A') : isce3::image::ResampSlc(product, frequency) {} - // Constructor from an isce3::product::Product and reference product (flattening) - inline ResampSlc(const isce3::product::Product & product, - const isce3::product::Product & refProduct, + // Constructor from an isce3::product::RadarGridProduct and reference product (flattening) + inline ResampSlc(const isce3::product::RadarGridProduct & product, + const isce3::product::RadarGridProduct & refProduct, char frequency = 'A') : isce3::image::ResampSlc(product, refProduct, frequency) {} diff --git a/cxx/isce3/geocode/GeocodeCov.h b/cxx/isce3/geocode/GeocodeCov.h index f73d2f65a..dfe7fab2f 100644 --- a/cxx/isce3/geocode/GeocodeCov.h +++ b/cxx/isce3/geocode/GeocodeCov.h @@ -15,7 +15,7 @@ #include // isce3::product -#include +#include #include // isce3::geometry diff --git a/cxx/isce3/geocode/geocodeSlc.cpp b/cxx/isce3/geocode/geocodeSlc.cpp index 1b353d989..cf9f892c7 100644 --- a/cxx/isce3/geocode/geocodeSlc.cpp +++ b/cxx/isce3/geocode/geocodeSlc.cpp @@ -14,7 +14,7 @@ #include #include #include -#include +#include #include diff --git a/cxx/isce3/geometry/Geo2rdr.h b/cxx/isce3/geometry/Geo2rdr.h index c79a01f9b..9ef1f3463 100644 --- a/cxx/isce3/geometry/Geo2rdr.h +++ b/cxx/isce3/geometry/Geo2rdr.h @@ -22,7 +22,7 @@ #include // isce3::product -#include +#include #include #include @@ -50,7 +50,7 @@ class isce3::geometry::Geo2rdr { * @param[in] frequency Frequency designation * @param[in] nativeDoppler Flag for using native Doppler frequencies instead of zero-Doppler */ - Geo2rdr(const isce3::product::Product &, + Geo2rdr(const isce3::product::RadarGridProduct &, char frequency = 'A', bool nativeDoppler = false); diff --git a/cxx/isce3/geometry/Geo2rdr.icc b/cxx/isce3/geometry/Geo2rdr.icc index 078dde2d7..3ec82b1fc 100644 --- a/cxx/isce3/geometry/Geo2rdr.icc +++ b/cxx/isce3/geometry/Geo2rdr.icc @@ -10,7 +10,7 @@ inline isce3::geometry::Geo2rdr:: -Geo2rdr(const isce3::product::Product & product, +Geo2rdr(const isce3::product::RadarGridProduct & product, char frequency, bool nativeDoppler) : diff --git a/cxx/isce3/geometry/RTC.cpp b/cxx/isce3/geometry/RTC.cpp index 6624b511e..b1b1fa23f 100644 --- a/cxx/isce3/geometry/RTC.cpp +++ b/cxx/isce3/geometry/RTC.cpp @@ -563,7 +563,7 @@ void areaProjGetNBlocks(const int array_length, const int array_width, } } -void computeRtc(isce3::product::Product& product, isce3::io::Raster& dem_raster, +void computeRtc(isce3::product::RadarGridProduct& product, isce3::io::Raster& dem_raster, isce3::io::Raster& output_raster, char frequency, bool native_doppler, rtcInputTerrainRadiometry input_terrain_radiometry, rtcOutputTerrainRadiometry output_terrain_radiometry, diff --git a/cxx/isce3/geometry/RTC.h b/cxx/isce3/geometry/RTC.h index f5e789ec9..ac95303ed 100644 --- a/cxx/isce3/geometry/RTC.h +++ b/cxx/isce3/geometry/RTC.h @@ -96,7 +96,7 @@ void applyRtc(const isce3::product::RadarGridParameters& radarGrid, /** Generate radiometric terrain correction (RTC) area or area normalization * factor * - * @param[in] product Product + * @param[in] product RadarGridProduct * @param[in] dem_raster Input DEM raster * @param[out] output_raster Output raster * @param[in] frequency Product frequency @@ -116,7 +116,7 @@ void applyRtc(const isce3::product::RadarGridParameters& radarGrid, * looks associated with the geogrid will be saved * @param[in] rtc_memory_mode Select memory mode * */ -void computeRtc(isce3::product::Product& product, isce3::io::Raster& dem_raster, +void computeRtc(isce3::product::RadarGridProduct& product, isce3::io::Raster& dem_raster, isce3::io::Raster& output_raster, char frequency = 'A', bool native_doppler = false, rtcInputTerrainRadiometry inputTerrainRadiometry = diff --git a/cxx/isce3/geometry/Topo.cpp b/cxx/isce3/geometry/Topo.cpp index 6db7e218f..26e68a29d 100644 --- a/cxx/isce3/geometry/Topo.cpp +++ b/cxx/isce3/geometry/Topo.cpp @@ -17,7 +17,7 @@ #include #include -#include +#include // isce3::geometry #include @@ -32,7 +32,7 @@ using isce3::core::Vec3; using isce3::io::Raster; isce3::geometry::Topo:: -Topo(const isce3::product::Product & product, +Topo(const isce3::product::RadarGridProduct & product, char frequency, bool nativeDoppler) : diff --git a/cxx/isce3/geometry/Topo.h b/cxx/isce3/geometry/Topo.h index c2af4430b..89798012e 100644 --- a/cxx/isce3/geometry/Topo.h +++ b/cxx/isce3/geometry/Topo.h @@ -30,11 +30,11 @@ class isce3::geometry::Topo { /** * Constructor using a product * - * @param[in] product Input Product + * @param[in] product Input RadarGridProduct * @param[in] frequency Frequency designation * @param[in] nativeDoppler Flag for using native Doppler frequencies instead of zero-Doppler */ - Topo(const isce3::product::Product &, + Topo(const isce3::product::RadarGridProduct &, char frequency = 'A', bool nativeDoppler = false); diff --git a/cxx/isce3/image/ResampSlc.h b/cxx/isce3/image/ResampSlc.h index 3aa8941ff..666ee130c 100644 --- a/cxx/isce3/image/ResampSlc.h +++ b/cxx/isce3/image/ResampSlc.h @@ -12,7 +12,7 @@ #include #include -#include +#include #include namespace isce3 { namespace image { @@ -21,15 +21,15 @@ class ResampSlc { public: typedef Tile> Tile_t; - /** Constructor from an isce3::product::Product (no flattening) */ - ResampSlc(const isce3::product::Product& product, char frequency = 'A'); + /** Constructor from an isce3::product::RadarGridProduct (no flattening) */ + ResampSlc(const isce3::product::RadarGridProduct& product, char frequency = 'A'); /** - * Constructor from an isce3::product::Product and reference product + * Constructor from an isce3::product::RadarGridProduct and reference product * (flattening) */ - ResampSlc(const isce3::product::Product& product, - const isce3::product::Product& refProduct, char frequency = 'A'); + ResampSlc(const isce3::product::RadarGridProduct& product, + const isce3::product::RadarGridProduct& refProduct, char frequency = 'A'); /** Constructor from an isce3::product::Swath (no flattening) */ ResampSlc(const isce3::product::Swath& swath); @@ -97,7 +97,7 @@ class ResampSlc { void doppler(const isce3::core::LUT2d&); // Set reference product for flattening - void referenceProduct(const isce3::product::Product& product, + void referenceProduct(const isce3::product::RadarGridProduct& product, char frequency = 'A'); // Get/set number of lines per processing tile diff --git a/cxx/isce3/image/ResampSlc.icc b/cxx/isce3/image/ResampSlc.icc index b72303b84..d456caf37 100644 --- a/cxx/isce3/image/ResampSlc.icc +++ b/cxx/isce3/image/ResampSlc.icc @@ -4,8 +4,8 @@ namespace isce3 { namespace image { -// Constructor from an isce3::product::Product -inline ResampSlc::ResampSlc(const isce3::product::Product& product, +// Constructor from an isce3::product::RadarGridProduct +inline ResampSlc::ResampSlc(const isce3::product::RadarGridProduct& product, char frequency) : ResampSlc(product.swath(frequency)) { @@ -15,10 +15,10 @@ inline ResampSlc::ResampSlc(const isce3::product::Product& product, _filename = product.filename(); } -// Constructor from an isce3::product::Product and reference product +// Constructor from an isce3::product::RadarGridProduct and reference product // (flattening) -inline ResampSlc::ResampSlc(const isce3::product::Product& product, - const isce3::product::Product& refProduct, +inline ResampSlc::ResampSlc(const isce3::product::RadarGridProduct& product, + const isce3::product::RadarGridProduct& refProduct, char frequency) : ResampSlc(product.swath(frequency), refProduct.swath(frequency)) { @@ -146,7 +146,7 @@ inline void ResampSlc::doppler(const isce3::core::LUT2d& lut) // Set reference product inline void -ResampSlc::referenceProduct(const isce3::product::Product& refProduct, +ResampSlc::referenceProduct(const isce3::product::RadarGridProduct& refProduct, char frequency) { _setRefDataFromSwath(refProduct.swath(frequency)); diff --git a/cxx/isce3/product/GeoGridProduct.cpp b/cxx/isce3/product/GeoGridProduct.cpp new file mode 100644 index 000000000..9c7eb1544 --- /dev/null +++ b/cxx/isce3/product/GeoGridProduct.cpp @@ -0,0 +1,58 @@ +#include "GeoGridProduct.h" +#include +#include + +/** @param[in] file IH5File object for product. */ +isce3::product::GeoGridProduct:: +GeoGridProduct(isce3::io::IH5File & file) { + + std::string base_dir = "/science/"; + + isce3::io::IGroup base_group = file.openGroup(base_dir); + std::vector key_vector = {"grids"}; + + std::string image_group_str = "", metadata_group_str; + setImageMetadataGroupStr(file, base_dir, base_group, key_vector, + image_group_str, metadata_group_str); + + // If did not find HDF5 groups grids + if (image_group_str.size() == 0) { + std::string error_msg = ("ERROR grids groups not found in " + + file.getFileName()); + throw isce3::except::RuntimeError(ISCE_SRCINFO(), error_msg); + } + + // Get grids group + isce3::io::IGroup imGroup = file.openGroup(image_group_str); + + // Configure grids + loadFromH5(imGroup, _grids); + + // Get metadata group + isce3::io::IGroup metaGroup = file.openGroup(metadata_group_str); + // Configure metadata + + loadFromH5(metaGroup, _metadata); + + // Get look direction + auto identification_vector = isce3::product::findGroupPath(base_group, "identification"); + if (identification_vector.size() == 0) { + std::string error_msg = ("ERROR identification group not found in " + + file.getFileName()); + throw isce3::except::RuntimeError(ISCE_SRCINFO(), error_msg); + } else if (identification_vector.size() > 1) { + std::string error_msg = ("ERROR there should be only one identification" + " group in " + + file.getFileName()); + throw isce3::except::RuntimeError(ISCE_SRCINFO(), error_msg); + } + + std::string identification_group_str = base_dir + identification_vector[0]; + std::string lookDir; + isce3::io::loadFromH5( + file, identification_group_str + "/lookDirection", lookDir); + lookSide(lookDir); + + // Save the filename + _filename = file.filename(); +} diff --git a/cxx/isce3/product/Product.h b/cxx/isce3/product/GeoGridProduct.h similarity index 52% rename from cxx/isce3/product/Product.h rename to cxx/isce3/product/GeoGridProduct.h index fffe906ef..6e04cade9 100644 --- a/cxx/isce3/product/Product.h +++ b/cxx/isce3/product/GeoGridProduct.h @@ -8,44 +8,49 @@ // std #include -#include -#include #include #include #include #include #include -#include +#include // Declarations namespace isce3 { namespace product { - class Product; + class GeoGridProduct; } } -// Product class declaration -class isce3::product::Product { +/** GeoGridProduct class declaration + * + * The L2Produt attribute Grids map, i.e. _grids, associates the + * frequency (key) with the Grids object (value). The GeoGridProduct object + * is usually initiated with an empty map and the serialization of + * the GeoGridProduct is responsible for populating the Grid map + * from the GeoGridProduct's metadata. + */ +class isce3::product::GeoGridProduct { public: /** Constructor from IH5File object. */ - Product(isce3::io::IH5File &); + GeoGridProduct(isce3::io::IH5File &); - /** Constructor with Metadata and Swath map. */ - inline Product(const Metadata &, const std::map &); + /** Constructor with Metadata and Grid map. */ + inline GeoGridProduct(const Metadata &, const std::map &); /** Get a read-only reference to the metadata */ inline const Metadata & metadata() const { return _metadata; } /** Get a reference to the metadata. */ inline Metadata & metadata() { return _metadata; } - /** Get a read-only reference to a swath */ - inline const Swath & swath(char freq) const { return _swaths.at(freq); } - /** Get a reference to a swath */ - inline Swath & swath(char freq) { return _swaths[freq]; } - /** Set a swath */ - inline void swath(const Swath & s, char freq) { _swaths[freq] = s; } + /** Get a read-only reference to a grid */ + inline const Grid & grid(char freq) const { return _grids.at(freq); } + /** Get a reference to a grid */ + inline Grid & grid(char freq) { return _grids[freq]; } + /** Set a grid */ + inline void grid(const Grid & s, char freq) { _grids[freq] = s; } /** Get the look direction */ inline isce3::core::LookSide lookSide() const { return _lookSide; } @@ -59,20 +64,21 @@ class isce3::product::Product { private: isce3::product::Metadata _metadata; - std::map _swaths; + std::map _grids; std::string _filename; isce3::core::LookSide _lookSide; }; /** @param[in] meta Metadata object - * @param[in] swaths Map of Swath objects per frequency */ -isce3::product::Product:: -Product(const Metadata & meta, const std::map & swaths) : - _metadata(meta), _swaths(swaths) {} + * @param[in] grids Map of grid objects per frequency */ +isce3::product::GeoGridProduct:: +GeoGridProduct(const Metadata & meta, const std::map & grids) : + _metadata(meta), _grids(grids) {} + /** @param[in] look String representation of look side */ void -isce3::product::Product:: +isce3::product::GeoGridProduct:: lookSide(const std::string & inputLook) { _lookSide = isce3::core::parseLookSide(inputLook); } diff --git a/cxx/isce3/product/Grid.h b/cxx/isce3/product/Grid.h new file mode 100644 index 000000000..881844163 --- /dev/null +++ b/cxx/isce3/product/Grid.h @@ -0,0 +1,131 @@ +//-*- C++ -*- +//-*- coding: utf-8 -*- + +#pragma once + +// std +#include + +// isce3::core +#include +#include +#include + +// isce3::io +#include + +// isce3::product +#include + +// Declaration +namespace isce3 { + namespace product { + + +/** + * A class for representing Grid metadata originally based on + NISAR L2 products. + */ +class Grid { + + public: + // Constructors + inline Grid() {}; + + /** Get acquired range bandwidth in Hz */ + inline double rangeBandwidth() const { return _rangeBandwidth; } + /** Set acquired range bandwidth in Hz */ + inline void rangeBandwidth(double b) { _rangeBandwidth = b; } + + /** Get acquired azimuth bandwidth in Hz */ + inline double azimuthBandwidth() const { return _azimuthBandwidth; } + /** Set acquired azimuth bandwidth in Hz */ + inline void azimuthBandwidth(double b) { _azimuthBandwidth = b; } + + /** Get processed center frequency in Hz */ + inline double centerFrequency() const { return _centerFrequency; } + /** Set processed center frequency in Hz */ + inline void centerFrequency(double f) { _centerFrequency = f; } + + /** Get processed wavelength in meters */ + inline double wavelength() const { + return isce3::core::speed_of_light / _centerFrequency; + } + + /** Get scene center ground range spacing in meters */ + inline double slantRangeSpacing() const { + return _slantRangeSpacing; + } + /** Set scene center ground range spacing in meters */ + inline void slantRangeSpacing(double s) { + _slantRangeSpacing = s; + } + + /** Get geogrid */ + inline isce3::product::GeoGridParameters geogrid() { + return _geogrid; + } + + /** Set geogrid */ + inline void geogrid(isce3::product::GeoGridParameters geogrid) { + _geogrid = geogrid; + } + + /** Get time spacing of raster grid in seconds */ + inline double zeroDopplerTimeSpacing() const { return _zeroDopplerTimeSpacing; } + /** Set time spacing of raster grid in seconds */ + inline void zeroDopplerTimeSpacing(double dt) { _zeroDopplerTimeSpacing = dt; } + + /* Geogrid parameters */ + + /** Get the X-coordinate start */ + inline double startX() const { return _geogrid.startX(); } + /** Set the X-coordinate start */ + inline void startX(double val) { _geogrid.startX(val); } + + /** Get the y-coordinate start */ + inline double startY() const { return _geogrid.startY(); } + /** Set the y-coordinate start */ + inline void startY(double val) { _geogrid.startY(val);} + + /** Get the X-coordinate spacing */ + inline double spacingX() const { return _geogrid.spacingX(); } + /** Set the X-coordinate spacing */ + inline void spacingX(double val) { _geogrid.spacingX(val); } + + /** Get the y-coordinate spacing */ + inline double spacingY() const { return _geogrid.spacingY(); } + /** Set the y-coordinate spacing */ + inline void spacingY(double val) { _geogrid.spacingY(val);} + + /** Get number of pixels in east-west/x direction for geocoded grid */ + inline size_t width() const { return _geogrid.width(); } + /** Set number of pixels in north-south/y direction for geocoded grid */ + inline void width(int w) { _geogrid.width(w); } + + /** Get number of pixels in north-south/y direction for geocoded grid */ + inline size_t length() const { return _geogrid.length(); } + //** Set number of pixels in east-west/x direction for geocoded grid */ + inline void length(int l) { _geogrid.length(l); } + + /** Get epsg code for geocoded grid */ + inline size_t epsg() const { return _geogrid.epsg(); } + //** Set epsg code for geocoded grid */ + inline void epsg(int l) { _geogrid.epsg(l); } + + + private: + + // Other metadata + isce3::product::GeoGridParameters _geogrid; + double _rangeBandwidth; + double _azimuthBandwidth; + double _slantRangeSpacing; + double _zeroDopplerTimeSpacing; + double _centerFrequency; + +}; + + + } +} \ No newline at end of file diff --git a/cxx/isce3/product/RadarGridParameters.cpp b/cxx/isce3/product/RadarGridParameters.cpp index 443b9c30a..3d0a6d713 100644 --- a/cxx/isce3/product/RadarGridParameters.cpp +++ b/cxx/isce3/product/RadarGridParameters.cpp @@ -1,9 +1,9 @@ #include "RadarGridParameters.h" -#include "Product.h" +#include "RadarGridProduct.h" isce3::product::RadarGridParameters:: -RadarGridParameters(const Product & product, char frequency) : +RadarGridParameters(const RadarGridProduct & product, char frequency) : RadarGridParameters(product.swath(frequency), product.lookSide()) { validate(); diff --git a/cxx/isce3/product/RadarGridParameters.h b/cxx/isce3/product/RadarGridParameters.h index 2f5430aa8..e7966e854 100644 --- a/cxx/isce3/product/RadarGridParameters.h +++ b/cxx/isce3/product/RadarGridParameters.h @@ -20,10 +20,10 @@ class isce3::product::RadarGridParameters { /** * Constructor with a product - * @param[in] product Input Product + * @param[in] product Input RadarGridProduct * @param[in] frequency Frequency designation */ - RadarGridParameters(const isce3::product::Product & product, + RadarGridParameters(const isce3::product::RadarGridProduct & product, char frequency = 'A'); /** diff --git a/cxx/isce3/product/Product.cpp b/cxx/isce3/product/RadarGridProduct.cpp similarity index 74% rename from cxx/isce3/product/Product.cpp rename to cxx/isce3/product/RadarGridProduct.cpp index ba9642081..6643f5b52 100644 --- a/cxx/isce3/product/Product.cpp +++ b/cxx/isce3/product/RadarGridProduct.cpp @@ -1,10 +1,14 @@ -#include "Product.h" +#include "RadarGridProduct.h" #include -/** Find unique group path excluding repeated occurrences +namespace isce3 { namespace product { + +/** + * Return the path to each child group of `group` that ends with the substring + * `group_name`. */ -std::vector _findGroupPath( - isce3::io::IGroup& group, const std::string group_name) +std::vector findGroupPath( + isce3::io::IGroup& group, const std::string& group_name) { auto group_vector = group.find(group_name, ".", "GROUP"); @@ -31,27 +35,26 @@ std::vector _findGroupPath( return filtered_group_vector; } -/** @param[in] file IH5File object for product. */ -isce3::product::Product:: -Product(isce3::io::IH5File & file) { - - std::string base_dir = "/science/"; - - isce3::io::IGroup base_group = file.openGroup(base_dir); +/** + * Return grids or swaths group paths within the base_group. + * Start by assigning an empty string to image_group_str in case + * grids and swaths group are not found. + */ +void setImageMetadataGroupStr( + isce3::io::IH5File & file, + std::string& base_dir, + isce3::io::IGroup& base_group, + std::vector& key_vector, + std::string &image_group_str, + std::string &metadata_group_str) +{ bool flag_has_swaths = false; - /* - In this section, we look for grids or swaths group within base_group. - We start by assigning an empty string to image_group_str in case - grids and swaths group are not found. - */ - std::vector key_vector = {"grids", "swaths"}; - std::string image_group_str = "", metadata_group_str; for (const auto& key : key_vector) { // Look for HDF5 groups that match key (i.e., "grids" or "swaths") - auto group_vector = _findGroupPath(base_group, key); + auto group_vector = findGroupPath(base_group, key); if (group_vector.size() > 1) { /* @@ -84,28 +87,32 @@ Product(isce3::io::IH5File & file) { break; } } +} - // If did not find HDF5 groups swaths or grids +/** @param[in] file IH5File object for product. */ +RadarGridProduct:: +RadarGridProduct(isce3::io::IH5File & file) { + + std::string base_dir = "/science/"; + isce3::io::IGroup base_group = file.openGroup(base_dir); + std::vector key_vector = {"swaths"}; + + std::string image_group_str = "", metadata_group_str; + setImageMetadataGroupStr(file, base_dir, base_group, key_vector, + image_group_str, metadata_group_str); + + // If did not find HDF5 groups swaths if (image_group_str.size() == 0) { - std::string error_msg = ("ERROR swaths and grids groups" - " not found in " + + std::string error_msg = ("ERROR swaths group not found in " + file.getFileName()); throw isce3::except::RuntimeError(ISCE_SRCINFO(), error_msg); } - // Get swaths/grids group + // Get swaths group isce3::io::IGroup imGroup = file.openGroup(image_group_str); // Configure swaths - if (flag_has_swaths) { - loadFromH5(imGroup, _swaths); - } - /* - Not implemented yet: - else { - loadFromH5(imGroup, _grids); - } - */ + loadFromH5(imGroup, _swaths); // Get metadata group isce3::io::IGroup metaGroup = file.openGroup(metadata_group_str); @@ -114,7 +121,7 @@ Product(isce3::io::IH5File & file) { loadFromH5(metaGroup, _metadata); // Get look direction - auto identification_vector = _findGroupPath(base_group, "identification"); + auto identification_vector = findGroupPath(base_group, "identification"); if (identification_vector.size() == 0) { std::string error_msg = ("ERROR identification group not found in " + file.getFileName()); @@ -135,3 +142,5 @@ Product(isce3::io::IH5File & file) { // Save the filename _filename = file.filename(); } + +}} \ No newline at end of file diff --git a/cxx/isce3/product/RadarGridProduct.h b/cxx/isce3/product/RadarGridProduct.h new file mode 100644 index 000000000..1802b6bda --- /dev/null +++ b/cxx/isce3/product/RadarGridProduct.h @@ -0,0 +1,121 @@ +// -*- C++ -*- +// -*- coding: utf-8 -*- +// +// Source Author: Bryan Riel +// Copyright 2017-2018 + +#pragma once + +// std +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace isce3 { namespace product { + +/** Find unique group path excluding repeated occurrences +*/ + + +/** + * Return the path to each child group of `group` that ends with the substring + * `group_name`. + * + * \param[in] group Parent group + * \param[in] group_name Search string + * \returns List of child group paths + */ +std::vector findGroupPath( + isce3::io::IGroup& group, const std::string& group_name); + +/** + * Return grids or swaths group paths within the base_group. + * Start by assigning an empty string to image_group_str in case + * grids and swaths group are not found. + * + * \param[in] file + * \param[in] base_dir Path to `base_group` object (e.g. '/science/') + * \param[in] base_group Base group + * \param[in] key_vector Vector containing possible image groups + * (e.g., 'swaths', 'grids', or both) to look for + * \param[out] image_group_str Path to first image group found containing + * one of the `key_vector` keys (e.g., '/science/LSAR/RSLC/swaths') + * \param[in] metadata_group_str Path to first metadata group found by + * substituting `key` with `metadata` in `image_group_str` + * (e.g., '/science/LSAR/RSLC/metadata') + */ +void setImageMetadataGroupStr( + isce3::io::IH5File & file, + std::string& base_dir, + isce3::io::IGroup& base_group, + std::vector& key_vector, + std::string &image_group_str, + std::string &metadata_group_str); + +/** RadarGridProduct class declaration + * + * The Produt attribute Swaths map, i.e. _swaths, associates the + * frequency (key) with the Swath object (value). The RadarGridProduct object + * is usually initiated with an empty map and the serialization of + * the SAR product is responsible for populating the Swath map + * from the product's metadata. + * + */ +class RadarGridProduct { + + public: + /** Constructor from IH5File object. */ + RadarGridProduct(isce3::io::IH5File &); + /** Constructor with Metadata and Swath map. */ + inline RadarGridProduct(const Metadata &, const std::map &); + + /** Get a read-only reference to the metadata */ + inline const Metadata & metadata() const { return _metadata; } + /** Get a reference to the metadata. */ + inline Metadata & metadata() { return _metadata; } + + /** Get a read-only reference to a swath */ + inline const Swath & swath(char freq) const { return _swaths.at(freq); } + /** Get a reference to a swath */ + inline Swath & swath(char freq) { return _swaths[freq]; } + /** Set a swath */ + inline void swath(const Swath & s, char freq) { _swaths[freq] = s; } + + /** Get the look direction */ + inline isce3::core::LookSide lookSide() const { return _lookSide; } + /** Set look direction using enum */ + inline void lookSide(isce3::core::LookSide side) { _lookSide = side; } + /** Set look direction from a string */ + inline void lookSide(const std::string &); + + /** Get the filename of the HDF5 file. */ + inline std::string filename() const { return _filename; } + + private: + isce3::product::Metadata _metadata; + std::map _swaths; + std::string _filename; + isce3::core::LookSide _lookSide; +}; + +/** @param[in] meta Metadata object + * @param[in] swaths Map of Swath objects per frequency */ +isce3::product::RadarGridProduct:: +RadarGridProduct(const Metadata & meta, const std::map & swaths) : + _metadata(meta), _swaths(swaths) {} + +/** @param[in] look String representation of look side */ +void +isce3::product::RadarGridProduct:: +lookSide(const std::string & inputLook) { + _lookSide = isce3::core::parseLookSide(inputLook); +} + +}} diff --git a/cxx/isce3/product/Serialization.h b/cxx/isce3/product/Serialization.h index 4b6050c25..39eb1337b 100644 --- a/cxx/isce3/product/Serialization.h +++ b/cxx/isce3/product/Serialization.h @@ -22,6 +22,7 @@ // isce3::product #include #include +#include //! The isce namespace namespace isce3 { @@ -128,12 +129,129 @@ namespace isce3 { * @param[in] group HDF5 group object. * @param[in] swaths Map of Swaths to be configured. */ inline void loadFromH5(isce3::io::IGroup & group, std::map & swaths) { - loadFromH5(group, swaths['A'], 'A'); + if (isce3::io::exists(group, "frequencyA")) { + loadFromH5(group, swaths['A'], 'A'); + } if (isce3::io::exists(group, "frequencyB")) { loadFromH5(group, swaths['B'], 'B'); } } + /** Load Grid from HDF5 + * + * @param[in] group HDF5 group object. + * @param[in] grid Grid object to be configured. + * @param[in] freq Frequency designation (e.g., A or B) */ + inline void loadFromH5(isce3::io::IGroup & group, Grid & grid, char freq) { + + // Open appropriate frequency group + std::string freqString("frequency"); + freqString.push_back(freq); + isce3::io::IGroup fgroup = group.openGroup(freqString); + + // Load X-coordinates + std::valarray x_array; + isce3::io::loadFromH5(fgroup, "xCoordinates", x_array); + grid.startX(x_array[0]); + grid.width(x_array.size()); + + // Load Y-coordinates + std::valarray y_array; + isce3::io::loadFromH5(fgroup, "yCoordinates", y_array); + grid.startY(y_array[0]); + grid.length(y_array.size()); + + // Get X-coordinate spacing + double value; + isce3::io::loadFromH5(fgroup, "xCoordinateSpacing", value); + grid.spacingX(value); + + isce3::io::loadFromH5(fgroup, "yCoordinateSpacing", value); + grid.spacingY(value); + + isce3::io::loadFromH5(fgroup, "rangeBandwidth", value); + grid.rangeBandwidth(value); + + isce3::io::loadFromH5(fgroup, "azimuthBandwidth", value); + grid.azimuthBandwidth(value); + + isce3::io::loadFromH5(fgroup, "centerFrequency", value); + grid.centerFrequency(value); + + isce3::io::loadFromH5(fgroup, "slantRangeSpacing", value); + grid.slantRangeSpacing(value); + + auto zero_dop_freq_vect = fgroup.find("zeroDopplerTimeSpacing", + ".", "DATASET"); + + /* + Look for zeroDopplerTimeSpacing in frequency group + (GCOV and GSLC products) + */ + if (zero_dop_freq_vect.size() > 0) { + isce3::io::loadFromH5(fgroup, "zeroDopplerTimeSpacing", value); + grid.zeroDopplerTimeSpacing(value); + } else { + + /* + Look for zeroDopplerTimeSpacing within grid group + (GUNW products) + */ + auto zero_dop_vect = group.find("zeroDopplerTimeSpacing", + ".", "DATASET"); + if (zero_dop_vect.size() > 0) { + isce3::io::loadFromH5(group, "zeroDopplerTimeSpacing", value); + grid.zeroDopplerTimeSpacing(value); + } else { + grid.zeroDopplerTimeSpacing( + std::numeric_limits::quiet_NaN()); + } + } + + auto epsg_freq_vect = fgroup.find("epsg", ".", "DATASET"); + + int epsg = -1; + // Look for epsg in frequency group + if (epsg_freq_vect.size() > 0) { + isce3::io::loadFromH5(fgroup, "epsg", epsg); + grid.epsg(epsg); + } else { + + // Look for epsg in dataset projection + auto projection_vect = fgroup.find("projection", + ".", "DATASET"); + if (projection_vect.size() > 0) { + for (auto projection_str: projection_vect) { + auto projection_obj = fgroup.openDataSet(projection_str); + if (projection_obj.attrExists("epsg_code")) { + auto attr = projection_obj.openAttribute("epsg_code"); + attr.read(getH5Type(), &epsg); + grid.epsg(epsg); + } + } + } + } + + if (epsg == -1) { + throw isce3::except::RuntimeError(ISCE_SRCINFO(), + "ERROR could not infer EPSG code from input HDF5 file"); + } + + } + + /** Load multiple grids from HDF5 + * + * @param[in] group HDF5 group object. + * @param[in] grids Map of Grids to be configured. */ + inline void loadFromH5(isce3::io::IGroup & group, std::map & grids) { + if (isce3::io::exists(group, "frequencyA")) { + loadFromH5(group, grids['A'], 'A'); + } + if (isce3::io::exists(group, "frequencyB")) { + loadFromH5(group, grids['B'], 'B'); + } + } + /** Load Metadata parameters from HDF5. * * @param[in] group HDF5 group object. diff --git a/cxx/isce3/product/forward.h b/cxx/isce3/product/forward.h index d8a5f166d..5e112440b 100644 --- a/cxx/isce3/product/forward.h +++ b/cxx/isce3/product/forward.h @@ -2,9 +2,11 @@ namespace isce3 { namespace product { - class Product; + class RadarGridProduct; + class GeoGridProduct; class RadarGridParameters; class GeoGridParameters; class Swath; + class Grid; }} diff --git a/cxx/isce3/signal/Crossmul.cpp b/cxx/isce3/signal/Crossmul.cpp index 0a7bf442f..5f0af5fca 100644 --- a/cxx/isce3/signal/Crossmul.cpp +++ b/cxx/isce3/signal/Crossmul.cpp @@ -22,9 +22,9 @@ size_t omp_thread_count() { /* isce3::signal::Crossmul:: -Crossmul(const isce3::product::Product& referenceSlcProduct, - const isce3::product::Product& secondarySlcProduct, - isce3::product::Product& outputInterferogramProduct) +Crossmul(const isce3::product::RadarGridProduct& referenceSlcProduct, + const isce3::product::RadarGridProduct& secondarySlcProduct, + isce3::product::RadarGridProduct& outputInterferogramProduct) */ /** diff --git a/cxx/isce3/signal/Crossmul.h b/cxx/isce3/signal/Crossmul.h index 407229a05..fb56876c0 100644 --- a/cxx/isce3/signal/Crossmul.h +++ b/cxx/isce3/signal/Crossmul.h @@ -25,9 +25,9 @@ class isce3::signal::Crossmul { ~Crossmul() {}; /* - void Crossmul(const isce3::product::Product& referenceSLC, - const isce3::product::Product& secondarySLC, - const isce3::product::Product& outputInterferogram); + void Crossmul(const isce3::product::RadarGridProduct& referenceSLC, + const isce3::product::RadarGridProduct& secondarySLC, + const isce3::product::RadarGridProduct& outputInterferogram); */ diff --git a/cxx/isce3/unwrap/ortools/LICENSE-2.0.txt b/cxx/isce3/unwrap/ortools/LICENSE-2.0.txt new file mode 100644 index 000000000..0f3822be4 --- /dev/null +++ b/cxx/isce3/unwrap/ortools/LICENSE-2.0.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2010 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/cxx/isce3/unwrap/ortools/ebert_graph.h b/cxx/isce3/unwrap/ortools/ebert_graph.h new file mode 100644 index 000000000..285dea0b2 --- /dev/null +++ b/cxx/isce3/unwrap/ortools/ebert_graph.h @@ -0,0 +1,2131 @@ +// Copyright 2010-2021 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_GRAPH_EBERT_GRAPH_H_ +#define OR_TOOLS_GRAPH_EBERT_GRAPH_H_ + +// A few variations on a theme of the "star" graph representation by +// Ebert, as described in J. Ebert, "A versatile data structure for +// edge-oriented graph algorithms." Communications of the ACM +// 30(6):513-519 (June 1987). +// http://portal.acm.org/citation.cfm?id=214769 +// +// In this file there are three representations that have much in +// common. The general one, called simply EbertGraph, contains both +// forward- and backward-star representations. The other, called +// ForwardEbertGraph, contains only the forward-star representation of +// the graph, and is appropriate for applications where the reverse +// arcs are not needed. +// +// The point of including all the representations in this one file is +// to capitalize, where possible, on the commonalities among them, and +// those commonalities are mostly factored out into base classes as +// described below. Despite the commonalities, however, each of the +// three representations presents a somewhat different interface +// because of their different underlying semantics. A quintessential +// example is that the AddArc() method, very natural for the +// EbertGraph representation, cannot exist for an inherently static +// representation like ForwardStaticGraph. +// +// Many clients are expected to use the interfaces to the graph +// objects directly, but some clients are parameterized by graph type +// and need a consistent interface for their underlying graph +// objects. For such clients, a small library of class templates is +// provided to give a consistent interface to clients where the +// underlying graph interfaces differ. Examples are the +// AnnotatedGraphBuildManager<> template, which provides a uniform +// interface for building the various types of graphs; and the +// TailArrayManager<> template, which provides a uniform interface for +// applications that need to map from arc indices to arc tail nodes, +// accounting for the fact that such a mapping has to be requested +// explicitly from the ForwardStaticGraph and ForwardStarGraph +// representations. +// +// There are two base class templates, StarGraphBase, and +// EbertGraphBase; their purpose is to hold methods and data +// structures that are in common among their descendants. Only classes +// that are leaves in the following hierarchy tree are eligible for +// free-standing instantiation and use by clients. The parentheses +// around StarGraphBase and EbertGraphBase indicate that they should +// not normally be instantiated by clients: +// +// (StarGraphBase) | +// / \ | +// / \ | +// / \ | +// / \ | +// (EbertGraphBase) ForwardStaticGraph | +// / \ | +// / \ | +// EbertGraph ForwardEbertGraph | +// +// In the general EbertGraph case, the graph is represented with three +// arrays. +// Let n be the number of nodes and m be the number of arcs. +// Let i be an integer in [0..m-1], denoting the index of an arc. +// * head_[i] contains the end-node of arc i, +// * head_[-i-1] contains the start-node of arc i. +// Note that in two's-complement arithmetic, -i-1 = ~i. +// Consequently: +// * head_[~i] contains the end-node of the arc reverse to arc i, +// * head_[i] contains the start-node of the arc reverse to arc i. +// Note that if arc (u, v) is defined, then the data structure also stores +// (v, u). +// Arc ~i thus denotes the arc reverse to arc i. +// This is what makes this representation useful for undirected graphs and for +// implementing algorithms like bidirectional shortest paths. +// Also note that the representation handles multi-graphs. If several arcs +// going from node u to node v are added to the graph, they will be handled as +// separate arcs. +// +// Now, for an integer u in [0..n-1] denoting the index of a node: +// * first_incident_arc_[u] denotes the first arc in the adjacency list of u. +// * going from an arc i, the adjacency list can be traversed using +// j = next_adjacent_arc_[i]. +// +// The EbertGraph implementation has the following benefits: +// * It is able to handle both directed or undirected graphs. +// * Being based on indices, it is easily serializable. Only the contents +// of the head_ array need to be stored. Even so, serialization is +// currently not implemented. +// * The node indices and arc indices can be stored in 32 bits, while +// still allowing to go a bit further than the 4-gigabyte +// limitation (48 gigabytes for a pure graph, without capacities or +// costs.) +// * The representation can be recomputed if edges have been loaded from +// * The representation can be recomputed if edges have been loaded from +// external memory or if edges have been re-ordered. +// * The memory consumption is: 2 * m * sizeof(NodeIndexType) +// + 2 * m * sizeof(ArcIndexType) +// + n * sizeof(ArcIndexType) +// plus a small constant. +// +// The EbertGraph implementation differs from the implementation described in +// [Ebert 1987] in the following respects: +// * arcs are represented using an (i, ~i) approach, whereas Ebert used +// (i, -i). Indices for direct arcs thus start at 0, in a fashion that is +// compatible with the index numbering in C and C++. Note that we also tested +// a (2*i, 2*i+1) storage pattern, which did not show any speed benefit, and +// made the use of the API much more difficult. +// * because of this, the 'nil' values for nodes and arcs are not 0, as Ebert +// first described. The value for the 'nil' node is set to -1, while the +// value for the 'nil' arc is set to the smallest integer representable with +// ArcIndexSize bytes. +// * it is possible to add arcs to the graph, with AddArc, in a much simpler +// way than described by Ebert. +// * TODO(user) although it is already possible, using the +// GroupForwardArcsByFunctor method, to group all the outgoing (resp. +// incoming) arcs of a node, the iterator logic could still be improved to +// allow traversing the outgoing (resp. incoming) arcs in O(out_degree(node)) +// (resp. O(in_degree(node))) instead of O(degree(node)). +// * TODO(user) it is possible to implement arc deletion and garbage collection +// in an efficient (relatively) manner. For the time being we haven't seen an +// application for this. +// +// The ForwardEbertGraph representation is like the EbertGraph case described +// above, with the following modifications: +// * The part of the head_[] array with negative indices is absent. In its +// place is a pointer tail_ which, if assigned, points to an array of tail +// nodes indexed by (nonnegative) arc index. In typical usage tail_ is NULL +// and the memory for the tail nodes need not be allocated. +// * The array of arc tails can be allocated as needed and populated from the +// adjacency lists of the graph. +// * Representing only the forward star of each node implies that the graph +// cannot be serialized directly nor rebuilt from scratch from just the head_ +// array. Rebuilding from scratch requires constructing the array of arc +// tails from the adjacency lists first, and serialization can be done either +// by first constructing the array of arc tails from the adjacency lists, or +// by serializing directly from the adjacency lists. +// * The memory consumption is: m * sizeof(NodeIndexType) +// + m * sizeof(ArcIndexType) +// + n * sizeof(ArcIndexType) +// plus a small constant when the array of arc tails is absent. Allocating +// the arc tail array adds another m * sizeof(NodeIndexType). +// +// The ForwardStaticGraph representation is restricted yet farther +// than ForwardEbertGraph, with the benefit that it provides higher +// performance to those applications that can use it. +// * As with ForwardEbertGraph, the presence of the array of arc +// tails is optional. +// * The outgoing adjacency list for each node is stored in a +// contiguous segment of the head_[] array, obviating the +// next_adjacent_arc_ structure entirely and ensuring good locality +// of reference for applications that iterate over outgoing +// adjacency lists. +// * The memory consumption is: m * sizeof(NodeIndexType) +// + n * sizeof(ArcIndexType) +// plus a small constant when the array of arc tails is absent. Allocating +// the arc tail array adds another m * sizeof(NodeIndexType). + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "permutation.h" +#include "zvector.h" + +namespace operations_research { + +// Forward declarations. +template +class EbertGraph; +template +class ForwardEbertGraph; +template +class ForwardStaticGraph; + +// Standard instantiation of ForwardEbertGraph (named 'ForwardStarGraph') of +// EbertGraph (named 'StarGraph'); and relevant type shortcuts. Unless their use +// cases prevent them from doing so, users are encouraged to use StarGraph or +// ForwardStarGraph according to whether or not they require reverse arcs to be +// represented explicitly. Along with either graph representation, the other +// type shortcuts here will often come in handy. +typedef int32_t NodeIndex; +typedef int32_t ArcIndex; +typedef int64_t FlowQuantity; +typedef int64_t CostValue; +typedef EbertGraph StarGraph; +typedef ForwardEbertGraph ForwardStarGraph; +typedef ForwardStaticGraph ForwardStarStaticGraph; +typedef ZVector NodeIndexArray; +typedef ZVector ArcIndexArray; +typedef ZVector QuantityArray; +typedef ZVector CostArray; + +template +class StarGraphBase { + public: + // The index of the 'nil' node in the graph. + static const NodeIndexType kNilNode; + + // The index of the 'nil' arc in the graph. + static const ArcIndexType kNilArc; + + // The index of the first node in the graph. + static const NodeIndexType kFirstNode; + + // The index of the first arc in the graph. + static const ArcIndexType kFirstArc; + + // The maximum possible number of nodes in the graph. (The maximum + // index is kMaxNumNodes-1, since indices start at 0. Unfortunately + // we waste a value representing this and the max_num_nodes_ member.) + static const NodeIndexType kMaxNumNodes; + + // The maximum possible number of arcs in the graph. (The maximum + // index is kMaxNumArcs-1, since indices start at 0. Unfortunately + // we waste a value representing this and the max_num_arcs_ member.) + static const ArcIndexType kMaxNumArcs; + // Returns the number of nodes in the graph. + NodeIndexType num_nodes() const { return num_nodes_; } + + // Returns the number of original arcs in the graph + // (The ones with positive indices.) + ArcIndexType num_arcs() const { return num_arcs_; } + + // Returns one more than the largest index of an extant node, + // meaning a node that is mentioned as the head or tail of some arc + // in the graph. To be used as a helper when clients need to + // dimension or iterate over arrays of node annotation information. + NodeIndexType end_node_index() const { return kFirstNode + num_nodes_; } + + // Returns one more than the largest index of an extant direct + // arc. To be used as a helper when clients need to dimension or + // iterate over arrays of arc annotation information. + ArcIndexType end_arc_index() const { return kFirstArc + num_arcs_; } + + // Returns the maximum possible number of nodes in the graph. + NodeIndexType max_num_nodes() const { return max_num_nodes_; } + + // Returns the maximum possible number of original arcs in the graph. + // (The ones with positive indices.) + ArcIndexType max_num_arcs() const { return max_num_arcs_; } + + // Returns one more than the largest valid index of a node. To be + // used as a helper when clients need to dimension or iterate over + // arrays of node annotation information. + NodeIndexType max_end_node_index() const { + return kFirstNode + max_num_nodes_; + } + + // Returns one more than the largest valid index of a direct arc. To + // be used as a helper when clients need to dimension or iterate + // over arrays of arc annotation information. + ArcIndexType max_end_arc_index() const { return kFirstArc + max_num_arcs_; } + + // Utility function to check that a node index is within the bounds AND + // different from kNilNode. + // Returns true if node is in the range [kFirstNode .. max_num_nodes_). + // It is exported so that users of the DerivedGraph class can use it. + // To be used in a DCHECK; also used internally to validate + // arguments passed to our methods from clients (e.g., AddArc()). + bool IsNodeValid(NodeIndexType node) const { + return node >= kFirstNode && node < max_num_nodes_; + } + + // Returns the first arc going from tail to head, if it exists, or kNilArc + // if such an arc does not exist. + ArcIndexType LookUpArc(const NodeIndexType tail, + const NodeIndexType head) const { + for (ArcIndexType arc = FirstOutgoingArc(tail); arc != kNilArc; + arc = ThisAsDerived()->NextOutgoingArc(tail, arc)) { + if (Head(arc) == head) { + return arc; + } + } + return kNilArc; + } + + // Returns the head or end-node of arc. + NodeIndexType Head(const ArcIndexType arc) const { + assert(ThisAsDerived()->CheckArcValidity(arc)); + return head_[arc]; + } + + std::string NodeDebugString(const NodeIndexType node) const { + if (node == kNilNode) { + return "NilNode"; + } else { + return std::to_string(static_cast(node)); + } + } + + std::string ArcDebugString(const ArcIndexType arc) const { + if (arc == kNilArc) { + return "NilArc"; + } else { + return std::to_string(static_cast(arc)); + } + } + + // Iterator class for traversing all the nodes in the graph. + class NodeIterator { + public: + explicit NodeIterator(const DerivedGraph& graph) + : graph_(graph), head_(graph_.StartNode(kFirstNode)) {} + + // Returns true unless all the nodes have been traversed. + bool Ok() const { return head_ != kNilNode; } + + // Advances the current node index. + void Next() { head_ = graph_.NextNode(head_); } + + // Returns the index of the node currently pointed to by the iterator. + NodeIndexType Index() const { return head_; } + + private: + // A reference to the current DerivedGraph considered. + const DerivedGraph& graph_; + + // The index of the current node considered. + NodeIndexType head_; + }; + + // Iterator class for traversing the arcs in the graph. + class ArcIterator { + public: + explicit ArcIterator(const DerivedGraph& graph) + : graph_(graph), arc_(graph_.StartArc(kFirstArc)) {} + + // Returns true unless all the arcs have been traversed. + bool Ok() const { return arc_ != kNilArc; } + + // Advances the current arc index. + void Next() { arc_ = graph_.NextArc(arc_); } + + // Returns the index of the arc currently pointed to by the iterator. + ArcIndexType Index() const { return arc_; } + + private: + // A reference to the current DerivedGraph considered. + const DerivedGraph& graph_; + + // The index of the current arc considered. + ArcIndexType arc_; + }; + + // Iterator class for traversing the outgoing arcs associated to a given node. + class OutgoingArcIterator { + public: + OutgoingArcIterator(const DerivedGraph& graph, NodeIndexType node) + : graph_(graph), + node_(graph_.StartNode(node)), + arc_(graph_.StartArc(graph_.FirstOutgoingArc(node))) { + assert(CheckInvariant()); + } + + // This constructor takes an arc as extra argument and makes the iterator + // start at arc. + OutgoingArcIterator(const DerivedGraph& graph, NodeIndexType node, + ArcIndexType arc) + : graph_(graph), + node_(graph_.StartNode(node)), + arc_(graph_.StartArc(arc)) { + assert(CheckInvariant()); + } + + // Can only assign from an iterator on the same graph. + void operator=(const OutgoingArcIterator& iterator) { + assert(&iterator.graph_ == &graph_); + node_ = iterator.node_; + arc_ = iterator.arc_; + } + + // Returns true unless all the outgoing arcs have been traversed. + bool Ok() const { return arc_ != kNilArc; } + + // Advances the current outgoing arc index. + void Next() { + arc_ = graph_.NextOutgoingArc(node_, arc_); + assert(CheckInvariant()); + } + + // Returns the index of the arc currently pointed to by the iterator. + ArcIndexType Index() const { return arc_; } + + private: + // Returns true if the invariant for the iterator is verified. + // To be used in a DCHECK. + bool CheckInvariant() const { + if (arc_ == kNilArc) { + return true; // This occurs when the iterator has reached the end. + } + assert(graph_.IsOutgoing(arc_, node_)); + return true; + } + + // A reference to the current DerivedGraph considered. + const DerivedGraph& graph_; + + // The index of the node on which arcs are iterated. + NodeIndexType node_; + + // The index of the current arc considered. + ArcIndexType arc_; + }; + + protected: + StarGraphBase() + : max_num_nodes_(0), + max_num_arcs_(0), + num_nodes_(0), + num_arcs_(0), + first_incident_arc_() {} + + ~StarGraphBase() {} + + // Returns kNilNode if the graph has no nodes or node if it has at least one + // node. Useful for initializing iterators correctly in the case of empty + // graphs. + NodeIndexType StartNode(NodeIndexType node) const { + return num_nodes_ == 0 ? kNilNode : node; + } + + // Returns kNilArc if the graph has no arcs arc if it has at least one arc. + // Useful for initializing iterators correctly in the case of empty graphs. + ArcIndexType StartArc(ArcIndexType arc) const { + return num_arcs_ == 0 ? kNilArc : arc; + } + + // Returns the node following the argument in the graph. + // Returns kNilNode (= end) if the range of nodes has been exhausted. + // It is called by NodeIterator::Next() and as such does not expect to be + // passed an argument equal to kNilNode. + // This is why the return line is simplified from + // return (node == kNilNode || next_node >= num_nodes_) + // ? kNilNode : next_node; + // to + // return next_node < num_nodes_ ? next_node : kNilNode; + NodeIndexType NextNode(const NodeIndexType node) const { + assert(IsNodeValid(node)); + const NodeIndexType next_node = node + 1; + return next_node < num_nodes_ ? next_node : kNilNode; + } + + // Returns the arc following the argument in the graph. + // Returns kNilArc (= end) if the range of arcs has been exhausted. + // It is called by ArcIterator::Next() and as such does not expect to be + // passed an argument equal to kNilArc. + // This is why the return line is simplified from + // return ( arc == kNilArc || next_arc >= num_arcs_) ? kNilArc : next_arc; + // to + // return next_arc < num_arcs_ ? next_arc : kNilArc; + ArcIndexType NextArc(const ArcIndexType arc) const { + assert(ThisAsDerived()->CheckArcValidity(arc)); + const ArcIndexType next_arc = arc + 1; + return next_arc < num_arcs_ ? next_arc : kNilArc; + } + + // Returns the first outgoing arc for node. + ArcIndexType FirstOutgoingArc(const NodeIndexType node) const { + assert(IsNodeValid(node)); + return ThisAsDerived()->FindNextOutgoingArc( + ThisAsDerived()->FirstOutgoingOrOppositeIncomingArc(node)); + } + + // The maximum number of nodes that the graph can hold. + NodeIndexType max_num_nodes_; + + // The maximum number of arcs that the graph can hold. + ArcIndexType max_num_arcs_; + + // The maximum index of the node currently held by the graph. + NodeIndexType num_nodes_; + + // The current number of arcs held by the graph. + ArcIndexType num_arcs_; + + // Array of node indices. head_[i] contains the head node of arc i. + ZVector head_; + + // Array of arc indices. first_incident_arc_[i] contains the first arc + // incident to node i. + ZVector first_incident_arc_; + + private: + // Shorthand: returns a const DerivedGraph*-typed version of our + // "this" pointer. + inline const DerivedGraph* ThisAsDerived() const { + return static_cast(this); + } + + // Shorthand: returns a DerivedGraph*-typed version of our "this" + // pointer. + inline DerivedGraph* ThisAsDerived() { + return static_cast(this); + } +}; + +template +class PermutationIndexComparisonByArcHead { + public: + explicit PermutationIndexComparisonByArcHead( + const ZVector& head) + : head_(head) {} + + bool operator()(ArcIndexType a, ArcIndexType b) const { + return head_[a] < head_[b]; + } + + private: + const ZVector& head_; +}; + +template +class ForwardStaticGraph + : public StarGraphBase > { + typedef StarGraphBase > + Base; + friend class StarGraphBase >; + + using Base::ArcDebugString; + using Base::NodeDebugString; + + using Base::first_incident_arc_; + using Base::head_; + using Base::max_num_arcs_; + using Base::max_num_nodes_; + using Base::num_arcs_; + using Base::num_nodes_; + + public: + using Base::end_arc_index; + using Base::Head; + using Base::IsNodeValid; + + using Base::kFirstArc; + using Base::kFirstNode; + using Base::kNilArc; + + typedef NodeIndexType NodeIndex; + typedef ArcIndexType ArcIndex; + + class CycleHandlerForAnnotatedArcs + : public ArrayIndexCycleHandler { + typedef ArrayIndexCycleHandler Base; + + public: + CycleHandlerForAnnotatedArcs( + PermutationCycleHandler* annotation_handler, + NodeIndexType* data) + : ArrayIndexCycleHandler(&data[kFirstArc]), + annotation_handler_(annotation_handler) {} + + void SetTempFromIndex(ArcIndexType source) override { + Base::SetTempFromIndex(source); + annotation_handler_->SetTempFromIndex(source); + } + + void SetIndexFromIndex(ArcIndexType source, + ArcIndexType destination) const override { + Base::SetIndexFromIndex(source, destination); + annotation_handler_->SetIndexFromIndex(source, destination); + } + + void SetIndexFromTemp(ArcIndexType destination) const override { + Base::SetIndexFromTemp(destination); + annotation_handler_->SetIndexFromTemp(destination); + } + + private: + PermutationCycleHandler* annotation_handler_; + + CycleHandlerForAnnotatedArcs(const CycleHandlerForAnnotatedArcs&); + CycleHandlerForAnnotatedArcs& operator=(const CycleHandlerForAnnotatedArcs&); + }; + + // Constructor for use by GraphBuilderFromArcs instances and direct + // clients that want to materialize a graph in one step. + // Materializing all at once is the only choice available with a + // static graph. + // + // Args: + // sort_arcs_by_head: determines whether arcs incident to each tail + // node are sorted by head node. + // client_cycle_handler: if non-NULL, mediates the permutation of + // arbitrary annotation data belonging to the client according + // to the permutation applied to the arcs in forming the + // graph. Two permutations may be composed to form the final one + // that affects the arcs. First, the arcs are always permuted to + // group them by tail node because ForwardStaticGraph requires + // this. Second, if each node's outgoing arcs are sorted by head + // node (according to sort_arcs_by_head), that sorting implies + // an additional permutation on the arcs. + ForwardStaticGraph( + const NodeIndexType num_nodes, const ArcIndexType num_arcs, + const bool sort_arcs_by_head, + std::vector >* client_input_arcs, + operations_research::PermutationCycleHandler* const + client_cycle_handler) { + max_num_arcs_ = num_arcs; + num_arcs_ = num_arcs; + max_num_nodes_ = num_nodes; + // A more convenient name for a parameter required by style to be + // a pointer, because we modify its referent. + std::vector >& input_arcs = + *client_input_arcs; + + // We coopt the first_incident_arc_ array as a node-indexed vector + // used for two purposes related to degree before setting up its + // final values. First, it counts the out-degree of each + // node. Second, it is reused to count the number of arcs outgoing + // from each node that have already been put in place from the + // given input_arcs. We reserve an extra entry as a sentinel at + // the end. + first_incident_arc_.Reserve(kFirstNode, kFirstNode + num_nodes); + first_incident_arc_.SetAll(0); + for (ArcIndexType arc = kFirstArc; arc < kFirstArc + num_arcs; ++arc) { + first_incident_arc_[kFirstNode + input_arcs[arc].first] += 1; + // Take this opportunity to see how many nodes are really + // mentioned in the arc list. + num_nodes_ = std::max( + num_nodes_, static_cast(input_arcs[arc].first + 1)); + num_nodes_ = std::max( + num_nodes_, static_cast(input_arcs[arc].second + 1)); + } + ArcIndexType next_arc = kFirstArc; + for (NodeIndexType node = 0; node < num_nodes; ++node) { + ArcIndexType degree = first_incident_arc_[kFirstNode + node]; + first_incident_arc_[kFirstNode + node] = next_arc; + next_arc += degree; + } + assert(num_arcs == next_arc); + head_.Reserve(kFirstArc, kFirstArc + num_arcs - 1); + std::unique_ptr arc_permutation; + if (client_cycle_handler != nullptr) { + arc_permutation.reset(new ArcIndexType[end_arc_index()]); + for (ArcIndexType input_arc = 0; input_arc < num_arcs; ++input_arc) { + NodeIndexType tail = input_arcs[input_arc].first; + NodeIndexType head = input_arcs[input_arc].second; + ArcIndexType arc = first_incident_arc_[kFirstNode + tail]; + // The head_ entry will get permuted into the right place + // later. + head_[kFirstArc + input_arc] = kFirstNode + head; + arc_permutation[kFirstArc + arc] = input_arc; + first_incident_arc_[kFirstNode + tail] += 1; + } + } else { + if (sizeof(input_arcs[0].first) >= sizeof(first_incident_arc_[0])) { + // We reuse the input_arcs[].first entries to hold our + // mapping to the head_ array. This allows us to spread out + // cache badness. + for (ArcIndexType input_arc = 0; input_arc < num_arcs; ++input_arc) { + NodeIndexType tail = input_arcs[input_arc].first; + ArcIndexType arc = first_incident_arc_[kFirstNode + tail]; + first_incident_arc_[kFirstNode + tail] = arc + 1; + input_arcs[input_arc].first = static_cast(arc); + } + for (ArcIndexType input_arc = 0; input_arc < num_arcs; ++input_arc) { + ArcIndexType arc = + static_cast(input_arcs[input_arc].first); + NodeIndexType head = input_arcs[input_arc].second; + head_[kFirstArc + arc] = kFirstNode + head; + } + } else { + // We cannot reuse the input_arcs[].first entries so we map to + // the head_ array in a single loop. + for (ArcIndexType input_arc = 0; input_arc < num_arcs; ++input_arc) { + NodeIndexType tail = input_arcs[input_arc].first; + NodeIndexType head = input_arcs[input_arc].second; + ArcIndexType arc = first_incident_arc_[kFirstNode + tail]; + first_incident_arc_[kFirstNode + tail] = arc + 1; + head_[kFirstArc + arc] = kFirstNode + head; + } + } + } + // Shift the entries in first_incident_arc_ to compensate for the + // counting each one has done through its incident arcs. Note that + // there is a special sentry element at the end of + // first_incident_arc_. + for (NodeIndexType node = kFirstNode + num_nodes; node > /* kFirstNode */ 0; + --node) { + first_incident_arc_[node] = first_incident_arc_[node - 1]; + } + first_incident_arc_[kFirstNode] = kFirstArc; + if (sort_arcs_by_head) { + ArcIndexType begin = first_incident_arc_[kFirstNode]; + if (client_cycle_handler != nullptr) { + for (NodeIndexType node = 0; node < num_nodes; ++node) { + ArcIndexType end = first_incident_arc_[node + 1]; + std::sort( + &arc_permutation[begin], &arc_permutation[end], + PermutationIndexComparisonByArcHead( + head_)); + begin = end; + } + } else { + for (NodeIndexType node = 0; node < num_nodes; ++node) { + ArcIndexType end = first_incident_arc_[node + 1]; + // The second argument in the following has a strange index + // expression because ZVector claims that no index is valid + // unless it refers to an element in the vector. In particular + // an index one past the end is invalid. + ArcIndexType begin_index = (begin < num_arcs ? begin : begin - 1); + ArcIndexType begin_offset = (begin < num_arcs ? 0 : 1); + ArcIndexType end_index = (end > 0 ? end - 1 : end); + ArcIndexType end_offset = (end > 0 ? 1 : 0); + std::sort(&head_[begin_index] + begin_offset, + &head_[end_index] + end_offset); + begin = end; + } + } + } + if (client_cycle_handler != nullptr && num_arcs > 0) { + // Apply the computed permutation if we haven't already. + CycleHandlerForAnnotatedArcs handler_for_constructor( + client_cycle_handler, &head_[kFirstArc] - kFirstArc); + // We use a permutation cycle handler to place the head array + // indices and permute the client's arc annotation data along + // with them. + PermutationApplier permutation(&handler_for_constructor); + permutation.Apply(&arc_permutation[0], kFirstArc, end_arc_index()); + } + } + + // Returns the tail or start-node of arc. + NodeIndexType Tail(const ArcIndexType arc) const { + assert(CheckArcValidity(arc)); + assert(CheckTailIndexValidity(arc)); + return (*tail_)[arc]; + } + + // Returns true if arc is incoming to node. + bool IsIncoming(ArcIndexType arc, NodeIndexType node) const { + return Head(arc) == node; + } + + // Utility function to check that an arc index is within the bounds. + // It is exported so that users of the ForwardStaticGraph class can use it. + // To be used in a DCHECK. + bool CheckArcBounds(const ArcIndexType arc) const { + return ((arc == kNilArc) || (arc >= kFirstArc && arc < max_num_arcs_)); + } + + // Utility function to check that an arc index is within the bounds AND + // different from kNilArc. + // It is exported so that users of the ForwardStaticGraph class can use it. + // To be used in a DCHECK. + bool CheckArcValidity(const ArcIndexType arc) const { + return ((arc != kNilArc) && (arc >= kFirstArc && arc < max_num_arcs_)); + } + + // Returns true if arc is a valid index into the (*tail_) array. + bool CheckTailIndexValidity(const ArcIndexType arc) const { + return ((tail_ != nullptr) && (arc >= kFirstArc) && + (arc <= tail_->max_index())); + } + + ArcIndexType NextOutgoingArc(const NodeIndexType node, + ArcIndexType arc) const { + assert(IsNodeValid(node)); + assert(CheckArcValidity(arc)); + ++arc; + if (arc < first_incident_arc_[node + 1]) { + return arc; + } else { + return kNilArc; + } + } + + // Returns a debug string containing all the information contained in the + // data structure in raw form. + std::string DebugString() const { + std::string result = "Arcs:(node) :\n"; + for (ArcIndexType arc = kFirstArc; arc < num_arcs_; ++arc) { + result += " " + ArcDebugString(arc) + ":(" + NodeDebugString(head_[arc]) + + ")\n"; + } + result += "Node:First arc :\n"; + for (NodeIndexType node = kFirstNode; node <= num_nodes_; ++node) { + result += " " + NodeDebugString(node) + ":" + + ArcDebugString(first_incident_arc_[node]) + "\n"; + } + return result; + } + + bool BuildTailArray() { + // If (*tail_) is already allocated, we have the invariant that + // its contents are canonical, so we do not need to do anything + // here in that case except return true. + if (tail_ == nullptr) { + if (!RepresentationClean()) { + // We have been asked to build the (*tail_) array, but we have + // no valid information from which to build it. The graph is + // in an unrecoverable, inconsistent state. + return false; + } + // Reallocate (*tail_) and rebuild its contents from the + // adjacency lists. + tail_.reset(new ZVector); + tail_->Reserve(kFirstArc, max_num_arcs_ - 1); + typename Base::NodeIterator node_it(*this); + for (; node_it.Ok(); node_it.Next()) { + NodeIndexType node = node_it.Index(); + typename Base::OutgoingArcIterator arc_it(*this, node); + for (; arc_it.Ok(); arc_it.Next()) { + (*tail_)[arc_it.Index()] = node; + } + } + } + assert(TailArrayComplete()); + return true; + } + + void ReleaseTailArray() { tail_.reset(nullptr); } + + // To be used in a DCHECK(). + bool TailArrayComplete() const { + if (!tail_) { + throw isce3::except::RuntimeError(ISCE_SRCINFO(), "tail_ is nullptr"); + } + for (ArcIndexType arc = kFirstArc; arc < num_arcs_; ++arc) { + if (!CheckTailIndexValidity(arc)) { + throw isce3::except::RuntimeError( + ISCE_SRCINFO(), "CheckTailIndexValidity(arc) failed"); + } + if (!IsNodeValid((*tail_)[arc])) { + throw isce3::except::RuntimeError( + ISCE_SRCINFO(), "IsNodeValid((*tail_)[arc]) failed"); + } + } + return true; + } + + private: + bool IsDirect() const { return true; } + bool RepresentationClean() const { return true; } + bool IsOutgoing(const NodeIndexType node, + const ArcIndexType unused_arc) const { + return true; + } + + // Returns the first arc in node's incidence list. + ArcIndexType FirstOutgoingOrOppositeIncomingArc(NodeIndexType node) const { + assert(RepresentationClean()); + assert(IsNodeValid(node)); + ArcIndexType result = first_incident_arc_[node]; + return ((result != first_incident_arc_[node + 1]) ? result : kNilArc); + } + + // Utility method that finds the next outgoing arc. + ArcIndexType FindNextOutgoingArc(ArcIndexType arc) const { + assert(CheckArcBounds(arc)); + return arc; + } + + // Array of node indices, not always present. (*tail_)[i] contains + // the tail node of arc i. This array is not needed for normal graph + // traversal operations, but is used in optimizing the graph's + // layout so arcs are grouped by tail node, and can be used in one + // approach to serializing the graph. + // + // Invariants: At any time when we are not executing a method of + // this class, either tail_ == NULL or the tail_ array's contents + // are kept canonical. If tail_ != NULL, any method that modifies + // adjacency lists must also ensure (*tail_) is modified + // correspondingly. The converse does not hold: Modifications to + // (*tail_) are allowed without updating the adjacency lists. If + // such modifications take place, representation_clean_ must be set + // to false, of course, to indicate that the adjacency lists are no + // longer current. + std::unique_ptr > tail_; +}; + +// The index of the 'nil' node in the graph. +template +const NodeIndexType + StarGraphBase::kNilNode = -1; + +// The index of the 'nil' arc in the graph. +template +const ArcIndexType + StarGraphBase::kNilArc = + std::numeric_limits::min(); + +// The index of the first node in the graph. +template +const NodeIndexType + StarGraphBase::kFirstNode = 0; + +// The index of the first arc in the graph. +template +const ArcIndexType + StarGraphBase::kFirstArc = 0; + +// The maximum possible node index in the graph. +template +const NodeIndexType + StarGraphBase::kMaxNumNodes = + std::numeric_limits::max(); + +// The maximum possible number of arcs in the graph. +// (The maximum index is kMaxNumArcs-1, since indices start at 0.) +template +const ArcIndexType + StarGraphBase::kMaxNumArcs = + std::numeric_limits::max(); + +// A template for the base class that holds the functionality that exists in +// common between the EbertGraph<> template and the ForwardEbertGraph<> +// template. +// +// This template is for internal use only, and this is enforced by making all +// constructors for this class template protected. Clients should use one of the +// two derived-class templates. Most clients will not even use those directly, +// but will use the StarGraph and ForwardStarGraph typenames declared above. +// +// The DerivedGraph template argument must be the type of the class (typically +// itself built from a template) that: +// 1. implements the full interface expected for either ForwardEbertGraph or +// EbertGraph, and +// 2. inherits from an instance of this template. +// The base class needs access to some members of the derived class such as, for +// example, NextOutgoingArc(), and it gets this access via the DerivedGraph +// template argument. +template +class EbertGraphBase + : public StarGraphBase { + typedef StarGraphBase Base; + friend class StarGraphBase; + + protected: + using Base::first_incident_arc_; + using Base::head_; + using Base::max_num_arcs_; + using Base::max_num_nodes_; + using Base::num_arcs_; + using Base::num_nodes_; + + public: + using Base::end_arc_index; + using Base::IsNodeValid; + + using Base::kFirstArc; + using Base::kFirstNode; + using Base::kMaxNumArcs; + using Base::kMaxNumNodes; + using Base::kNilArc; + using Base::kNilNode; + + // Reserves memory needed for max_num_nodes nodes and max_num_arcs arcs. + // Returns false if the parameters passed are not OK. + // It can be used to enlarge the graph, but does not shrink memory + // if called with smaller values. + bool Reserve(NodeIndexType new_max_num_nodes, ArcIndexType new_max_num_arcs) { + if (new_max_num_nodes < 0 || new_max_num_nodes > kMaxNumNodes) { + return false; + } + if (new_max_num_arcs < 0 || new_max_num_arcs > kMaxNumArcs) { + return false; + } + first_incident_arc_.Reserve(kFirstNode, new_max_num_nodes - 1); + for (NodeIndexType node = max_num_nodes_; + node <= first_incident_arc_.max_index(); ++node) { + first_incident_arc_.Set(node, kNilArc); + } + ThisAsDerived()->ReserveInternal(new_max_num_nodes, new_max_num_arcs); + max_num_nodes_ = new_max_num_nodes; + max_num_arcs_ = new_max_num_arcs; + return true; + } + + // Adds an arc to the graph and returns its index. + // Returns kNilArc if the arc could not be added. + // Note that for a given pair (tail, head) AddArc does not overwrite an + // already-existing arc between tail and head: Another arc is created + // instead. This makes it possible to handle multi-graphs. + ArcIndexType AddArc(NodeIndexType tail, NodeIndexType head) { + if (num_arcs_ >= max_num_arcs_ || !IsNodeValid(tail) || + !IsNodeValid(head)) { + return kNilArc; + } + if (tail + 1 > num_nodes_) { + num_nodes_ = tail + 1; // max does not work on int16_t. + } + if (head + 1 > num_nodes_) { + num_nodes_ = head + 1; + } + ArcIndexType arc = num_arcs_; + ++num_arcs_; + ThisAsDerived()->RecordArc(arc, tail, head); + return arc; + } + + template + void GroupForwardArcsByFunctor( + const ArcIndexTypeStrictWeakOrderingFunctor& compare, + PermutationCycleHandler* annotation_handler) { + std::unique_ptr arc_permutation( + new ArcIndexType[end_arc_index()]); + + // Determine the permutation that groups arcs by their tail nodes. + for (ArcIndexType i = 0; i < end_arc_index(); ++i) { + // Start with the identity permutation. + arc_permutation[i] = i; + } + std::sort(&arc_permutation[kFirstArc], &arc_permutation[end_arc_index()], + compare); + + // Now we actually permute the head_ array and the + // scaled_arc_cost_ array according to the sorting permutation. + CycleHandlerForAnnotatedArcs cycle_handler(annotation_handler, + ThisAsDerived()); + PermutationApplier permutation(&cycle_handler); + permutation.Apply(&arc_permutation[0], kFirstArc, end_arc_index()); + + // Finally, rebuild the graph from its permuted head_ array. + ThisAsDerived()->BuildRepresentation(); + } + + class CycleHandlerForAnnotatedArcs + : public PermutationCycleHandler { + public: + CycleHandlerForAnnotatedArcs( + PermutationCycleHandler* annotation_handler, + DerivedGraph* graph) + : annotation_handler_(annotation_handler), + graph_(graph), + head_temp_(kNilNode), + tail_temp_(kNilNode) {} + + void SetTempFromIndex(ArcIndexType source) override { + if (annotation_handler_ != nullptr) { + annotation_handler_->SetTempFromIndex(source); + } + head_temp_ = graph_->Head(source); + tail_temp_ = graph_->Tail(source); + } + + void SetIndexFromIndex(ArcIndexType source, + ArcIndexType destination) const override { + if (annotation_handler_ != nullptr) { + annotation_handler_->SetIndexFromIndex(source, destination); + } + graph_->SetHead(destination, graph_->Head(source)); + graph_->SetTail(destination, graph_->Tail(source)); + } + + void SetIndexFromTemp(ArcIndexType destination) const override { + if (annotation_handler_ != nullptr) { + annotation_handler_->SetIndexFromTemp(destination); + } + graph_->SetHead(destination, head_temp_); + graph_->SetTail(destination, tail_temp_); + } + + // Since we are free to destroy the permutation array we use the + // kNilArc value to mark entries in the array that have been + // processed already. There is no need to be able to recover the + // original permutation array entries once they have been seen. + void SetSeen(ArcIndexType* permutation_element) const override { + *permutation_element = kNilArc; + } + + bool Unseen(ArcIndexType permutation_element) const override { + return permutation_element != kNilArc; + } + + ~CycleHandlerForAnnotatedArcs() override {} + + private: + PermutationCycleHandler* annotation_handler_; + DerivedGraph* graph_; + NodeIndexType head_temp_; + NodeIndexType tail_temp_; + + CycleHandlerForAnnotatedArcs(const CycleHandlerForAnnotatedArcs&); + CycleHandlerForAnnotatedArcs& operator=(const CycleHandlerForAnnotatedArcs&); + }; + + protected: + EbertGraphBase() : next_adjacent_arc_(), representation_clean_(true) {} + + ~EbertGraphBase() {} + + void Initialize(NodeIndexType max_num_nodes, ArcIndexType max_num_arcs) { + if (!Reserve(max_num_nodes, max_num_arcs)) { + pyre::journal::firewall_t channel("isce3.unwrap.ortools.ebert_graph"); + channel << pyre::journal::at(__HERE__) + << "Could not reserve memory for " + << static_cast(max_num_nodes) << " nodes and " + << static_cast(max_num_arcs) << " arcs." + << pyre::journal::endl; + } + first_incident_arc_.SetAll(kNilArc); + ThisAsDerived()->InitializeInternal(max_num_nodes, max_num_arcs); + } + + // Returns the first arc in node's incidence list. + ArcIndexType FirstOutgoingOrOppositeIncomingArc( + const NodeIndexType node) const { + assert(representation_clean_); + assert(IsNodeValid(node)); + return first_incident_arc_[node]; + } + + // Returns the next arc following the passed argument in its adjacency list. + ArcIndexType NextAdjacentArc(const ArcIndexType arc) const { + assert(representation_clean_); + assert(ThisAsDerived()->CheckArcValidity(arc)); + return next_adjacent_arc_[arc]; + } + + // Returns the outgoing arc following the argument in the adjacency list. + ArcIndexType NextOutgoingArc(const NodeIndexType /*unused_node*/, + const ArcIndexType arc) const { + assert(ThisAsDerived()->CheckArcValidity(arc)); + assert(ThisAsDerived()->IsDirect(arc)); + return ThisAsDerived()->FindNextOutgoingArc(NextAdjacentArc(arc)); + } + + // Array of next indices. + // next_adjacent_arc_[i] contains the next arc in the adjacency list of arc i. + ZVector next_adjacent_arc_; + + // Flag to indicate that BuildRepresentation() needs to be called + // before the adjacency lists are examined. Only for DCHECK in debug + // builds. + bool representation_clean_; + + private: + // Shorthand: returns a const DerivedGraph*-typed version of our + // "this" pointer. + inline const DerivedGraph* ThisAsDerived() const { + return static_cast(this); + } + + // Shorthand: returns a DerivedGraph*-typed version of our "this" + // pointer. + inline DerivedGraph* ThisAsDerived() { + return static_cast(this); + } + + void InitializeInternal(NodeIndexType /*max_num_nodes*/, + ArcIndexType /*max_num_arcs*/) { + next_adjacent_arc_.SetAll(kNilArc); + } + + bool RepresentationClean() const { return representation_clean_; } + + // Using the SetHead() method implies that the BuildRepresentation() + // method must be called to restore consistency before the graph is + // used. + void SetHead(const ArcIndexType arc, const NodeIndexType head) { + representation_clean_ = false; + head_.Set(arc, head); + } +}; + +// Most users should only use StarGraph, which is EbertGraph, +// and other type shortcuts; see the bottom of this file. +template +class EbertGraph + : public EbertGraphBase > { + typedef EbertGraphBase > + Base; + friend class EbertGraphBase >; + friend class StarGraphBase >; + + using Base::ArcDebugString; + using Base::FirstOutgoingOrOppositeIncomingArc; + using Base::Initialize; + using Base::NextAdjacentArc; + using Base::NodeDebugString; + + using Base::first_incident_arc_; + using Base::head_; + using Base::max_num_arcs_; + using Base::max_num_nodes_; + using Base::next_adjacent_arc_; + using Base::num_arcs_; + using Base::num_nodes_; + using Base::representation_clean_; + + public: + using Base::Head; + using Base::IsNodeValid; + + using Base::kFirstArc; + using Base::kFirstNode; + using Base::kNilArc; + using Base::kNilNode; + + typedef NodeIndexType NodeIndex; + typedef ArcIndexType ArcIndex; + + EbertGraph() {} + + EbertGraph(NodeIndexType max_num_nodes, ArcIndexType max_num_arcs) { + Initialize(max_num_nodes, max_num_arcs); + } + + ~EbertGraph() {} + + // Iterator class for traversing the arcs incident to a given node in the + // graph. + class OutgoingOrOppositeIncomingArcIterator { + public: + OutgoingOrOppositeIncomingArcIterator(const EbertGraph& graph, + NodeIndexType node) + : graph_(graph), + node_(graph_.StartNode(node)), + arc_(graph_.StartArc( + graph_.FirstOutgoingOrOppositeIncomingArc(node))) { + assert(CheckInvariant()); + } + + // This constructor takes an arc as extra argument and makes the iterator + // start at arc. + OutgoingOrOppositeIncomingArcIterator(const EbertGraph& graph, + NodeIndexType node, ArcIndexType arc) + : graph_(graph), + node_(graph_.StartNode(node)), + arc_(graph_.StartArc(arc)) { + assert(CheckInvariant()); + } + + // Can only assign from an iterator on the same graph. + void operator=(const OutgoingOrOppositeIncomingArcIterator& iterator) { + assert(&iterator.graph_ == &graph_); + node_ = iterator.node_; + arc_ = iterator.arc_; + } + + // Returns true unless all the adjancent arcs have been traversed. + bool Ok() const { return arc_ != kNilArc; } + + // Advances the current adjacent arc index. + void Next() { + arc_ = graph_.NextAdjacentArc(arc_); + assert(CheckInvariant()); + } + + // Returns the index of the arc currently pointed to by the iterator. + ArcIndexType Index() const { return arc_; } + + private: + // Returns true if the invariant for the iterator is verified. + // To be used in a DCHECK. + bool CheckInvariant() const { + if (arc_ == kNilArc) { + return true; // This occurs when the iterator has reached the end. + } + assert(graph_.IsOutgoingOrOppositeIncoming(arc_, node_)); + return true; + } + // A reference to the current EbertGraph considered. + const EbertGraph& graph_; + + // The index of the node on which arcs are iterated. + NodeIndexType node_; + + // The index of the current arc considered. + ArcIndexType arc_; + }; + + // Iterator class for traversing the incoming arcs associated to a given node. + class IncomingArcIterator { + public: + IncomingArcIterator(const EbertGraph& graph, NodeIndexType node) + : graph_(graph), + node_(graph_.StartNode(node)), + arc_(graph_.StartArc(graph_.FirstIncomingArc(node))) { + assert(CheckInvariant()); + } + + // This constructor takes an arc as extra argument and makes the iterator + // start at arc. + IncomingArcIterator(const EbertGraph& graph, NodeIndexType node, + ArcIndexType arc) + : graph_(graph), + node_(graph_.StartNode(node)), + arc_(arc == kNilArc ? kNilArc + : graph_.StartArc(graph_.Opposite(arc))) { + assert(CheckInvariant()); + } + + // Can only assign from an iterator on the same graph. + void operator=(const IncomingArcIterator& iterator) { + assert(&iterator.graph_ == &graph_); + node_ = iterator.node_; + arc_ = iterator.arc_; + } + + // Returns true unless all the incoming arcs have been traversed. + bool Ok() const { return arc_ != kNilArc; } + + // Advances the current incoming arc index. + void Next() { + arc_ = graph_.NextIncomingArc(arc_); + assert(CheckInvariant()); + } + + // Returns the index of the arc currently pointed to by the iterator. + ArcIndexType Index() const { + return arc_ == kNilArc ? kNilArc : graph_.Opposite(arc_); + } + + private: + // Returns true if the invariant for the iterator is verified. + // To be used in a DCHECK. + bool CheckInvariant() const { + if (arc_ == kNilArc) { + return true; // This occurs when the iterator has reached the end. + } + assert(graph_.IsIncoming(Index(), node_)); + return true; + } + // A reference to the current EbertGraph considered. + const EbertGraph& graph_; + + // The index of the node on which arcs are iterated. + NodeIndexType node_; + + // The index of the current arc considered. + ArcIndexType arc_; + }; + + // Utility function to check that an arc index is within the bounds. + // It is exported so that users of the EbertGraph class can use it. + // To be used in a DCHECK. + bool CheckArcBounds(const ArcIndexType arc) const { + return (arc == kNilArc) || (arc >= -max_num_arcs_ && arc < max_num_arcs_); + } + + // Utility function to check that an arc index is within the bounds AND + // different from kNilArc. + // It is exported so that users of the EbertGraph class can use it. + // To be used in a DCHECK. + bool CheckArcValidity(const ArcIndexType arc) const { + return (arc != kNilArc) && (arc >= -max_num_arcs_ && arc < max_num_arcs_); + } + + // Returns the tail or start-node of arc. + NodeIndexType Tail(const ArcIndexType arc) const { + assert(CheckArcValidity(arc)); + return head_[Opposite(arc)]; + } + + // Returns the tail or start-node of arc if it is positive + // (i.e. it is taken in the direction it was entered in the graph), + // and the head or end-node otherwise. 'This' in Ebert's paper. + NodeIndexType DirectArcTail(const ArcIndexType arc) const { + return Tail(DirectArc(arc)); + } + + // Returns the head or end-node of arc if it is positive + // (i.e. it is taken in the direction it was entered in the graph), + // and the tail or start-node otherwise. 'That' in Ebert's paper. + NodeIndexType DirectArcHead(const ArcIndexType arc) const { + return Head(DirectArc(arc)); + } + + // Returns the arc in normal/direct direction. + ArcIndexType DirectArc(const ArcIndexType arc) const { + assert(CheckArcValidity(arc)); + return std::max(arc, Opposite(arc)); + } + + // Returns the arc in reverse direction. + ArcIndexType ReverseArc(const ArcIndexType arc) const { + assert(CheckArcValidity(arc)); + return std::min(arc, Opposite(arc)); + } + + // Returns the opposite arc, i.e the direct arc is the arc is in reverse + // direction, and the reverse arc if the arc is direct. + ArcIndexType Opposite(const ArcIndexType arc) const { + const ArcIndexType opposite = ~arc; + assert(CheckArcValidity(arc)); + assert(CheckArcValidity(opposite)); + return opposite; + } + + // Returns true if the arc is direct. + bool IsDirect(const ArcIndexType arc) const { + assert(CheckArcBounds(arc)); + return arc != kNilArc && arc >= 0; + } + + // Returns true if the arc is in the reverse direction. + bool IsReverse(const ArcIndexType arc) const { + assert(CheckArcBounds(arc)); + return arc != kNilArc && arc < 0; + } + + // Returns true if arc is incident to node. + bool IsOutgoingOrOppositeIncoming(ArcIndexType arc, + NodeIndexType node) const { + return Tail(arc) == node; + } + + // Returns true if arc is incoming to node. + bool IsIncoming(ArcIndexType arc, NodeIndexType node) const { + return IsDirect(arc) && Head(arc) == node; + } + + // Returns true if arc is outgoing from node. + bool IsOutgoing(ArcIndexType arc, NodeIndexType node) const { + return IsDirect(arc) && Tail(arc) == node; + } + + // Recreates the next_adjacent_arc_ and first_incident_arc_ variables from + // the array head_ in O(n + m) time. + // This is useful if head_ array has been sorted according to a given + // criterion, for example. + void BuildRepresentation() { + first_incident_arc_.SetAll(kNilArc); + for (ArcIndexType arc = kFirstArc; arc < max_num_arcs_; ++arc) { + Attach(arc); + } + representation_clean_ = true; + } + + // Returns a debug string containing all the information contained in the + // data structure in raw form. + std::string DebugString() const { + assert(representation_clean_); + std::string result = "Arcs:(node, next arc) :\n"; + for (ArcIndexType arc = -num_arcs_; arc < num_arcs_; ++arc) { + result += " " + ArcDebugString(arc) + ":(" + NodeDebugString(head_[arc]) + + "," + ArcDebugString(next_adjacent_arc_[arc]) + ")\n"; + } + result += "Node:First arc :\n"; + for (NodeIndexType node = kFirstNode; node < num_nodes_; ++node) { + result += " " + NodeDebugString(node) + ":" + + ArcDebugString(first_incident_arc_[node]) + "\n"; + } + return result; + } + + private: + // Handles reserving space in the next_adjacent_arc_ and head_ + // arrays, which are always present and are therefore in the base + // class. Although they reside in the base class, those two arrays + // are maintained differently by different derived classes, + // depending on whether the derived class stores reverse arcs. Hence + // the code to set those arrays up is in a method of the derived + // class. + void ReserveInternal(NodeIndexType /*new_max_num_nodes*/, + ArcIndexType new_max_num_arcs) { + head_.Reserve(-new_max_num_arcs, new_max_num_arcs - 1); + next_adjacent_arc_.Reserve(-new_max_num_arcs, new_max_num_arcs - 1); + for (ArcIndexType arc = -new_max_num_arcs; arc < -max_num_arcs_; ++arc) { + head_.Set(arc, kNilNode); + next_adjacent_arc_.Set(arc, kNilArc); + } + for (ArcIndexType arc = max_num_arcs_; arc < new_max_num_arcs; ++arc) { + head_.Set(arc, kNilNode); + next_adjacent_arc_.Set(arc, kNilArc); + } + } + + // Returns the first incoming arc for node. + ArcIndexType FirstIncomingArc(const NodeIndexType node) const { + assert(kFirstNode <= node); + assert(max_num_nodes_ >= node); + return FindNextIncomingArc(FirstOutgoingOrOppositeIncomingArc(node)); + } + + // Returns the incoming arc following the argument in the adjacency list. + ArcIndexType NextIncomingArc(const ArcIndexType arc) const { + assert(CheckArcValidity(arc)); + assert(IsReverse(arc)); + return FindNextIncomingArc(NextAdjacentArc(arc)); + } + + // Handles the part of AddArc() that is not in common with other + // graph classes based on the EbertGraphBase template. + void RecordArc(ArcIndexType arc, NodeIndexType tail, NodeIndexType head) { + head_.Set(Opposite(arc), tail); + head_.Set(arc, head); + Attach(arc); + } + + // Using the SetTail() method implies that the BuildRepresentation() + // method must be called to restore consistency before the graph is + // used. + void SetTail(const ArcIndexType arc, const NodeIndexType tail) { + representation_clean_ = false; + head_.Set(Opposite(arc), tail); + } + + // Utility method to attach a new arc. + void Attach(ArcIndexType arc) { + assert(CheckArcValidity(arc)); + const NodeIndexType tail = head_[Opposite(arc)]; + assert(IsNodeValid(tail)); + next_adjacent_arc_.Set(arc, first_incident_arc_[tail]); + first_incident_arc_.Set(tail, arc); + const NodeIndexType head = head_[arc]; + assert(IsNodeValid(head)); + next_adjacent_arc_.Set(Opposite(arc), first_incident_arc_[head]); + first_incident_arc_.Set(head, Opposite(arc)); + } + + // Utility method that finds the next outgoing arc. + ArcIndexType FindNextOutgoingArc(ArcIndexType arc) const { + assert(CheckArcBounds(arc)); + while (IsReverse(arc)) { + arc = NextAdjacentArc(arc); + assert(CheckArcBounds(arc)); + } + return arc; + } + + // Utility method that finds the next incoming arc. + ArcIndexType FindNextIncomingArc(ArcIndexType arc) const { + assert(CheckArcBounds(arc)); + while (IsDirect(arc)) { + arc = NextAdjacentArc(arc); + assert(CheckArcBounds(arc)); + } + return arc; + } +}; + +// A forward-star-only graph representation for greater efficiency in +// those algorithms that don't need reverse arcs. +template +class ForwardEbertGraph + : public EbertGraphBase > { + typedef EbertGraphBase > + Base; + friend class EbertGraphBase >; + friend class StarGraphBase >; + + using Base::ArcDebugString; + using Base::Initialize; + using Base::NextAdjacentArc; + using Base::NodeDebugString; + + using Base::first_incident_arc_; + using Base::head_; + using Base::max_num_arcs_; + using Base::max_num_nodes_; + using Base::next_adjacent_arc_; + using Base::num_arcs_; + using Base::num_nodes_; + using Base::representation_clean_; + + public: + using Base::Head; + using Base::IsNodeValid; + + using Base::kFirstArc; + using Base::kFirstNode; + using Base::kNilArc; + using Base::kNilNode; + + typedef NodeIndexType NodeIndex; + typedef ArcIndexType ArcIndex; + + ForwardEbertGraph() {} + + ForwardEbertGraph(NodeIndexType max_num_nodes, ArcIndexType max_num_arcs) { + Initialize(max_num_nodes, max_num_arcs); + } + + ~ForwardEbertGraph() {} + + // Utility function to check that an arc index is within the bounds. + // It is exported so that users of the ForwardEbertGraph class can use it. + // To be used in a DCHECK. + bool CheckArcBounds(const ArcIndexType arc) const { + return (arc == kNilArc) || (arc >= kFirstArc && arc < max_num_arcs_); + } + + // Utility function to check that an arc index is within the bounds AND + // different from kNilArc. + // It is exported so that users of the ForwardEbertGraph class can use it. + // To be used in a DCHECK. + bool CheckArcValidity(const ArcIndexType arc) const { + return (arc != kNilArc) && (arc >= kFirstArc && arc < max_num_arcs_); + } + + // Returns true if arc is a valid index into the (*tail_) array. + bool CheckTailIndexValidity(const ArcIndexType arc) const { + return (tail_ != nullptr) && (arc >= kFirstArc) && + (arc <= tail_->max_index()); + } + + // Returns the tail or start-node of arc. + NodeIndexType Tail(const ArcIndexType arc) const { + assert(CheckArcValidity(arc)); + assert(CheckTailIndexValidity(arc)); + return (*tail_)[arc]; + } + + // Returns true if arc is incoming to node. + bool IsIncoming(ArcIndexType arc, NodeIndexType node) const { + return IsDirect(arc) && Head(arc) == node; + } + + // Recreates the next_adjacent_arc_ and first_incident_arc_ + // variables from the arrays head_ and tail_ in O(n + m) time. This + // is useful if the head_ and tail_ arrays have been sorted + // according to a given criterion, for example. + void BuildRepresentation() { + first_incident_arc_.SetAll(kNilArc); + assert(TailArrayComplete()); + for (ArcIndexType arc = kFirstArc; arc < max_num_arcs_; ++arc) { + assert(CheckTailIndexValidity(arc)); + Attach((*tail_)[arc], arc); + } + representation_clean_ = true; + } + + bool BuildTailArray() { + // If (*tail_) is already allocated, we have the invariant that + // its contents are canonical, so we do not need to do anything + // here in that case except return true. + if (tail_ == nullptr) { + if (!representation_clean_) { + // We have been asked to build the (*tail_) array, but we have + // no valid information from which to build it. The graph is + // in an unrecoverable, inconsistent state. + return false; + } + // Reallocate (*tail_) and rebuild its contents from the + // adjacency lists. + tail_.reset(new ZVector); + tail_->Reserve(kFirstArc, max_num_arcs_ - 1); + typename Base::NodeIterator node_it(*this); + for (; node_it.Ok(); node_it.Next()) { + NodeIndexType node = node_it.Index(); + typename Base::OutgoingArcIterator arc_it(*this, node); + for (; arc_it.Ok(); arc_it.Next()) { + (*tail_)[arc_it.Index()] = node; + } + } + } + assert(TailArrayComplete()); + return true; + } + + void ReleaseTailArray() { tail_.reset(nullptr); } + + // To be used in a DCHECK(). + bool TailArrayComplete() const { + if (!tail_) { + throw isce3::except::RuntimeError(ISCE_SRCINFO(), "tail_ is nullptr"); + } + for (ArcIndexType arc = kFirstArc; arc < num_arcs_; ++arc) { + if (!CheckTailIndexValidity(arc)) { + throw isce3::except::RuntimeError( + ISCE_SRCINFO(), "CheckTailIndexValidity(arc) failed"); + } + if (!IsNodeValid((*tail_)[arc])) { + throw isce3::except::RuntimeError( + ISCE_SRCINFO(), "IsNodeValid((*tail_)[arc]) failed"); + } + } + return true; + } + + // Returns a debug string containing all the information contained in the + // data structure in raw form. + std::string DebugString() const { + assert(representation_clean_); + std::string result = "Arcs:(node, next arc) :\n"; + for (ArcIndexType arc = kFirstArc; arc < num_arcs_; ++arc) { + result += " " + ArcDebugString(arc) + ":(" + NodeDebugString(head_[arc]) + + "," + ArcDebugString(next_adjacent_arc_[arc]) + ")\n"; + } + result += "Node:First arc :\n"; + for (NodeIndexType node = kFirstNode; node < num_nodes_; ++node) { + result += " " + NodeDebugString(node) + ":" + + ArcDebugString(first_incident_arc_[node]) + "\n"; + } + return result; + } + + private: + // Reserves space for the (*tail_) array. + // + // This method is separate from ReserveInternal() because our + // practice of making the (*tail_) array optional implies that the + // tail_ pointer might not be constructed when the ReserveInternal() + // method is called. Therefore we have this method also, and we + // ensure that it is called only when tail_ is guaranteed to have + // been initialized. + void ReserveTailArray(ArcIndexType new_max_num_arcs) { + if (tail_ != nullptr) { + // The (*tail_) values are already canonical, so we're just + // reserving additional space for new arcs that haven't been + // added yet. + if (tail_->Reserve(kFirstArc, new_max_num_arcs - 1)) { + for (ArcIndexType arc = tail_->max_index() + 1; arc < new_max_num_arcs; + ++arc) { + tail_->Set(arc, kNilNode); + } + } + } + } + + // Reserves space for the arrays indexed by arc indices, except + // (*tail_) even if it is present. We cannot grow the (*tail_) array + // in this method because this method is called from + // Base::Reserve(), which in turn is called from the base template + // class constructor. That base class constructor is called on *this + // before tail_ is constructed. Hence when this method is called, + // tail_ might contain garbage. This method can safely refer only to + // fields of the base template class, not to fields of *this outside + // the base template class. + // + // The strange situation in which this method of a derived class can + // refer only to members of the base class arises because different + // derived classes use the data members of the base class in + // slightly different ways. The purpose of this derived class + // method, then, is only to encode the derived-class-specific + // conventions for how the derived class uses the data members of + // the base class. + // + // To be specific, the forward-star graph representation, lacking + // reverse arcs, allocates only the positive index range for the + // head_ and next_adjacent_arc_ arrays, while the general + // representation allocates space for both positive- and + // negative-indexed arcs (i.e., both forward and reverse arcs). + void ReserveInternal(NodeIndexType new_max_num_nodes, + ArcIndexType new_max_num_arcs) { + head_.Reserve(kFirstArc, new_max_num_arcs - 1); + next_adjacent_arc_.Reserve(kFirstArc, new_max_num_arcs - 1); + for (ArcIndexType arc = max_num_arcs_; arc < new_max_num_arcs; ++arc) { + head_.Set(arc, kNilNode); + next_adjacent_arc_.Set(arc, kNilArc); + } + ReserveTailArray(new_max_num_arcs); + } + + // Handles the part of AddArc() that is not in common wth other + // graph classes based on the EbertGraphBase template. + void RecordArc(ArcIndexType arc, NodeIndexType tail, NodeIndexType head) { + head_.Set(arc, head); + Attach(tail, arc); + } + + // Using the SetTail() method implies that the BuildRepresentation() + // method must be called to restore consistency before the graph is + // used. + void SetTail(const ArcIndexType arc, const NodeIndexType tail) { + assert(CheckTailIndexValidity(arc)); + if (!tail_) { + throw isce3::except::RuntimeError(ISCE_SRCINFO(), "tail_ is nullptr"); + } + representation_clean_ = false; + tail_->Set(arc, tail); + } + + // Utility method to attach a new arc. + void Attach(NodeIndexType tail, ArcIndexType arc) { + assert(CheckArcValidity(arc)); + assert(IsNodeValid(tail)); + next_adjacent_arc_.Set(arc, first_incident_arc_[tail]); + first_incident_arc_.Set(tail, arc); + const NodeIndexType head = head_[arc]; + assert(IsNodeValid(head)); + // Because Attach() is a public method, keeping (*tail_) canonical + // requires us to record the new arc's tail here. + if (tail_ != nullptr) { + assert(CheckTailIndexValidity(arc)); + tail_->Set(arc, tail); + } + } + + // Utility method that finds the next outgoing arc. + ArcIndexType FindNextOutgoingArc(ArcIndexType arc) const { + assert(CheckArcBounds(arc)); + return arc; + } + + private: + // Always returns true because for any ForwardEbertGraph, only + // direct arcs are represented, so all valid arc indices refer to + // arcs that are outgoing from their tail nodes. + bool IsOutgoing(const ArcIndex unused_arc, + const NodeIndex unused_node) const { + return true; + } + + // Always returns true because for any ForwardEbertGraph, only + // outgoing arcs are represented, so all valid arc indices refer to + // direct arcs. + bool IsDirect(const ArcIndex unused_arc) const { return true; } + + // Array of node indices, not always present. (*tail_)[i] contains + // the tail node of arc i. This array is not needed for normal graph + // traversal operations, but is used in optimizing the graph's + // layout so arcs are grouped by tail node, and can be used in one + // approach to serializing the graph. + // + // Invariants: At any time when we are not executing a method of + // this class, either tail_ == NULL or the tail_ array's contents + // are kept canonical. If tail_ != NULL, any method that modifies + // adjacency lists must also ensure (*tail_) is modified + // correspondingly. The converse does not hold: Modifications to + // (*tail_) are allowed without updating the adjacency lists. If + // such modifications take place, representation_clean_ must be set + // to false, of course, to indicate that the adjacency lists are no + // longer current. + std::unique_ptr > tail_; +}; + +// Traits for EbertGraphBase types, for use in testing and clients +// that work with both forward-only and forward/reverse graphs. +// +// The default is to assume reverse arcs so if someone forgets to +// specialize the traits of a new forward-only graph type, they will +// get errors from tests rather than incomplete testing. +template +struct graph_traits { + static constexpr bool has_reverse_arcs = true; + static constexpr bool is_dynamic = true; +}; + +template +struct graph_traits > { + static constexpr bool has_reverse_arcs = false; + static constexpr bool is_dynamic = true; +}; + +template +struct graph_traits > { + static constexpr bool has_reverse_arcs = false; + static constexpr bool is_dynamic = false; +}; + +namespace or_internal { + +// The TailArrayBuilder class template is not expected to be used by +// clients. It is a helper for the TailArrayManager template. +// +// The TailArrayBuilder for graphs with reverse arcs does nothing. +template +struct TailArrayBuilder { + explicit TailArrayBuilder(GraphType* unused_graph) {} + + bool BuildTailArray() const { return true; } +}; + +// The TailArrayBuilder for graphs without reverse arcs calls the +// appropriate method on the graph from the TailArrayBuilder +// constructor. +template +struct TailArrayBuilder { + explicit TailArrayBuilder(GraphType* graph) : graph_(graph) {} + + bool BuildTailArray() const { return graph_->BuildTailArray(); } + + GraphType* const graph_; +}; + +// The TailArrayReleaser class template is not expected to be used by +// clients. It is a helper for the TailArrayManager template. +// +// The TailArrayReleaser for graphs with reverse arcs does nothing. +template +struct TailArrayReleaser { + explicit TailArrayReleaser(GraphType* unused_graph) {} + + void ReleaseTailArray() const {} +}; + +// The TailArrayReleaser for graphs without reverse arcs calls the +// appropriate method on the graph from the TailArrayReleaser +// constructor. +template +struct TailArrayReleaser { + explicit TailArrayReleaser(GraphType* graph) : graph_(graph) {} + + void ReleaseTailArray() const { graph_->ReleaseTailArray(); } + + GraphType* const graph_; +}; + +} // namespace or_internal + +template +class TailArrayManager { + public: + explicit TailArrayManager(GraphType* g) : graph_(g) {} + + bool BuildTailArrayFromAdjacencyListsIfForwardGraph() const { + or_internal::TailArrayBuilder::has_reverse_arcs> + tail_array_builder(graph_); + return tail_array_builder.BuildTailArray(); + } + + void ReleaseTailArrayIfForwardGraph() const { + or_internal::TailArrayReleaser::has_reverse_arcs> + tail_array_releaser(graph_); + tail_array_releaser.ReleaseTailArray(); + } + + private: + GraphType* graph_; +}; + +template +class ArcFunctorOrderingByTailAndHead { + public: + explicit ArcFunctorOrderingByTailAndHead(const GraphType& graph) + : graph_(graph) {} + + bool operator()(typename GraphType::ArcIndex a, + typename GraphType::ArcIndex b) const { + return ((graph_.Tail(a) < graph_.Tail(b)) || + ((graph_.Tail(a) == graph_.Tail(b)) && + (graph_.Head(a) < graph_.Head(b)))); + } + + private: + const GraphType& graph_; +}; + +namespace or_internal { + +// The GraphBuilderFromArcs class template is not expected to be used +// by clients. It is a helper for the AnnotatedGraphBuildManager +// template. +// +// Deletes itself upon returning the graph! +template +class GraphBuilderFromArcs { + public: + GraphBuilderFromArcs(typename GraphType::NodeIndex max_num_nodes, + typename GraphType::ArcIndex max_num_arcs, + bool sort_arcs) + : num_arcs_(0), sort_arcs_(sort_arcs) { + Reserve(max_num_nodes, max_num_arcs); + } + + typename GraphType::ArcIndex AddArc(typename GraphType::NodeIndex tail, + typename GraphType::NodeIndex head) { + assert(num_arcs_ < max_num_arcs_); + assert(tail < GraphType::kFirstNode + max_num_nodes_); + assert(head < GraphType::kFirstNode + max_num_nodes_); + if (num_arcs_ < max_num_arcs_ && + tail < GraphType::kFirstNode + max_num_nodes_ && + head < GraphType::kFirstNode + max_num_nodes_) { + typename GraphType::ArcIndex result = GraphType::kFirstArc + num_arcs_; + arcs_.push_back(std::make_pair(tail, head)); + num_arcs_ += 1; + return result; + } else { + // Too many arcs or node index out of bounds! + return GraphType::kNilArc; + } + } + + // Builds the graph from the given arcs. + GraphType* Graph(PermutationCycleHandler* + client_cycle_handler) { + GraphType* graph = new GraphType(max_num_nodes_, num_arcs_, sort_arcs_, + &arcs_, client_cycle_handler); + delete this; + return graph; + } + + private: + bool Reserve(typename GraphType::NodeIndex new_max_num_nodes, + typename GraphType::ArcIndex new_max_num_arcs) { + max_num_nodes_ = new_max_num_nodes; + max_num_arcs_ = new_max_num_arcs; + arcs_.reserve(new_max_num_arcs); + return true; + } + + typename GraphType::NodeIndex max_num_nodes_; + typename GraphType::ArcIndex max_num_arcs_; + typename GraphType::ArcIndex num_arcs_; + + std::vector< + std::pair > + arcs_; + + const bool sort_arcs_; +}; + +// Trivial delegating specialization for dynamic graphs. +// +// Deletes itself upon returning the graph! +template +class GraphBuilderFromArcs { + public: + GraphBuilderFromArcs(typename GraphType::NodeIndex max_num_nodes, + typename GraphType::ArcIndex max_num_arcs, + bool sort_arcs) + : graph_(new GraphType(max_num_nodes, max_num_arcs)), + sort_arcs_(sort_arcs) {} + + bool Reserve(const typename GraphType::NodeIndex new_max_num_nodes, + const typename GraphType::ArcIndex new_max_num_arcs) { + return graph_->Reserve(new_max_num_nodes, new_max_num_arcs); + } + + typename GraphType::ArcIndex AddArc( + const typename GraphType::NodeIndex tail, + const typename GraphType::NodeIndex head) { + return graph_->AddArc(tail, head); + } + + GraphType* Graph(PermutationCycleHandler* + client_cycle_handler) { + if (sort_arcs_) { + TailArrayManager tail_array_manager(graph_); + tail_array_manager.BuildTailArrayFromAdjacencyListsIfForwardGraph(); + ArcFunctorOrderingByTailAndHead arc_ordering(*graph_); + graph_->GroupForwardArcsByFunctor(arc_ordering, client_cycle_handler); + tail_array_manager.ReleaseTailArrayIfForwardGraph(); + } + GraphType* result = graph_; + delete this; + return result; + } + + private: + GraphType* const graph_; + const bool sort_arcs_; +}; + +} // namespace or_internal + +template +class AnnotatedGraphBuildManager + : public or_internal::GraphBuilderFromArcs< + GraphType, graph_traits::is_dynamic> { + public: + AnnotatedGraphBuildManager(typename GraphType::NodeIndex num_nodes, + typename GraphType::ArcIndex num_arcs, + bool sort_arcs) + : or_internal::GraphBuilderFromArcs::is_dynamic>( + num_nodes, num_arcs, sort_arcs) {} +}; + +// Builds a directed line graph for 'graph' (see "directed line graph" in +// http://en.wikipedia.org/wiki/Line_graph). Arcs of the original graph +// become nodes and the new graph contains only nodes created from arcs in the +// original graph (we use the notation (a->b) for these new nodes); the index +// of the node (a->b) in the new graph is exactly the same as the index of the +// arc a->b in the original graph. +// An arc from node (a->b) to node (c->d) in the new graph is added if and only +// if b == c in the original graph. +// This method expects that 'line_graph' is an empty graph (it has no nodes +// and no arcs). +// Returns false on an error. +template +bool BuildLineGraph(const GraphType& graph, GraphType* const line_graph) { + if (line_graph == nullptr) { + pyre::journal::firewall_t channel("isce3.unwrap.ortools.ebert_graph"); + channel << pyre::journal::at(__HERE__) + << "line_graph must not be NULL" + << pyre::journal::endl; + return false; + } + if (line_graph->num_nodes() != 0) { + pyre::journal::firewall_t channel("isce3.unwrap.ortools.ebert_graph"); + channel << pyre::journal::at(__HERE__) + << "line_graph must be empty" + << pyre::journal::endl; + return false; + } + typedef typename GraphType::ArcIterator ArcIterator; + typedef typename GraphType::OutgoingArcIterator OutgoingArcIterator; + // Sizing then filling. + typename GraphType::ArcIndex num_arcs = 0; + for (ArcIterator arc_iterator(graph); arc_iterator.Ok(); + arc_iterator.Next()) { + const typename GraphType::ArcIndex arc = arc_iterator.Index(); + const typename GraphType::NodeIndex head = graph.Head(arc); + for (OutgoingArcIterator iterator(graph, head); iterator.Ok(); + iterator.Next()) { + ++num_arcs; + } + } + line_graph->Reserve(graph.num_arcs(), num_arcs); + for (ArcIterator arc_iterator(graph); arc_iterator.Ok(); + arc_iterator.Next()) { + const typename GraphType::ArcIndex arc = arc_iterator.Index(); + const typename GraphType::NodeIndex head = graph.Head(arc); + for (OutgoingArcIterator iterator(graph, head); iterator.Ok(); + iterator.Next()) { + line_graph->AddArc(arc, iterator.Index()); + } + } + return true; +} + +} // namespace operations_research +#endif // OR_TOOLS_GRAPH_EBERT_GRAPH_H_ diff --git a/cxx/isce3/unwrap/ortools/graph.h b/cxx/isce3/unwrap/ortools/graph.h new file mode 100644 index 000000000..40986c3a9 --- /dev/null +++ b/cxx/isce3/unwrap/ortools/graph.h @@ -0,0 +1,2377 @@ +// Copyright 2010-2021 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +// +// This file defines a generic graph interface on which most algorithms can be +// built and provides a few efficient implementations with a fast construction +// time. Its design is based on the experience acquired by the Operations +// Research team in their various graph algorithm implementations. +// +// The main ideas are: +// - Graph nodes and arcs are represented by integers. +// - Node or arc annotations (weight, cost, ...) are not part of the graph +// class, they can be stored outside in one or more arrays and can be easily +// retrieved using a node or arc as an index. +// +// Terminology: +// - An arc of a graph is directed and going from a tail node to a head node. +// - Some implementations also store 'reverse' arcs and can be used for +// undirected graph or flow-like algorithm. +// - A node or arc index is 'valid' if it represents a node or arc of +// the graph. The validity ranges are always [0, num_nodes()) for nodes and +// [0, num_arcs()) for forward arcs. Reverse arcs are elements of +// [-num_arcs(), 0) and are also considered valid by the implementations that +// store them. +// +// Provided implementations: +// - ListGraph<> for the simplest api. Also aliased to util::Graph. +// - StaticGraph<> for performance, but require calling Build(), see below +// - CompleteGraph<> if you need a fully connected graph +// - CompleteBipartiteGraph<> if you need a fully connected bipartite graph +// - ReverseArcListGraph<> to add reverse arcs to ListGraph<> +// - ReverseArcStaticGraph<> to add reverse arcs to StaticGraph<> +// - ReverseArcMixedGraph<> for a smaller memory footprint +// +// Utility classes & functions: +// - Permute() to permute an array according to a given permutation. +// - SVector<> vector with index range [-size(), size()) for ReverseArcGraph. +// +// Basic usage: +// typedef ListGraph<> Graph; // Choose a graph implementation. +// Graph graph; +// for (...) { +// graph.AddArc(tail, head); +// } +// ... +// for (int node = 0; node < graph.num_nodes(); ++node) { +// for (const int arc : graph.OutgoingArcs(node)) { +// head = graph.Head(arc); +// tail = node; // or graph.Tail(arc) which is fast but not as much. +// } +// } +// +// Iteration over the arcs touching a node: +// +// - OutgoingArcs(node): All the forward arcs leaving the node. +// - IncomingArcs(node): All the forward arcs arriving at the node. +// +// And two more involved ones: +// +// - OutgoingOrOppositeIncomingArcs(node): This returns both the forward arcs +// leaving the node (i.e. OutgoingArcs(node)) and the reverse arcs leaving the +// node (i.e. the opposite arcs of the ones returned by IncomingArcs(node)). +// - OppositeIncomingArcs(node): This returns the reverse arcs leaving the node. +// +// Note on iteration efficiency: When re-indexing the arcs it is not possible to +// have both the outgoing arcs and the incoming ones form a consecutive range. +// +// It is however possible to do so for the outgoing arcs and the opposite +// incoming arcs. It is why the OutgoingOrOppositeIncomingArcs() and +// OutgoingArcs() iterations are more efficient than the IncomingArcs() one. +// +// If you know the graph size in advance, this already set the number of nodes, +// reserve space for the arcs and check in DEBUG mode that you don't go over the +// bounds: +// Graph graph(num_nodes, arc_capacity); +// +// Storing and using node annotations: +// vector is_visited(graph.num_nodes(), false); +// ... +// for (int node = 0; node < graph.num_nodes(); ++node) { +// if (!is_visited[node]) ... +// } +// +// Storing and using arc annotations: +// vector weights; +// for (...) { +// graph.AddArc(tail, head); +// weights.push_back(arc_weight); +// } +// ... +// for (const int arc : graph.OutgoingArcs(node)) { +// ... weights[arc] ...; +// } +// +// More efficient version: +// typedef StaticGraph<> Graph; +// Graph graph(num_nodes, arc_capacity); // Optional, but help memory usage. +// vector weights; +// weights.reserve(arc_capacity); // Optional, but help memory usage. +// for (...) { +// graph.AddArc(tail, head); +// weights.push_back(arc_weight); +// } +// ... +// vector permutation; +// graph.Build(&permutation); // A static graph must be Build() before usage. +// Permute(permutation, &weights); // Build() may permute the arc index. +// ... +// +// Encoding an undirected graph: If you don't need arc annotation, then the best +// is to add two arcs for each edge (one in each direction) to a directed graph. +// Otherwise you can do the following. +// +// typedef ReverseArc... Graph; +// Graph graph; +// for (...) { +// graph.AddArc(tail, head); // or graph.AddArc(head, tail) but not both. +// edge_annotations.push_back(value); +// } +// ... +// for (const Graph::NodeIndex node : graph.AllNodes()) { +// for (const Graph::ArcIndex arc : +// graph.OutgoingOrOppositeIncomingArcs(node)) { +// destination = graph.Head(arc); +// annotation = edge_annotations[arc < 0 ? graph.OppositeArc(arc) : arc]; +// } +// } +// +// +// Note: The graphs are primarily designed to be constructed first and then used +// because it covers most of the use cases. It is possible to extend the +// interface with more dynamicity (like removing arcs), but this is not done at +// this point. Note that a "dynamic" implementation will break some assumptions +// we make on what node or arc are valid and also on the indices returned by +// AddArc(). Some arguments for simplifying the interface at the cost of +// dynamicity are: +// +// - It is always possible to construct a static graph from a dynamic one +// before calling a complex algo. +// - If you really need a dynamic graph, maybe it is better to compute a graph +// property incrementally rather than calling an algorithm that starts from +// scratch each time. + +#ifndef UTIL_GRAPH_GRAPH_H_ +#define UTIL_GRAPH_GRAPH_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "iterators.h" + +namespace util { + +// Forward declaration. +template +class SVector; + +// Base class of all Graphs implemented here. The default value for the graph +// index types is int32_t since almost all graphs that fit into memory do not +// need bigger indices. +// +// Note: The type can be unsigned, except for the graphs with reverse arcs +// where the ArcIndexType must be signed, but not necessarly the NodeIndexType. +template +class BaseGraph { + public: + // Typedef so you can use Graph::NodeIndex and Graph::ArcIndex to be generic + // but also to improve the readability of your code. We also recommend + // that you define a typedef ... Graph; for readability. + typedef NodeIndexType NodeIndex; + typedef ArcIndexType ArcIndex; + + BaseGraph() + : num_nodes_(0), + node_capacity_(0), + num_arcs_(0), + arc_capacity_(0), + const_capacities_(false) {} + virtual ~BaseGraph() {} + + // Returns the number of valid nodes in the graph. + NodeIndexType num_nodes() const { return num_nodes_; } + + // Returns the number of valid arcs in the graph. + ArcIndexType num_arcs() const { return num_arcs_; } + + // Allows nice range-based for loop: + // for (const NodeIndex node : graph.AllNodes()) { ... } + // for (const ArcIndex arc : graph.AllForwardArcs()) { ... } + IntegerRange AllNodes() const; + IntegerRange AllForwardArcs() const; + + // Returns true if the given node is a valid node of the graph. + bool IsNodeValid(NodeIndexType node) const { + return node >= 0 && node < num_nodes_; + } + + // Returns true if the given arc is a valid arc of the graph. + // Note that the arc validity range changes for graph with reverse arcs. + bool IsArcValid(ArcIndexType arc) const { + return (HasReverseArcs ? -num_arcs_ : 0) <= arc && arc < num_arcs_; + } + + // Capacity reserved for future nodes, always >= num_nodes_. + NodeIndexType node_capacity() const; + + // Capacity reserved for future arcs, always >= num_arcs_. + ArcIndexType arc_capacity() const; + + // Changes the graph capacities. The functions will fail in debug mode if: + // - const_capacities_ is true. + // - A valid node does not fall into the new node range. + // - A valid arc does not fall into the new arc range. + // In non-debug mode, const_capacities_ is ignored and nothing will happen + // if the new capacity value for the arcs or the nodes is too small. + virtual void ReserveNodes(NodeIndexType bound) { + assert(!const_capacities_); + assert(bound >= num_nodes_); + if (bound <= num_nodes_) return; + node_capacity_ = bound; + } + virtual void ReserveArcs(ArcIndexType bound) { + assert(!const_capacities_); + assert(bound >= num_arcs_); + if (bound <= num_arcs_) return; + arc_capacity_ = bound; + } + void Reserve(NodeIndexType node_capacity, ArcIndexType arc_capacity) { + ReserveNodes(node_capacity); + ReserveArcs(arc_capacity); + } + + // FreezeCapacities() makes any future attempt to change the graph capacities + // crash in DEBUG mode. + void FreezeCapacities(); + + // Constants that will never be a valid node or arc. + // They are the maximum possible node and arc capacity. + static const NodeIndexType kNilNode; + static const ArcIndexType kNilArc; + + // TODO(user): remove the public functions below. They are just here during + // the transition from the old ebert_graph api to this new graph api. + template + void GroupForwardArcsByFunctor(const A& a, B* b) { + pyre::journal::error_t channel("isce3.unwrap.ortools.graph"); + channel << pyre::journal::at(__HERE__) << "Not supported" + << pyre::journal::endl; + } + ArcIndexType max_end_arc_index() const { return arc_capacity_; } + + protected: + // Functions commented when defined because they are implementation details. + void ComputeCumulativeSum(std::vector* v); + void BuildStartAndForwardHead(SVector* head, + std::vector* start, + std::vector* permutation); + + NodeIndexType num_nodes_; + NodeIndexType node_capacity_; + ArcIndexType num_arcs_; + ArcIndexType arc_capacity_; + bool const_capacities_; +}; + +// Basic graph implementation without reverse arc. This class also serves as a +// documentation for the generic graph interface (minus the part related to +// reverse arcs). +// +// This implementation uses a linked list and compared to StaticGraph: +// - Is a bit faster to construct (if the arcs are not ordered by tail). +// - Does not require calling Build(). +// - Has slower outgoing arc iteration. +// - Uses more memory: ArcIndexType * node_capacity() +// + (ArcIndexType + NodeIndexType) * arc_capacity(). +// - Has an efficient Tail() but need an extra NodeIndexType/arc memory for it. +// - Never changes the initial arc index returned by AddArc(). +// +template +class ListGraph : public BaseGraph { + typedef BaseGraph Base; + using Base::arc_capacity_; + using Base::const_capacities_; + using Base::node_capacity_; + using Base::num_arcs_; + using Base::num_nodes_; + + public: + using Base::IsArcValid; + ListGraph() {} + + // Reserve space for the graph at construction and do not allow it to grow + // beyond that, see FreezeCapacities(). This constructor also makes any nodes + // in [0, num_nodes) valid. + ListGraph(NodeIndexType num_nodes, ArcIndexType arc_capacity) { + this->Reserve(num_nodes, arc_capacity); + this->FreezeCapacities(); + this->AddNode(num_nodes - 1); + } + + // If node is not a valid node, sets num_nodes_ to node + 1 so that the given + // node becomes valid. It will fail in DEBUG mode if the capacities are fixed + // and the new node is out of range. + void AddNode(NodeIndexType node); + + // Adds an arc to the graph and returns its current index which will always + // be num_arcs() - 1. It will also automatically call AddNode(tail) + // and AddNode(head). It will fail in DEBUG mode if the capacities + // are fixed and this cause the graph to grow beyond them. + // + // Note: Self referencing arcs and duplicate arcs are supported. + ArcIndexType AddArc(NodeIndexType tail, NodeIndexType head); + + // Some graph implementations need to be finalized with Build() before they + // can be used. After Build() is called, the arc indices (which had been the + // return values of previous AddArc() calls) may change: the new index of + // former arc #i will be stored in permutation[i] if #i is smaller than + // permutation.size() or will be unchanged otherwise. If you don't care about + // these, just call the simple no-output version Build(). + // + // Note that some implementations become immutable after calling Build(). + void Build() { Build(nullptr); } + void Build(std::vector* permutation); + + // Do not use directly. + class OutgoingArcIterator; + class OutgoingHeadIterator; + + // Graph jargon: the "degree" of a node is its number of arcs. The out-degree + // is the number of outgoing arcs. The in-degree is the number of incoming + // arcs, and is only available for some graph implementations, below. + // + // ListGraph<>::OutDegree() works in O(degree). + ArcIndexType OutDegree(NodeIndexType node) const; + + // Allows to iterate over the forward arcs that verify Tail(arc) == node. + // This is meant to be used as: + // for (const ArcIndex arc : graph.OutgoingArcs(node)) { ... } + BeginEndWrapper OutgoingArcs(NodeIndexType node) const; + + // Advanced usage. Same as OutgoingArcs(), but allows to restart the iteration + // from an already known outgoing arc of the given node. + BeginEndWrapper OutgoingArcsStartingFrom( + NodeIndexType node, ArcIndexType from) const; + + // This loops over the heads of the OutgoingArcs(node). It is just a more + // convenient way to achieve this. Moreover this interface is used by some + // graph algorithms. + BeginEndWrapper operator[](NodeIndexType node) const; + + // Returns the tail/head of a valid arc. + NodeIndexType Tail(ArcIndexType arc) const; + NodeIndexType Head(ArcIndexType arc) const; + + void ReserveNodes(NodeIndexType bound) override; + void ReserveArcs(ArcIndexType bound) override; + + private: + std::vector start_; + std::vector next_; + std::vector head_; + std::vector tail_; +}; + +// Most efficient implementation of a graph without reverse arcs: +// - Build() needs to be called after the arc and node have been added. +// - The graph is really compact memory wise: +// ArcIndexType * node_capacity() + 2 * NodeIndexType * arc_capacity(), +// but when Build() is called it uses a temporary extra space of +// ArcIndexType * arc_capacity(). +// - The construction is really fast. +// +// NOTE(user): if the need arises for very-well compressed graphs, we could +// shave NodeIndexType * arc_capacity() off the permanent memory requirement +// with a similar class that doesn't support Tail(), i.e. +// StaticGraphWithoutTail<>. This almost corresponds to a past implementation +// of StaticGraph<> @CL 116144340. +template +class StaticGraph : public BaseGraph { + typedef BaseGraph Base; + using Base::arc_capacity_; + using Base::const_capacities_; + using Base::node_capacity_; + using Base::num_arcs_; + using Base::num_nodes_; + + public: + using Base::IsArcValid; + StaticGraph() : is_built_(false), arc_in_order_(true), last_tail_seen_(0) {} + StaticGraph(NodeIndexType num_nodes, ArcIndexType arc_capacity) + : is_built_(false), arc_in_order_(true), last_tail_seen_(0) { + this->Reserve(num_nodes, arc_capacity); + this->FreezeCapacities(); + this->AddNode(num_nodes - 1); + } + + // Do not use directly. See instead the arc iteration functions below. + class OutgoingArcIterator; + + NodeIndexType Head(ArcIndexType arc) const; + NodeIndexType Tail(ArcIndexType arc) const; + ArcIndexType OutDegree(NodeIndexType node) const; // Work in O(1). + BeginEndWrapper OutgoingArcs(NodeIndexType node) const; + BeginEndWrapper OutgoingArcsStartingFrom( + NodeIndexType node, ArcIndexType from) const; + + // This loops over the heads of the OutgoingArcs(node). It is just a more + // convenient way to achieve this. Moreover this interface is used by some + // graph algorithms. + BeginEndWrapper operator[](NodeIndexType node) const; + + void ReserveNodes(NodeIndexType bound) override; + void ReserveArcs(ArcIndexType bound) override; + void AddNode(NodeIndexType node); + ArcIndexType AddArc(NodeIndexType tail, NodeIndexType head); + + void Build() { Build(nullptr); } + void Build(std::vector* permutation); + + private: + ArcIndexType DirectArcLimit(NodeIndexType node) const { + assert(is_built_); + assert(Base::IsNodeValid(node)); + return node + 1 < num_nodes_ ? start_[node + 1] : num_arcs_; + } + + bool is_built_; + bool arc_in_order_; + NodeIndexType last_tail_seen_; + std::vector start_; + std::vector head_; + std::vector tail_; +}; + +// Extends the ListGraph by also storing the reverse arcs. +// This class also documents the Graph interface related to reverse arc. +// - NodeIndexType can be unsigned, but ArcIndexType must be signed. +// - It has most of the same advantanges and disadvantages as ListGraph. +// - It takes 2 * ArcIndexType * node_capacity() +// + 2 * (ArcIndexType + NodeIndexType) * arc_capacity() memory. +template +class ReverseArcListGraph + : public BaseGraph { + typedef BaseGraph Base; + using Base::arc_capacity_; + using Base::const_capacities_; + using Base::node_capacity_; + using Base::num_arcs_; + using Base::num_nodes_; + + public: + using Base::IsArcValid; + ReverseArcListGraph() {} + ReverseArcListGraph(NodeIndexType num_nodes, ArcIndexType arc_capacity) { + this->Reserve(num_nodes, arc_capacity); + this->FreezeCapacities(); + this->AddNode(num_nodes - 1); + } + + // Returns the opposite arc of a given arc. That is the reverse arc of the + // given forward arc or the forward arc of a given reverse arc. + ArcIndexType OppositeArc(ArcIndexType arc) const; + + // Do not use directly. See instead the arc iteration functions below. + class OutgoingOrOppositeIncomingArcIterator; + class OppositeIncomingArcIterator; + class IncomingArcIterator; + class OutgoingArcIterator; + class OutgoingHeadIterator; + + // ReverseArcListGraph<>::OutDegree() and ::InDegree() work in O(degree). + ArcIndexType OutDegree(NodeIndexType node) const; + ArcIndexType InDegree(NodeIndexType node) const; + + // Arc iterations functions over the arcs touching a node (see the top-level + // comment for the different types). To be used as follows: + // for (const Graph::ArcIndex arc : IterationFunction(node)) { ... } + // + // The StartingFrom() version are similar, but restart the iteration from a + // given arc position (which must be valid in the iteration context). + BeginEndWrapper OutgoingArcs(NodeIndexType node) const; + BeginEndWrapper IncomingArcs(NodeIndexType node) const; + BeginEndWrapper + OutgoingOrOppositeIncomingArcs(NodeIndexType node) const; + BeginEndWrapper OppositeIncomingArcs( + NodeIndexType node) const; + BeginEndWrapper OutgoingArcsStartingFrom( + NodeIndexType node, ArcIndexType from) const; + BeginEndWrapper IncomingArcsStartingFrom( + NodeIndexType node, ArcIndexType from) const; + BeginEndWrapper + OutgoingOrOppositeIncomingArcsStartingFrom(NodeIndexType node, + ArcIndexType from) const; + BeginEndWrapper OppositeIncomingArcsStartingFrom( + NodeIndexType node, ArcIndexType from) const; + + // This loops over the heads of the OutgoingArcs(node). It is just a more + // convenient way to achieve this. Moreover this interface is used by some + // graph algorithms. + BeginEndWrapper operator[](NodeIndexType node) const; + + NodeIndexType Head(ArcIndexType arc) const; + NodeIndexType Tail(ArcIndexType arc) const; + + void ReserveNodes(NodeIndexType bound) override; + void ReserveArcs(ArcIndexType bound) override; + void AddNode(NodeIndexType node); + ArcIndexType AddArc(NodeIndexType tail, NodeIndexType head); + + void Build() { Build(nullptr); } + void Build(std::vector* permutation); + + private: + std::vector start_; + std::vector reverse_start_; + SVector next_; + SVector head_; +}; + +// StaticGraph with reverse arc. +// - NodeIndexType can be unsigned, but ArcIndexType must be signed. +// - It has most of the same advantanges and disadvantages as StaticGraph. +// - It takes 2 * ArcIndexType * node_capacity() +// + 2 * (ArcIndexType + NodeIndexType) * arc_capacity() memory. +// - If the ArcIndexPermutation is needed, then an extra ArcIndexType * +// arc_capacity() is needed for it. +// - The reverse arcs from a node are sorted by head (so we could add a log() +// time lookup function). +template +class ReverseArcStaticGraph + : public BaseGraph { + typedef BaseGraph Base; + using Base::arc_capacity_; + using Base::const_capacities_; + using Base::node_capacity_; + using Base::num_arcs_; + using Base::num_nodes_; + + public: + using Base::IsArcValid; + ReverseArcStaticGraph() : is_built_(false) {} + ReverseArcStaticGraph(NodeIndexType num_nodes, ArcIndexType arc_capacity) + : is_built_(false) { + this->Reserve(num_nodes, arc_capacity); + this->FreezeCapacities(); + this->AddNode(num_nodes - 1); + } + + // Deprecated. + class OutgoingOrOppositeIncomingArcIterator; + class OppositeIncomingArcIterator; + class IncomingArcIterator; + class OutgoingArcIterator; + + // ReverseArcStaticGraph<>::OutDegree() and ::InDegree() work in O(1). + ArcIndexType OutDegree(NodeIndexType node) const; + ArcIndexType InDegree(NodeIndexType node) const; + + BeginEndWrapper OutgoingArcs(NodeIndexType node) const; + BeginEndWrapper IncomingArcs(NodeIndexType node) const; + BeginEndWrapper + OutgoingOrOppositeIncomingArcs(NodeIndexType node) const; + BeginEndWrapper OppositeIncomingArcs( + NodeIndexType node) const; + BeginEndWrapper OutgoingArcsStartingFrom( + NodeIndexType node, ArcIndexType from) const; + BeginEndWrapper IncomingArcsStartingFrom( + NodeIndexType node, ArcIndexType from) const; + BeginEndWrapper + OutgoingOrOppositeIncomingArcsStartingFrom(NodeIndexType node, + ArcIndexType from) const; + BeginEndWrapper OppositeIncomingArcsStartingFrom( + NodeIndexType node, ArcIndexType from) const; + + // This loops over the heads of the OutgoingArcs(node). It is just a more + // convenient way to achieve this. Moreover this interface is used by some + // graph algorithms. + BeginEndWrapper operator[](NodeIndexType node) const; + + ArcIndexType OppositeArc(ArcIndexType arc) const; + // TODO(user): support Head() and Tail() before Build(), like StaticGraph<>. + NodeIndexType Head(ArcIndexType arc) const; + NodeIndexType Tail(ArcIndexType arc) const; + + void ReserveArcs(ArcIndexType bound) override; + void AddNode(NodeIndexType node); + ArcIndexType AddArc(NodeIndexType tail, NodeIndexType head); + + void Build() { Build(nullptr); } + void Build(std::vector* permutation); + + private: + ArcIndexType DirectArcLimit(NodeIndexType node) const { + assert(is_built_); + assert(Base::IsNodeValid(node)); + return node + 1 < num_nodes_ ? start_[node + 1] : num_arcs_; + } + ArcIndexType ReverseArcLimit(NodeIndexType node) const { + assert(is_built_); + assert(Base::IsNodeValid(node)); + return node + 1 < num_nodes_ ? reverse_start_[node + 1] : 0; + } + + bool is_built_; + std::vector start_; + std::vector reverse_start_; + SVector head_; + SVector opposite_; +}; + +// This graph is a mix between the ReverseArcListGraph and the +// ReverseArcStaticGraph. It uses less memory: +// - It takes 2 * ArcIndexType * node_capacity() +// + (2 * NodeIndexType + ArcIndexType) * arc_capacity() memory. +// - If the ArcIndexPermutation is needed, then an extra ArcIndexType * +// arc_capacity() is needed for it. +template +class ReverseArcMixedGraph + : public BaseGraph { + typedef BaseGraph Base; + using Base::arc_capacity_; + using Base::const_capacities_; + using Base::node_capacity_; + using Base::num_arcs_; + using Base::num_nodes_; + + public: + using Base::IsArcValid; + ReverseArcMixedGraph() : is_built_(false) {} + ReverseArcMixedGraph(NodeIndexType num_nodes, ArcIndexType arc_capacity) + : is_built_(false) { + this->Reserve(num_nodes, arc_capacity); + this->FreezeCapacities(); + this->AddNode(num_nodes - 1); + } + + // Deprecated. + class OutgoingOrOppositeIncomingArcIterator; + class OppositeIncomingArcIterator; + class IncomingArcIterator; + class OutgoingArcIterator; + + ArcIndexType OutDegree(NodeIndexType node) const; // O(1) + ArcIndexType InDegree(NodeIndexType node) const; // O(in-degree) + + BeginEndWrapper OutgoingArcs(NodeIndexType node) const; + BeginEndWrapper IncomingArcs(NodeIndexType node) const; + BeginEndWrapper + OutgoingOrOppositeIncomingArcs(NodeIndexType node) const; + BeginEndWrapper OppositeIncomingArcs( + NodeIndexType node) const; + BeginEndWrapper OutgoingArcsStartingFrom( + NodeIndexType node, ArcIndexType from) const; + BeginEndWrapper IncomingArcsStartingFrom( + NodeIndexType node, ArcIndexType from) const; + BeginEndWrapper + OutgoingOrOppositeIncomingArcsStartingFrom(NodeIndexType node, + ArcIndexType from) const; + BeginEndWrapper OppositeIncomingArcsStartingFrom( + NodeIndexType node, ArcIndexType from) const; + + // This loops over the heads of the OutgoingArcs(node). It is just a more + // convenient way to achieve this. Moreover this interface is used by some + // graph algorithms. + BeginEndWrapper operator[](NodeIndexType node) const; + + ArcIndexType OppositeArc(ArcIndexType arc) const; + // TODO(user): support Head() and Tail() before Build(), like StaticGraph<>. + NodeIndexType Head(ArcIndexType arc) const; + NodeIndexType Tail(ArcIndexType arc) const; + + void ReserveArcs(ArcIndexType bound) override; + void AddNode(NodeIndexType node); + ArcIndexType AddArc(NodeIndexType tail, NodeIndexType head); + + void Build() { Build(nullptr); } + void Build(std::vector* permutation); + + private: + ArcIndexType DirectArcLimit(NodeIndexType node) const { + assert(is_built_); + assert(Base::IsNodeValid(node)); + return node + 1 < num_nodes_ ? start_[node + 1] : num_arcs_; + } + + bool is_built_; + std::vector start_; + std::vector reverse_start_; + std::vector next_; + SVector head_; +}; + +// Permutes the elements of array_to_permute: element #i will be moved to +// position permutation[i]. permutation must be either empty (in which case +// nothing happens), or a permutation of [0, permutation.size()). +// +// The algorithm is fast but need extra memory for a copy of the permuted part +// of array_to_permute. +// +// TODO(user): consider slower but more memory efficient implementations that +// follow the cycles of the permutation and use a bitmap to indicate what has +// been permuted or to mark the beginning of each cycle. + +// Some compiler do not know typeof(), so we have to use this extra function +// internally. +template +void PermuteWithExplicitElementType(const IntVector& permutation, + Array* array_to_permute, + ElementType unused) { + std::vector temp(permutation.size()); + for (int i = 0; i < permutation.size(); ++i) { + temp[i] = (*array_to_permute)[i]; + } + for (int i = 0; i < permutation.size(); ++i) { + (*array_to_permute)[permutation[i]] = temp[i]; + } +} + +template +void Permute(const IntVector& permutation, Array* array_to_permute) { + if (permutation.empty()) { + return; + } + PermuteWithExplicitElementType(permutation, array_to_permute, + (*array_to_permute)[0]); +} + +// We need a specialization for vector, because the default code uses +// (*array_to_permute)[0] as ElementType, which isn't 'bool' in that case. +template +void Permute(const IntVector& permutation, + std::vector* array_to_permute) { + if (permutation.empty()) { + return; + } + bool unused = false; + PermuteWithExplicitElementType(permutation, array_to_permute, unused); +} + +// A vector-like class where valid indices are in [- size_, size_) and reserved +// indices for future growth are in [- capacity_, capacity_). It is used to hold +// arc related information for graphs with reverse arcs. +// It supports only up to 2^31-1 elements, for compactness. If you ever need +// more, consider using templates for the size/capacity integer types. +// +// Sample usage: +// +// SVector v; +// v.grow(left_value, right_value); +// v.resize(10); +// v.clear(); +// v.swap(new_v); +// std:swap(v[i], v[~i]); +template +class SVector { + public: + SVector() : base_(nullptr), size_(0), capacity_(0) {} + + ~SVector() { clear_and_dealloc(); } + + // Copy constructor and assignment operator. + SVector(const SVector& other) : SVector() { *this = other; } + SVector& operator=(const SVector& other) { + if (capacity_ < other.size_) { + clear_and_dealloc(); + // NOTE(user): Alternatively, our capacity could inherit from the other + // vector's capacity, which can be (much) greater than its size. + capacity_ = other.size_; + base_ = Allocate(capacity_); + if (base_ == nullptr) { + throw isce3::except::RuntimeError(ISCE_SRCINFO(), "base_ == nullptr"); + } + base_ += capacity_; + } else { // capacity_ >= other.size + clear(); + } + // Perform the actual copy of the payload. + size_ = other.size_; + for (int i = -size_; i < size_; ++i) { + new (base_ + i) T(other.base_[i]); + } + return *this; + } + + // Move constructor and move assignment operator. + SVector(SVector&& other) : SVector() { swap(other); } + SVector& operator=(SVector&& other) { + // NOTE(user): We could just swap() and let the other's destruction take + // care of the clean-up, but it is probably less bug-prone to perform the + // destruction immediately. + clear_and_dealloc(); + swap(other); + return *this; + } + + T& operator[](int n) { + assert(n < size_); + assert(n >= -size_); + return base_[n]; + } + + const T& operator[](int n) const { + assert(n < size_); + assert(n >= -size_); + return base_[n]; + } + + void resize(int n) { + reserve(n); + for (int i = -n; i < -size_; ++i) { + new (base_ + i) T(); + } + for (int i = size_; i < n; ++i) { + new (base_ + i) T(); + } + for (int i = -size_; i < -n; ++i) { + base_[i].~T(); + } + for (int i = n; i < size_; ++i) { + base_[i].~T(); + } + size_ = n; + } + + void clear() { resize(0); } + + T* data() const { return base_; } + + void swap(SVector& x) { + std::swap(base_, x.base_); + std::swap(size_, x.size_); + std::swap(capacity_, x.capacity_); + } + + void reserve(int n) { + assert(n >= 0); + assert(n <= max_size()); + if (n > capacity_) { + const int new_capacity = std::min(n, max_size()); + T* new_storage = Allocate(new_capacity); + if (new_storage == nullptr) { + throw isce3::except::RuntimeError( + ISCE_SRCINFO(), "new_storage == nullptr"); + } + T* new_base = new_storage + new_capacity; + // TODO(user): in C++17 we could use std::uninitialized_move instead + // of this loop. + for (int i = -size_; i < size_; ++i) { + new (new_base + i) T(std::move(base_[i])); + } + int saved_size = size_; + clear_and_dealloc(); + size_ = saved_size; + base_ = new_base; + capacity_ = new_capacity; + } + } + + // NOTE(user): This doesn't currently support movable-only objects, but we + // could fix that. + void grow(const T& left = T(), const T& right = T()) { + if (size_ == capacity_) { + // We have to copy the elements because they are allowed to be element of + // *this. + T left_copy(left); // NOLINT + T right_copy(right); // NOLINT + reserve(NewCapacity(1)); + new (base_ + size_) T(right_copy); + new (base_ - size_ - 1) T(left_copy); + ++size_; + } else { + new (base_ + size_) T(right); + new (base_ - size_ - 1) T(left); + ++size_; + } + } + + int size() const { return size_; } + + int capacity() const { return capacity_; } + + int max_size() const { return std::numeric_limits::max(); } + + void clear_and_dealloc() { + if (base_ == nullptr) return; + clear(); + if (capacity_ > 0) { + free(base_ - capacity_); + } + capacity_ = 0; + base_ = nullptr; + } + + private: + T* Allocate(int capacity) const { + return static_cast(malloc(2LL * capacity * sizeof(T))); + } + + int NewCapacity(int delta) { + // TODO(user): check validity. + double candidate = 1.3 * static_cast(capacity_); + if (candidate > static_cast(max_size())) { + candidate = static_cast(max_size()); + } + int new_capacity = static_cast(candidate); + if (new_capacity > capacity_ + delta) { + return new_capacity; + } + return capacity_ + delta; + } + + T* base_; // Pointer to the element of index 0. + int size_; // Valid index are [- size_, size_). + int capacity_; // Reserved index are [- capacity_, capacity_). +}; + +// BaseGraph implementation ---------------------------------------------------- + +template +IntegerRange +BaseGraph::AllNodes() const { + return IntegerRange(0, num_nodes_); +} + +template +IntegerRange +BaseGraph::AllForwardArcs() const { + return IntegerRange(0, num_arcs_); +} + +template +const NodeIndexType + BaseGraph::kNilNode = + std::numeric_limits::max(); + +template +const ArcIndexType + BaseGraph::kNilArc = + std::numeric_limits::max(); + +template +NodeIndexType +BaseGraph::node_capacity() const { + // TODO(user): Is it needed? remove completely? return the real capacities + // at the cost of having a different implementation for each graphs? + return node_capacity_ > num_nodes_ ? node_capacity_ : num_nodes_; +} + +template +ArcIndexType +BaseGraph::arc_capacity() const { + // TODO(user): Same questions as the ones in node_capacity(). + return arc_capacity_ > num_arcs_ ? arc_capacity_ : num_arcs_; +} + +template +void BaseGraph::FreezeCapacities() { + // TODO(user): Only define this in debug mode at the cost of having a lot + // of ifndef NDEBUG all over the place? remove the function completely ? + const_capacities_ = true; + node_capacity_ = std::max(node_capacity_, num_nodes_); + arc_capacity_ = std::max(arc_capacity_, num_arcs_); +} + +// Computes the cumulative sum of the entry in v. We only use it with +// in/out degree distribution, hence the Check() at the end. +template +void BaseGraph:: + ComputeCumulativeSum(std::vector* v) { + ArcIndexType sum = 0; + for (int i = 0; i < num_nodes_; ++i) { + ArcIndexType temp = (*v)[i]; + (*v)[i] = sum; + sum += temp; + } + assert(sum == num_arcs_); +} + +// Given the tail of arc #i in (*head)[i] and the head of arc #i in (*head)[~i] +// - Reorder the arc by increasing tail. +// - Put the head of the new arc #i in (*head)[i]. +// - Put in start[i] the index of the first arc with tail >= i. +// - Update "permutation" to reflect the change, unless it is NULL. +template +void BaseGraph:: + BuildStartAndForwardHead(SVector* head, + std::vector* start, + std::vector* permutation) { + // Computes the outgoing degree of each nodes and check if we need to permute + // something or not. Note that the tails are currently stored in the positive + // range of the SVector head. + start->assign(num_nodes_, 0); + int last_tail_seen = 0; + bool permutation_needed = false; + for (int i = 0; i < num_arcs_; ++i) { + NodeIndexType tail = (*head)[i]; + if (!permutation_needed) { + permutation_needed = tail < last_tail_seen; + last_tail_seen = tail; + } + (*start)[tail]++; + } + ComputeCumulativeSum(start); + + // Abort early if we do not need the permutation: we only need to put the + // heads in the positive range. + if (!permutation_needed) { + for (int i = 0; i < num_arcs_; ++i) { + (*head)[i] = (*head)[~i]; + } + if (permutation != nullptr) { + permutation->clear(); + } + return; + } + + // Computes the forward arc permutation. + // Note that this temporarily alters the start vector. + std::vector perm(num_arcs_); + for (int i = 0; i < num_arcs_; ++i) { + perm[i] = (*start)[(*head)[i]]++; + } + + // Restore in (*start)[i] the index of the first arc with tail >= i. + for (int i = num_nodes_ - 1; i > 0; --i) { + (*start)[i] = (*start)[i - 1]; + } + (*start)[0] = 0; + + // Permutes the head into their final position in head. + // We do not need the tails anymore at this point. + for (int i = 0; i < num_arcs_; ++i) { + (*head)[perm[i]] = (*head)[~i]; + } + if (permutation != nullptr) { + permutation->swap(perm); + } +} + +// --------------------------------------------------------------------------- +// Macros to wrap old style iteration into the new range-based for loop style. +// --------------------------------------------------------------------------- + +// The parameters are: +// - c: the class name. +// - t: the iteration type (Outgoing, Incoming, OutgoingOrOppositeIncoming +// or OppositeIncoming). +// - e: the "end" ArcIndexType. +#define DEFINE_RANGE_BASED_ARC_ITERATION(c, t, e) \ + template \ + BeginEndWrapper::t##ArcIterator> \ + c::t##Arcs(NodeIndexType node) const { \ + return BeginEndWrapper(t##ArcIterator(*this, node), \ + t##ArcIterator(*this, node, e)); \ + } \ + template \ + BeginEndWrapper::t##ArcIterator> \ + c::t##ArcsStartingFrom( \ + NodeIndexType node, ArcIndexType from) const { \ + return BeginEndWrapper(t##ArcIterator(*this, node, from), \ + t##ArcIterator(*this, node, e)); \ + } + +// Adapt our old iteration style to support range-based for loops. Add typedefs +// required by std::iterator_traits. +#define DEFINE_STL_ITERATOR_FUNCTIONS(iterator_class_name) \ + using iterator_category = std::input_iterator_tag; \ + using difference_type = ptrdiff_t; \ + using pointer = const ArcIndexType*; \ + using reference = const ArcIndexType&; \ + using value_type = ArcIndexType; \ + bool operator!=(const iterator_class_name& other) const { \ + return this->index_ != other.index_; \ + } \ + bool operator==(const iterator_class_name& other) const { \ + return this->index_ == other.index_; \ + } \ + ArcIndexType operator*() const { return this->Index(); } \ + void operator++() { this->Next(); } + +// ListGraph implementation ---------------------------------------------------- + +DEFINE_RANGE_BASED_ARC_ITERATION(ListGraph, Outgoing, Base::kNilArc) + +template +BeginEndWrapper< + typename ListGraph::OutgoingHeadIterator> +ListGraph::operator[](NodeIndexType node) const { + return BeginEndWrapper( + OutgoingHeadIterator(*this, node), + OutgoingHeadIterator(*this, node, Base::kNilArc)); +} + +template +NodeIndexType ListGraph::Tail( + ArcIndexType arc) const { + assert(IsArcValid(arc)); + return tail_[arc]; +} + +template +NodeIndexType ListGraph::Head( + ArcIndexType arc) const { + assert(IsArcValid(arc)); + return head_[arc]; +} + +template +ArcIndexType ListGraph::OutDegree( + NodeIndexType node) const { + ArcIndexType degree(0); + for (auto arc : OutgoingArcs(node)) ++degree; + return degree; +} + +template +void ListGraph::AddNode(NodeIndexType node) { + if (node < num_nodes_) return; + assert(!const_capacities_ || node < node_capacity_); + num_nodes_ = node + 1; + start_.resize(num_nodes_, Base::kNilArc); +} + +template +ArcIndexType ListGraph::AddArc( + NodeIndexType tail, NodeIndexType head) { + assert(tail >= 0); + assert(head >= 0); + AddNode(tail > head ? tail : head); + head_.push_back(head); + tail_.push_back(tail); + next_.push_back(start_[tail]); + start_[tail] = num_arcs_; + assert(!const_capacities_ || num_arcs_ < arc_capacity_); + return num_arcs_++; +} + +template +void ListGraph::ReserveNodes(NodeIndexType bound) { + Base::ReserveNodes(bound); + if (bound <= num_nodes_) return; + start_.reserve(bound); +} + +template +void ListGraph::ReserveArcs(ArcIndexType bound) { + Base::ReserveArcs(bound); + if (bound <= num_arcs_) return; + head_.reserve(bound); + tail_.reserve(bound); + next_.reserve(bound); +} + +template +void ListGraph::Build( + std::vector* permutation) { + if (permutation != nullptr) { + permutation->clear(); + } +} + +template +class ListGraph::OutgoingArcIterator { + public: + OutgoingArcIterator(const ListGraph& graph, NodeIndexType node) + : graph_(graph), index_(graph.start_[node]) { + assert(graph.IsNodeValid(node)); + } + OutgoingArcIterator(const ListGraph& graph, NodeIndexType node, + ArcIndexType arc) + : graph_(graph), index_(arc) { + assert(graph.IsNodeValid(node)); + assert(arc == Base::kNilArc || graph.Tail(arc) == node); + } + bool Ok() const { return index_ != Base::kNilArc; } + ArcIndexType Index() const { return index_; } + void Next() { + assert(Ok()); + index_ = graph_.next_[index_]; + } + + DEFINE_STL_ITERATOR_FUNCTIONS(OutgoingArcIterator); + + private: + const ListGraph& graph_; + ArcIndexType index_; +}; + +template +class ListGraph::OutgoingHeadIterator { + public: + using iterator_category = std::input_iterator_tag; + using difference_type = ptrdiff_t; + using pointer = const NodeIndexType*; + using reference = const NodeIndexType&; + using value_type = NodeIndexType; + + OutgoingHeadIterator(const ListGraph& graph, NodeIndexType node) + : graph_(graph), index_(graph.start_[node]) { + assert(graph.IsNodeValid(node)); + } + OutgoingHeadIterator(const ListGraph& graph, NodeIndexType node, + ArcIndexType arc) + : graph_(graph), index_(arc) { + assert(graph.IsNodeValid(node)); + assert(arc == Base::kNilArc || graph.Tail(arc) == node); + } + bool Ok() const { return index_ != Base::kNilArc; } + NodeIndexType Index() const { return graph_.Head(index_); } + void Next() { + assert(Ok()); + index_ = graph_.next_[index_]; + } + + bool operator!=( + const typename ListGraph< + NodeIndexType, ArcIndexType>::OutgoingHeadIterator& other) const { + return index_ != other.index_; + } + NodeIndexType operator*() const { return Index(); } + void operator++() { Next(); } + + private: + const ListGraph& graph_; + ArcIndexType index_; +}; + +// StaticGraph implementation -------------------------------------------------- + +DEFINE_RANGE_BASED_ARC_ITERATION(StaticGraph, Outgoing, DirectArcLimit(node)) + +template +BeginEndWrapper +StaticGraph::operator[](NodeIndexType node) const { + return BeginEndWrapper( + head_.data() + start_[node], head_.data() + DirectArcLimit(node)); +} + +template +ArcIndexType StaticGraph::OutDegree( + NodeIndexType node) const { + return DirectArcLimit(node) - start_[node]; +} + +template +void StaticGraph::ReserveNodes( + NodeIndexType bound) { + Base::ReserveNodes(bound); + if (bound <= num_nodes_) return; + start_.reserve(bound); +} + +template +void StaticGraph::ReserveArcs(ArcIndexType bound) { + Base::ReserveArcs(bound); + if (bound <= num_arcs_) return; + head_.reserve(bound); + tail_.reserve(bound); +} + +template +void StaticGraph::AddNode(NodeIndexType node) { + if (node < num_nodes_) return; + assert(!const_capacities_ || node < node_capacity_) << node; + num_nodes_ = node + 1; + start_.resize(num_nodes_, 0); +} + +template +ArcIndexType StaticGraph::AddArc( + NodeIndexType tail, NodeIndexType head) { + assert(tail >= 0); + assert(head >= 0); + assert(!is_built_); + AddNode(tail > head ? tail : head); + if (arc_in_order_) { + if (tail >= last_tail_seen_) { + start_[tail]++; + last_tail_seen_ = tail; + } else { + arc_in_order_ = false; + } + } + tail_.push_back(tail); + head_.push_back(head); + assert(!const_capacities_ || num_arcs_ < arc_capacity_); + return num_arcs_++; +} + +template +NodeIndexType StaticGraph::Tail( + ArcIndexType arc) const { + assert(IsArcValid(arc)); + return tail_[arc]; +} + +template +NodeIndexType StaticGraph::Head( + ArcIndexType arc) const { + assert(IsArcValid(arc)); + return head_[arc]; +} + +// Implementation details: A reader may be surprised that we do many passes +// into the data where things could be done in one pass. For instance, during +// construction, we store the edges first, and then do a second pass at the +// end to compute the degree distribution. +// +// This is because it is a lot more efficient cache-wise to do it this way. +// This was determined by various experiments, but can also be understood: +// - during repetitive call to AddArc() a client usually accesses various +// areas of memory, and there is no reason to polute the cache with +// possibly random access to degree[i]. +// - When the degrees are needed, we compute them in one go, maximizing the +// chance of cache hit during the computation. +template +void StaticGraph::Build( + std::vector* permutation) { + assert(!is_built_); + if (is_built_) return; + is_built_ = true; + node_capacity_ = num_nodes_; + arc_capacity_ = num_arcs_; + this->FreezeCapacities(); + + // If Arc are in order, start_ already contains the degree distribution. + if (arc_in_order_) { + if (permutation != nullptr) { + permutation->clear(); + } + this->ComputeCumulativeSum(&start_); + return; + } + + // Computes outgoing degree of each nodes. We have to clear start_, since + // at least the first arc was processed with arc_in_order_ == true. + start_.assign(num_nodes_, 0); + for (int i = 0; i < num_arcs_; ++i) { + start_[tail_[i]]++; + } + this->ComputeCumulativeSum(&start_); + + // Computes the forward arc permutation. + // Note that this temporarily alters the start_ vector. + std::vector perm(num_arcs_); + for (int i = 0; i < num_arcs_; ++i) { + perm[i] = start_[tail_[i]]++; + } + + // We use "tail_" (which now contains rubbish) to permute "head_" faster. + if (tail_.size() != num_arcs_) { + throw isce3::except::RuntimeError( + ISCE_SRCINFO(), "tail_.size() != num_arcs"); + } + tail_.swap(head_); + for (int i = 0; i < num_arcs_; ++i) { + head_[perm[i]] = tail_[i]; + } + + if (permutation != nullptr) { + permutation->swap(perm); + } + + // Restore in start_[i] the index of the first arc with tail >= i. + for (int i = num_nodes_ - 1; i > 0; --i) { + start_[i] = start_[i - 1]; + } + start_[0] = 0; + + // Recompute the correct tail_ vector + for (const NodeIndexType node : Base::AllNodes()) { + for (const ArcIndexType arc : OutgoingArcs(node)) { + tail_[arc] = node; + } + } +} + +template +class StaticGraph::OutgoingArcIterator { + public: + OutgoingArcIterator(const StaticGraph& graph, NodeIndexType node) + : index_(graph.start_[node]), limit_(graph.DirectArcLimit(node)) {} + OutgoingArcIterator(const StaticGraph& graph, NodeIndexType node, + ArcIndexType arc) + : index_(arc), limit_(graph.DirectArcLimit(node)) { + assert(arc >= graph.start_[node]); + } + + bool Ok() const { return index_ < limit_; } + ArcIndexType Index() const { return index_; } + void Next() { + assert(Ok()); + index_++; + } + + // Note(user): we lose a bit by returning a BeginEndWrapper<> on top of + // this iterator rather than a simple IntegerRange<> on the arc indices. + // On my computer: around 420M arcs/sec instead of 440M arcs/sec. + // + // However, it is slightly more consistent to do it this way, and we don't + // have two different codes depending on the way a client iterates on the + // arcs. + DEFINE_STL_ITERATOR_FUNCTIONS(OutgoingArcIterator); + + private: + ArcIndexType index_; + const ArcIndexType limit_; +}; + +// ReverseArcListGraph implementation ------------------------------------------ + +DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcListGraph, Outgoing, Base::kNilArc) +DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcListGraph, Incoming, Base::kNilArc) +DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcListGraph, + OutgoingOrOppositeIncoming, Base::kNilArc) +DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcListGraph, OppositeIncoming, + Base::kNilArc) + +template +BeginEndWrapper::OutgoingHeadIterator> +ReverseArcListGraph::operator[]( + NodeIndexType node) const { + return BeginEndWrapper( + OutgoingHeadIterator(*this, node), + OutgoingHeadIterator(*this, node, Base::kNilArc)); +} + +template +ArcIndexType ReverseArcListGraph::OutDegree( + NodeIndexType node) const { + ArcIndexType degree(0); + for (auto arc : OutgoingArcs(node)) ++degree; + return degree; +} + +template +ArcIndexType ReverseArcListGraph::InDegree( + NodeIndexType node) const { + ArcIndexType degree(0); + for (auto arc : OppositeIncomingArcs(node)) ++degree; + return degree; +} + +template +ArcIndexType ReverseArcListGraph::OppositeArc( + ArcIndexType arc) const { + assert(IsArcValid(arc)); + return ~arc; +} + +template +NodeIndexType ReverseArcListGraph::Head( + ArcIndexType arc) const { + assert(IsArcValid(arc)); + return head_[arc]; +} + +template +NodeIndexType ReverseArcListGraph::Tail( + ArcIndexType arc) const { + return head_[OppositeArc(arc)]; +} + +template +void ReverseArcListGraph::ReserveNodes( + NodeIndexType bound) { + Base::ReserveNodes(bound); + if (bound <= num_nodes_) return; + start_.reserve(bound); + reverse_start_.reserve(bound); +} + +template +void ReverseArcListGraph::ReserveArcs( + ArcIndexType bound) { + Base::ReserveArcs(bound); + if (bound <= num_arcs_) return; + head_.reserve(bound); + next_.reserve(bound); +} + +template +void ReverseArcListGraph::AddNode( + NodeIndexType node) { + if (node < num_nodes_) return; + assert(!const_capacities_ || node < node_capacity_); + num_nodes_ = node + 1; + start_.resize(num_nodes_, Base::kNilArc); + reverse_start_.resize(num_nodes_, Base::kNilArc); +} + +template +ArcIndexType ReverseArcListGraph::AddArc( + NodeIndexType tail, NodeIndexType head) { + assert(tail >= 0); + assert(head >= 0); + AddNode(tail > head ? tail : head); + head_.grow(tail, head); + next_.grow(reverse_start_[head], start_[tail]); + start_[tail] = num_arcs_; + reverse_start_[head] = ~num_arcs_; + assert(!const_capacities_ || num_arcs_ < arc_capacity_); + return num_arcs_++; +} + +template +void ReverseArcListGraph::Build( + std::vector* permutation) { + if (permutation != nullptr) { + permutation->clear(); + } +} + +template +class ReverseArcListGraph::OutgoingArcIterator { + public: + OutgoingArcIterator(const ReverseArcListGraph& graph, NodeIndexType node) + : graph_(graph), index_(graph.start_[node]) { + assert(graph.IsNodeValid(node)); + } + OutgoingArcIterator(const ReverseArcListGraph& graph, NodeIndexType node, + ArcIndexType arc) + : graph_(graph), index_(arc) { + assert(graph.IsNodeValid(node)); + assert(arc == Base::kNilArc || arc >= 0); + assert(arc == Base::kNilArc || graph.Tail(arc) == node); + } + bool Ok() const { return index_ != Base::kNilArc; } + ArcIndexType Index() const { return index_; } + void Next() { + assert(Ok()); + index_ = graph_.next_[index_]; + } + + DEFINE_STL_ITERATOR_FUNCTIONS(OutgoingArcIterator); + + private: + const ReverseArcListGraph& graph_; + ArcIndexType index_; +}; + +template +class ReverseArcListGraph::OppositeIncomingArcIterator { + public: + OppositeIncomingArcIterator(const ReverseArcListGraph& graph, + NodeIndexType node) + : graph_(graph), index_(graph.reverse_start_[node]) { + assert(graph.IsNodeValid(node)); + } + OppositeIncomingArcIterator(const ReverseArcListGraph& graph, + NodeIndexType node, ArcIndexType arc) + : graph_(graph), index_(arc) { + assert(graph.IsNodeValid(node)); + assert(arc == Base::kNilArc || arc < 0); + assert(arc == Base::kNilArc || graph.Tail(arc) == node); + } + + bool Ok() const { return index_ != Base::kNilArc; } + ArcIndexType Index() const { return index_; } + void Next() { + assert(Ok()); + index_ = graph_.next_[index_]; + } + + DEFINE_STL_ITERATOR_FUNCTIONS(OppositeIncomingArcIterator); + + protected: + const ReverseArcListGraph& graph_; + ArcIndexType index_; +}; + +template +class ReverseArcListGraph::IncomingArcIterator + : public OppositeIncomingArcIterator { + public: + IncomingArcIterator(const ReverseArcListGraph& graph, NodeIndexType node) + : OppositeIncomingArcIterator(graph, node) {} + IncomingArcIterator(const ReverseArcListGraph& graph, NodeIndexType node, + ArcIndexType arc) + : OppositeIncomingArcIterator( + graph, node, + arc == Base::kNilArc ? Base::kNilArc : graph.OppositeArc(arc)) {} + + // We overwrite OppositeIncomingArcIterator::Index() here. + ArcIndexType Index() const { + return this->index_ == Base::kNilArc + ? Base::kNilArc + : this->graph_.OppositeArc(this->index_); + } + + DEFINE_STL_ITERATOR_FUNCTIONS(IncomingArcIterator); +}; + +template +class ReverseArcListGraph::OutgoingOrOppositeIncomingArcIterator { + public: + OutgoingOrOppositeIncomingArcIterator(const ReverseArcListGraph& graph, + NodeIndexType node) + : graph_(graph), index_(graph.reverse_start_[node]), node_(node) { + assert(graph.IsNodeValid(node)); + if (index_ == Base::kNilArc) index_ = graph.start_[node]; + } + OutgoingOrOppositeIncomingArcIterator(const ReverseArcListGraph& graph, + NodeIndexType node, ArcIndexType arc) + : graph_(graph), index_(arc), node_(node) { + assert(graph.IsNodeValid(node)); + assert(arc == Base::kNilArc || graph.Tail(arc) == node); + } + + bool Ok() const { return index_ != Base::kNilArc; } + ArcIndexType Index() const { return index_; } + void Next() { + assert(Ok()); + if (index_ < 0) { + index_ = graph_.next_[index_]; + if (index_ == Base::kNilArc) { + index_ = graph_.start_[node_]; + } + } else { + index_ = graph_.next_[index_]; + } + } + + DEFINE_STL_ITERATOR_FUNCTIONS(OutgoingOrOppositeIncomingArcIterator); + + private: + const ReverseArcListGraph& graph_; + ArcIndexType index_; + const NodeIndexType node_; +}; + +template +class ReverseArcListGraph::OutgoingHeadIterator { + public: + OutgoingHeadIterator(const ReverseArcListGraph& graph, NodeIndexType node) + : graph_(&graph), index_(graph.start_[node]) { + assert(graph.IsNodeValid(node)); + } + OutgoingHeadIterator(const ReverseArcListGraph& graph, NodeIndexType node, + ArcIndexType arc) + : graph_(&graph), index_(arc) { + assert(graph.IsNodeValid(node)); + assert(arc == Base::kNilArc || arc >= 0); + assert(arc == Base::kNilArc || graph.Tail(arc) == node); + } + bool Ok() const { return index_ != Base::kNilArc; } + ArcIndexType Index() const { return graph_->Head(index_); } + void Next() { + assert(Ok()); + index_ = graph_->next_[index_]; + } + + DEFINE_STL_ITERATOR_FUNCTIONS(OutgoingHeadIterator); + + private: + const ReverseArcListGraph* graph_; + ArcIndexType index_; +}; + +// ReverseArcStaticGraph implementation ---------------------------------------- + +DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcStaticGraph, Outgoing, + DirectArcLimit(node)) +DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcStaticGraph, Incoming, + ReverseArcLimit(node)) +DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcStaticGraph, + OutgoingOrOppositeIncoming, + DirectArcLimit(node)) +DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcStaticGraph, OppositeIncoming, + ReverseArcLimit(node)) + +template +ArcIndexType ReverseArcStaticGraph::OutDegree( + NodeIndexType node) const { + return DirectArcLimit(node) - start_[node]; +} + +template +ArcIndexType ReverseArcStaticGraph::InDegree( + NodeIndexType node) const { + return ReverseArcLimit(node) - reverse_start_[node]; +} + +template +BeginEndWrapper +ReverseArcStaticGraph::operator[]( + NodeIndexType node) const { + return BeginEndWrapper( + head_.data() + start_[node], head_.data() + DirectArcLimit(node)); +} + +template +ArcIndexType ReverseArcStaticGraph::OppositeArc( + ArcIndexType arc) const { + assert(is_built_); + assert(IsArcValid(arc)); + return opposite_[arc]; +} + +template +NodeIndexType ReverseArcStaticGraph::Head( + ArcIndexType arc) const { + assert(is_built_); + assert(IsArcValid(arc)); + return head_[arc]; +} + +template +NodeIndexType ReverseArcStaticGraph::Tail( + ArcIndexType arc) const { + assert(is_built_); + return head_[OppositeArc(arc)]; +} + +template +void ReverseArcStaticGraph::ReserveArcs( + ArcIndexType bound) { + Base::ReserveArcs(bound); + if (bound <= num_arcs_) return; + head_.reserve(bound); +} + +template +void ReverseArcStaticGraph::AddNode( + NodeIndexType node) { + if (node < num_nodes_) return; + assert(!const_capacities_ || node < node_capacity_); + num_nodes_ = node + 1; +} + +template +ArcIndexType ReverseArcStaticGraph::AddArc( + NodeIndexType tail, NodeIndexType head) { + assert(tail >= 0); + assert(head >= 0); + AddNode(tail > head ? tail : head); + + // We inverse head and tail here because it is more convenient this way + // during build time, see Build(). + head_.grow(head, tail); + assert(!const_capacities_ || num_arcs_ < arc_capacity_); + return num_arcs_++; +} + +template +void ReverseArcStaticGraph::Build( + std::vector* permutation) { + assert(!is_built_); + if (is_built_) return; + is_built_ = true; + node_capacity_ = num_nodes_; + arc_capacity_ = num_arcs_; + this->FreezeCapacities(); + this->BuildStartAndForwardHead(&head_, &start_, permutation); + + // Computes incoming degree of each nodes. + reverse_start_.assign(num_nodes_, 0); + for (int i = 0; i < num_arcs_; ++i) { + reverse_start_[head_[i]]++; + } + this->ComputeCumulativeSum(&reverse_start_); + + // Computes the reverse arcs of the forward arcs. + // Note that this sort the reverse arcs with the same tail by head. + opposite_.reserve(num_arcs_); + for (int i = 0; i < num_arcs_; ++i) { + // TODO(user): the 0 is wasted here, but minor optimisation. + opposite_.grow(0, reverse_start_[head_[i]]++ - num_arcs_); + } + + // Computes in reverse_start_ the start index of the reverse arcs. + for (int i = num_nodes_ - 1; i > 0; --i) { + reverse_start_[i] = reverse_start_[i - 1] - num_arcs_; + } + if (num_nodes_ != 0) { + reverse_start_[0] = -num_arcs_; + } + + // Fill reverse arc information. + for (int i = 0; i < num_arcs_; ++i) { + opposite_[opposite_[i]] = i; + } + for (const NodeIndexType node : Base::AllNodes()) { + for (const ArcIndexType arc : OutgoingArcs(node)) { + head_[opposite_[arc]] = node; + } + } +} + +template +class ReverseArcStaticGraph::OutgoingArcIterator { + public: + OutgoingArcIterator(const ReverseArcStaticGraph& graph, NodeIndexType node) + : index_(graph.start_[node]), limit_(graph.DirectArcLimit(node)) {} + OutgoingArcIterator(const ReverseArcStaticGraph& graph, NodeIndexType node, + ArcIndexType arc) + : index_(arc), limit_(graph.DirectArcLimit(node)) { + assert(arc >= graph.start_[node]); + } + + bool Ok() const { return index_ < limit_; } + ArcIndexType Index() const { return index_; } + void Next() { + assert(Ok()); + index_++; + } + + // TODO(user): we lose a bit by returning a BeginEndWrapper<> on top of this + // iterator rather than a simple IntegerRange on the arc indices. + DEFINE_STL_ITERATOR_FUNCTIONS(OutgoingArcIterator); + + private: + ArcIndexType index_; + const ArcIndexType limit_; +}; + +template +class ReverseArcStaticGraph::OppositeIncomingArcIterator { + public: + OppositeIncomingArcIterator(const ReverseArcStaticGraph& graph, + NodeIndexType node) + : graph_(graph), + limit_(graph.ReverseArcLimit(node)), + index_(graph.reverse_start_[node]) { + assert(graph.IsNodeValid(node)); + assert(index_ <= limit_); + } + OppositeIncomingArcIterator(const ReverseArcStaticGraph& graph, + NodeIndexType node, ArcIndexType arc) + : graph_(graph), limit_(graph.ReverseArcLimit(node)), index_(arc) { + assert(graph.IsNodeValid(node)); + assert(index_ >= graph.reverse_start_[node]); + assert(index_ <= limit_); + } + + bool Ok() const { return index_ < limit_; } + ArcIndexType Index() const { return index_; } + void Next() { + assert(Ok()); + index_++; + } + + DEFINE_STL_ITERATOR_FUNCTIONS(OppositeIncomingArcIterator); + + protected: + const ReverseArcStaticGraph& graph_; + const ArcIndexType limit_; + ArcIndexType index_; +}; + +template +class ReverseArcStaticGraph::IncomingArcIterator + : public OppositeIncomingArcIterator { + public: + IncomingArcIterator(const ReverseArcStaticGraph& graph, NodeIndexType node) + : OppositeIncomingArcIterator(graph, node) {} + IncomingArcIterator(const ReverseArcStaticGraph& graph, NodeIndexType node, + ArcIndexType arc) + : OppositeIncomingArcIterator(graph, node, + arc == graph.ReverseArcLimit(node) + ? graph.ReverseArcLimit(node) + : graph.OppositeArc(arc)) {} + + ArcIndexType Index() const { + return this->index_ == this->limit_ + ? this->limit_ + : this->graph_.OppositeArc(this->index_); + } + + DEFINE_STL_ITERATOR_FUNCTIONS(IncomingArcIterator); +}; + +template +class ReverseArcStaticGraph< + NodeIndexType, ArcIndexType>::OutgoingOrOppositeIncomingArcIterator { + public: + OutgoingOrOppositeIncomingArcIterator(const ReverseArcStaticGraph& graph, + NodeIndexType node) + : index_(graph.reverse_start_[node]), + first_limit_(graph.ReverseArcLimit(node)), + next_start_(graph.start_[node]), + limit_(graph.DirectArcLimit(node)) { + if (index_ == first_limit_) index_ = next_start_; + assert(graph.IsNodeValid(node)); + assert((index_ < first_limit_) || (index_ >= next_start_)); + } + OutgoingOrOppositeIncomingArcIterator(const ReverseArcStaticGraph& graph, + NodeIndexType node, ArcIndexType arc) + : index_(arc), + first_limit_(graph.ReverseArcLimit(node)), + next_start_(graph.start_[node]), + limit_(graph.DirectArcLimit(node)) { + assert(graph.IsNodeValid(node)); + assert((index_ >= graph.reverse_start_[node] && index_ < first_limit_) || + (index_ >= next_start_)); + } + + ArcIndexType Index() const { return index_; } + bool Ok() const { return index_ < limit_; } + void Next() { + assert(Ok()); + index_++; + if (index_ == first_limit_) { + index_ = next_start_; + } + } + + DEFINE_STL_ITERATOR_FUNCTIONS(OutgoingOrOppositeIncomingArcIterator); + + private: + ArcIndexType index_; + const ArcIndexType first_limit_; + const ArcIndexType next_start_; + const ArcIndexType limit_; +}; + +// ReverseArcMixedGraph implementation ----------------------------------------- + +DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcMixedGraph, Outgoing, + DirectArcLimit(node)) +DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcMixedGraph, Incoming, Base::kNilArc) +DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcMixedGraph, + OutgoingOrOppositeIncoming, + DirectArcLimit(node)) +DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcMixedGraph, OppositeIncoming, + Base::kNilArc) + +template +ArcIndexType ReverseArcMixedGraph::OutDegree( + NodeIndexType node) const { + return DirectArcLimit(node) - start_[node]; +} + +template +ArcIndexType ReverseArcMixedGraph::InDegree( + NodeIndexType node) const { + ArcIndexType degree(0); + for (auto arc : OppositeIncomingArcs(node)) ++degree; + return degree; +} + +template +BeginEndWrapper +ReverseArcMixedGraph::operator[]( + NodeIndexType node) const { + return BeginEndWrapper( + head_.data() + start_[node], head_.data() + DirectArcLimit(node)); +} + +template +ArcIndexType ReverseArcMixedGraph::OppositeArc( + ArcIndexType arc) const { + assert(IsArcValid(arc)); + return ~arc; +} + +template +NodeIndexType ReverseArcMixedGraph::Head( + ArcIndexType arc) const { + assert(is_built_); + assert(IsArcValid(arc)); + return head_[arc]; +} + +template +NodeIndexType ReverseArcMixedGraph::Tail( + ArcIndexType arc) const { + assert(is_built_); + return head_[OppositeArc(arc)]; +} + +template +void ReverseArcMixedGraph::ReserveArcs( + ArcIndexType bound) { + Base::ReserveArcs(bound); + if (bound <= num_arcs_) return; + head_.reserve(bound); +} + +template +void ReverseArcMixedGraph::AddNode( + NodeIndexType node) { + if (node < num_nodes_) return; + assert(!const_capacities_ || node < node_capacity_); + num_nodes_ = node + 1; +} + +template +ArcIndexType ReverseArcMixedGraph::AddArc( + NodeIndexType tail, NodeIndexType head) { + assert(tail >= 0); + assert(head >= 0); + AddNode(tail > head ? tail : head); + + // We inverse head and tail here because it is more convenient this way + // during build time, see Build(). + head_.grow(head, tail); + assert(!const_capacities_ || num_arcs_ < arc_capacity_); + return num_arcs_++; +} + +template +void ReverseArcMixedGraph::Build( + std::vector* permutation) { + assert(!is_built_); + if (is_built_) return; + is_built_ = true; + node_capacity_ = num_nodes_; + arc_capacity_ = num_arcs_; + this->FreezeCapacities(); + this->BuildStartAndForwardHead(&head_, &start_, permutation); + + // Fill tails. + for (const NodeIndexType node : Base::AllNodes()) { + for (const ArcIndexType arc : OutgoingArcs(node)) { + head_[~arc] = node; + } + } + + // Fill information for iterating over reverse arcs. + reverse_start_.assign(num_nodes_, Base::kNilArc); + next_.reserve(num_arcs_); + for (const ArcIndexType arc : Base::AllForwardArcs()) { + next_.push_back(reverse_start_[Head(arc)]); + reverse_start_[Head(arc)] = -next_.size(); + } +} + +template +class ReverseArcMixedGraph::OutgoingArcIterator { + public: + OutgoingArcIterator(const ReverseArcMixedGraph& graph, NodeIndexType node) + : index_(graph.start_[node]), limit_(graph.DirectArcLimit(node)) {} + OutgoingArcIterator(const ReverseArcMixedGraph& graph, NodeIndexType node, + ArcIndexType arc) + : index_(arc), limit_(graph.DirectArcLimit(node)) { + assert(arc >= graph.start_[node]); + } + + bool Ok() const { return index_ < limit_; } + ArcIndexType Index() const { return index_; } + void Next() { + assert(Ok()); + index_++; + } + + // TODO(user): we lose a bit by returning a BeginEndWrapper<> on top of this + // iterator rather than a simple IntegerRange on the arc indices. + DEFINE_STL_ITERATOR_FUNCTIONS(OutgoingArcIterator); + + private: + ArcIndexType index_; + const ArcIndexType limit_; +}; + +template +class ReverseArcMixedGraph::OppositeIncomingArcIterator { + public: + OppositeIncomingArcIterator(const ReverseArcMixedGraph& graph, + NodeIndexType node) + : graph_(&graph) { + assert(graph.is_built_); + assert(graph.IsNodeValid(node)); + index_ = graph.reverse_start_[node]; + } + OppositeIncomingArcIterator(const ReverseArcMixedGraph& graph, + NodeIndexType node, ArcIndexType arc) + : graph_(&graph), index_(arc) { + assert(graph.is_built_); + assert(graph.IsNodeValid(node)); + assert(arc == Base::kNilArc || arc < 0); + assert(arc == Base::kNilArc || graph.Tail(arc) == node); + } + bool Ok() const { return index_ != Base::kNilArc; } + ArcIndexType Index() const { return index_; } + void Next() { + assert(Ok()); + index_ = graph_->next_[~index_]; + } + + DEFINE_STL_ITERATOR_FUNCTIONS(OppositeIncomingArcIterator); + + protected: + const ReverseArcMixedGraph* graph_; + ArcIndexType index_; +}; + +template +class ReverseArcMixedGraph::IncomingArcIterator + : public OppositeIncomingArcIterator { + public: + IncomingArcIterator(const ReverseArcMixedGraph& graph, NodeIndexType node) + : OppositeIncomingArcIterator(graph, node) {} + IncomingArcIterator(const ReverseArcMixedGraph& graph, NodeIndexType node, + ArcIndexType arc) + : OppositeIncomingArcIterator( + graph, node, arc == Base::kNilArc ? arc : graph.OppositeArc(arc)) {} + ArcIndexType Index() const { + return this->index_ == Base::kNilArc + ? Base::kNilArc + : this->graph_->OppositeArc(this->index_); + } + + DEFINE_STL_ITERATOR_FUNCTIONS(IncomingArcIterator); +}; + +template +class ReverseArcMixedGraph< + NodeIndexType, ArcIndexType>::OutgoingOrOppositeIncomingArcIterator { + public: + OutgoingOrOppositeIncomingArcIterator(const ReverseArcMixedGraph& graph, + NodeIndexType node) + : graph_(&graph) { + limit_ = graph.DirectArcLimit(node); // also DCHECKs node and is_built_. + index_ = graph.reverse_start_[node]; + restart_ = graph.start_[node]; + if (index_ == Base::kNilArc) { + index_ = restart_; + } + } + OutgoingOrOppositeIncomingArcIterator(const ReverseArcMixedGraph& graph, + NodeIndexType node, ArcIndexType arc) + : graph_(&graph) { + limit_ = graph.DirectArcLimit(node); + index_ = arc; + restart_ = graph.start_[node]; + assert(arc == Base::kNilArc || arc == limit_ || graph.Tail(arc) == node); + } + bool Ok() const { + // Note that we always have limit_ <= Base::kNilArc. + return index_ < limit_; + } + ArcIndexType Index() const { return index_; } + void Next() { + assert(Ok()); + if (index_ < 0) { + index_ = graph_->next_[graph_->OppositeArc(index_)]; + if (index_ == Base::kNilArc) { + index_ = restart_; + } + } else { + index_++; + } + } + + DEFINE_STL_ITERATOR_FUNCTIONS(OutgoingOrOppositeIncomingArcIterator); + + private: + const ReverseArcMixedGraph* graph_; + ArcIndexType index_; + ArcIndexType restart_; + ArcIndexType limit_; +}; + +// CompleteGraph implementation ------------------------------------------------ +// Nodes and arcs are implicit and not stored. + +template +class CompleteGraph : public BaseGraph { + typedef BaseGraph Base; + using Base::arc_capacity_; + using Base::const_capacities_; + using Base::node_capacity_; + using Base::num_arcs_; + using Base::num_nodes_; + + public: + // Builds a complete graph with num_nodes nodes. + explicit CompleteGraph(NodeIndexType num_nodes) { + this->Reserve(num_nodes, num_nodes * num_nodes); + this->FreezeCapacities(); + num_nodes_ = num_nodes; + num_arcs_ = num_nodes * num_nodes; + } + + NodeIndexType Head(ArcIndexType arc) const; + NodeIndexType Tail(ArcIndexType arc) const; + ArcIndexType OutDegree(NodeIndexType node) const; + IntegerRange OutgoingArcs(NodeIndexType node) const; + IntegerRange OutgoingArcsStartingFrom(NodeIndexType node, + ArcIndexType from) const; + IntegerRange operator[](NodeIndexType node) const; +}; + +template +NodeIndexType CompleteGraph::Head( + ArcIndexType arc) const { + assert(this->IsArcValid(arc)); + return arc % num_nodes_; +} + +template +NodeIndexType CompleteGraph::Tail( + ArcIndexType arc) const { + assert(this->IsArcValid(arc)); + return arc / num_nodes_; +} + +template +ArcIndexType CompleteGraph::OutDegree( + NodeIndexType node) const { + return num_nodes_; +} + +template +IntegerRange +CompleteGraph::OutgoingArcs( + NodeIndexType node) const { + assert(node < num_nodes_); + return IntegerRange( + static_cast(num_nodes_) * node, + static_cast(num_nodes_) * (node + 1)); +} + +template +IntegerRange +CompleteGraph::OutgoingArcsStartingFrom( + NodeIndexType node, ArcIndexType from) const { + assert(node < num_nodes_); + return IntegerRange( + from, static_cast(num_nodes_) * (node + 1)); +} + +template +IntegerRange +CompleteGraph::operator[]( + NodeIndexType node) const { + assert(node < num_nodes_); + return IntegerRange(0, num_nodes_); +} + +// CompleteBipartiteGraph implementation --------------------------------------- +// Nodes and arcs are implicit and not stored. + +template +class CompleteBipartiteGraph + : public BaseGraph { + typedef BaseGraph Base; + using Base::arc_capacity_; + using Base::const_capacities_; + using Base::node_capacity_; + using Base::num_arcs_; + using Base::num_nodes_; + + public: + // Builds a complete bipartite graph from a set of left nodes to a set of + // right nodes. + // Indices of left nodes of the bipartite graph range from 0 to left_nodes-1; + // indices of right nodes range from left_nodes to left_nodes+right_nodes-1. + CompleteBipartiteGraph(NodeIndexType left_nodes, NodeIndexType right_nodes) + : left_nodes_(left_nodes), right_nodes_(right_nodes) { + this->Reserve(left_nodes + right_nodes, left_nodes * right_nodes); + this->FreezeCapacities(); + num_nodes_ = left_nodes + right_nodes; + num_arcs_ = left_nodes * right_nodes; + } + + NodeIndexType Head(ArcIndexType arc) const; + NodeIndexType Tail(ArcIndexType arc) const; + ArcIndexType OutDegree(NodeIndexType node) const; + IntegerRange OutgoingArcs(NodeIndexType node) const; + IntegerRange OutgoingArcsStartingFrom(NodeIndexType node, + ArcIndexType from) const; + IntegerRange operator[](NodeIndexType node) const; + + // Deprecated interface. + class OutgoingArcIterator { + public: + OutgoingArcIterator(const CompleteBipartiteGraph& graph, NodeIndexType node) + : index_(graph.right_nodes_ * node), + limit_(node >= graph.left_nodes_ ? index_ + : graph.right_nodes_ * (node + 1)) {} + + bool Ok() const { return index_ < limit_; } + ArcIndexType Index() const { return index_; } + void Next() { index_++; } + + private: + ArcIndexType index_; + const ArcIndexType limit_; + }; + + private: + const NodeIndexType left_nodes_; + const NodeIndexType right_nodes_; +}; + +template +NodeIndexType CompleteBipartiteGraph::Head( + ArcIndexType arc) const { + assert(this->IsArcValid(arc)); + return left_nodes_ + arc % right_nodes_; +} + +template +NodeIndexType CompleteBipartiteGraph::Tail( + ArcIndexType arc) const { + assert(this->IsArcValid(arc)); + return arc / right_nodes_; +} + +template +ArcIndexType CompleteBipartiteGraph::OutDegree( + NodeIndexType node) const { + return (node < left_nodes_) ? right_nodes_ : 0; +} + +template +IntegerRange +CompleteBipartiteGraph::OutgoingArcs( + NodeIndexType node) const { + if (node < left_nodes_) { + return IntegerRange(right_nodes_ * node, + right_nodes_ * (node + 1)); + } else { + return IntegerRange(0, 0); + } +} + +template +IntegerRange +CompleteBipartiteGraph::OutgoingArcsStartingFrom( + NodeIndexType node, ArcIndexType from) const { + if (node < left_nodes_) { + return IntegerRange(from, right_nodes_ * (node + 1)); + } else { + return IntegerRange(0, 0); + } +} + +template +IntegerRange +CompleteBipartiteGraph::operator[]( + NodeIndexType node) const { + if (node < left_nodes_) { + return IntegerRange(left_nodes_, left_nodes_ + right_nodes_); + } else { + return IntegerRange(0, 0); + } +} + +// Defining the simplest Graph interface as Graph for convenience. +typedef ListGraph<> Graph; + +} // namespace util + +#undef DEFINE_RANGE_BASED_ARC_ITERATION +#undef DEFINE_STL_ITERATOR_FUNCTIONS + +#endif // UTIL_GRAPH_GRAPH_H_ diff --git a/cxx/isce3/unwrap/ortools/graphs.h b/cxx/isce3/unwrap/ortools/graphs.h new file mode 100644 index 000000000..8d1501b5b --- /dev/null +++ b/cxx/isce3/unwrap/ortools/graphs.h @@ -0,0 +1,78 @@ +// Copyright 2010-2021 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Temporary utility class needed as long as we have two slightly +// different graph interface: The one in ebert_graph.h and the one in graph.h + +#ifndef OR_TOOLS_GRAPH_GRAPHS_H_ +#define OR_TOOLS_GRAPH_GRAPHS_H_ + +#include "ebert_graph.h" + +namespace operations_research { + +// Since StarGraph does not have exactly the same interface as the other +// graphs, we define a correspondence there. +template +struct Graphs { + typedef typename Graph::ArcIndex ArcIndex; + typedef typename Graph::NodeIndex NodeIndex; + static ArcIndex OppositeArc(const Graph& graph, ArcIndex arc) { + return graph.OppositeArc(arc); + } + static bool IsArcValid(const Graph& graph, ArcIndex arc) { + return graph.IsArcValid(arc); + } + static NodeIndex NodeReservation(const Graph& graph) { + return graph.node_capacity(); + } + static ArcIndex ArcReservation(const Graph& graph) { + return graph.arc_capacity(); + } + static void Build(Graph* graph) { graph->Build(); } + static void Build(Graph* graph, std::vector* permutation) { + graph->Build(permutation); + } +}; + +template <> +struct Graphs { + typedef operations_research::StarGraph Graph; +#if defined(_MSC_VER) + typedef Graph::ArcIndex ArcIndex; + typedef Graph::NodeIndex NodeIndex; +#else + typedef typename Graph::ArcIndex ArcIndex; + typedef typename Graph::NodeIndex NodeIndex; +#endif + static ArcIndex OppositeArc(const Graph& graph, ArcIndex arc) { + return graph.Opposite(arc); + } + static bool IsArcValid(const Graph& graph, ArcIndex arc) { + return graph.CheckArcValidity(arc); + } + static NodeIndex NodeReservation(const Graph& graph) { + return graph.max_num_nodes(); + } + static ArcIndex ArcReservation(const Graph& graph) { + return graph.max_num_arcs(); + } + static void Build(Graph* /*graph*/) {} + static void Build(Graph* /*graph*/, std::vector* permutation) { + permutation->clear(); + } +}; + +} // namespace operations_research + +#endif // OR_TOOLS_GRAPH_GRAPHS_H_ diff --git a/cxx/isce3/unwrap/ortools/iterators.h b/cxx/isce3/unwrap/ortools/iterators.h new file mode 100644 index 000000000..e3fc702e0 --- /dev/null +++ b/cxx/isce3/unwrap/ortools/iterators.h @@ -0,0 +1,178 @@ +// Copyright 2010-2021 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Helper classes to make it easy to implement range-based for loops. + +#ifndef UTIL_GRAPH_ITERATORS_H_ +#define UTIL_GRAPH_ITERATORS_H_ + +#include +#include + +namespace util { + +// This is useful for wrapping iterators of a class that support many different +// iterations. For instance, on a Graph class, one can write: +// +// BeginEndWrapper Graph::OutgoingArcs(NodeInde node) +// const { +// return BeginEndRange( +// OutgoingArcIterator(*this, node, /*at_end=*/false), +// OutgoingArcIterator(*this, node, /*at_end=*/true)); +// } +// +// And a client will use it like this: +// +// for (const ArcIndex arc : graph.OutgoingArcs(node)) { ... } +template +class BeginEndWrapper { + public: + using const_iterator = Iterator; + using value_type = typename std::iterator_traits::value_type; + + BeginEndWrapper(Iterator begin, Iterator end) : begin_(begin), end_(end) {} + Iterator begin() const { return begin_; } + Iterator end() const { return end_; } + + bool empty() const { return begin() == end(); } + + private: + const Iterator begin_; + const Iterator end_; +}; + +// Inline wrapper methods, to make the client code even simpler. +// The harm of overloading is probably less than the benefit of the nice, +// compact name, in this special case. +template +inline BeginEndWrapper BeginEndRange(Iterator begin, Iterator end) { + return BeginEndWrapper(begin, end); +} +template +inline BeginEndWrapper BeginEndRange( + std::pair begin_end) { + return BeginEndWrapper(begin_end.first, begin_end.second); +} + +// Shortcut for BeginEndRange(multimap::equal_range(key)). +// TODO(user): go further and expose only the values, not the pairs (key, +// values) since the caller already knows the key. +template +inline BeginEndWrapper EqualRange( + MultiMap& multi_map, const typename MultiMap::key_type& key) { + return BeginEndRange(multi_map.equal_range(key)); +} +template +inline BeginEndWrapper EqualRange( + const MultiMap& multi_map, const typename MultiMap::key_type& key) { + return BeginEndRange(multi_map.equal_range(key)); +} + +// The Reverse() function allows to reverse the iteration order of a range-based +// for loop over a container that support STL reverse iterators. +// The syntax is: +// for (const type& t : Reverse(container_of_t)) { ... } +template +class BeginEndReverseIteratorWrapper { + public: + explicit BeginEndReverseIteratorWrapper(const Container& c) : c_(c) {} + typename Container::const_reverse_iterator begin() const { + return c_.rbegin(); + } + typename Container::const_reverse_iterator end() const { return c_.rend(); } + + private: + const Container& c_; +}; +template +BeginEndReverseIteratorWrapper Reverse(const Container& c) { + return BeginEndReverseIteratorWrapper(c); +} + +// Simple iterator on an integer range, see IntegerRange below. +template +class IntegerRangeIterator + : public std::iterator { + public: + explicit IntegerRangeIterator(IntegerType value) : index_(value) {} + IntegerRangeIterator(const IntegerRangeIterator& other) + : index_(other.index_) {} + IntegerRangeIterator& operator=(const IntegerRangeIterator& other) { + index_ = other.index_; + } + bool operator!=(const IntegerRangeIterator& other) const { + // This may seems weird, but using < instead of != avoid almost-infinite + // loop if one use IntegerRange(1, 0) below for instance. + return index_ < other.index_; + } + bool operator==(const IntegerRangeIterator& other) const { + return index_ == other.index_; + } + IntegerType operator*() const { return index_; } + IntegerRangeIterator& operator++() { + ++index_; + return *this; + } + IntegerRangeIterator operator++(int) { + IntegerRangeIterator previous_position(*this); + ++index_; + return previous_position; + } + + private: + IntegerType index_; +}; + +// Allows to easily construct nice functions for range-based for loop. +// This can be used like this: +// +// for (const int i : IntegerRange(0, 10)) { ... } +// +// But it main purpose is to be used as return value for more complex classes: +// +// for (const ArcIndex arc : graph.AllOutgoingArcs()); +// for (const NodeIndex node : graph.AllNodes()); +template +class IntegerRange : public BeginEndWrapper> { + public: + IntegerRange(IntegerType begin, IntegerType end) + : BeginEndWrapper>( + IntegerRangeIterator(begin), + IntegerRangeIterator(end)) {} +}; + +// Allow iterating over a vector as a mutable vector. +template +struct MutableVectorIteration { + explicit MutableVectorIteration(std::vector* v) : v_(v) {} + struct Iterator { + explicit Iterator(typename std::vector::iterator it) : it_(it) {} + T* operator*() { return &*it_; } + Iterator& operator++() { + it_++; + return *this; + } + bool operator!=(const Iterator& other) const { return other.it_ != it_; } + + private: + typename std::vector::iterator it_; + }; + Iterator begin() { return Iterator(v_->begin()); } + Iterator end() { return Iterator(v_->end()); } + + private: + std::vector* const v_; +}; +} // namespace util + +#endif // UTIL_GRAPH_ITERATORS_H_ diff --git a/cxx/isce3/unwrap/ortools/max_flow.cc b/cxx/isce3/unwrap/ortools/max_flow.cc new file mode 100644 index 000000000..06432fb4f --- /dev/null +++ b/cxx/isce3/unwrap/ortools/max_flow.cc @@ -0,0 +1,987 @@ +// Copyright 2010-2021 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "max_flow.h" + +#include +#include +#include + +#include + +#include "graphs.h" + +namespace operations_research { + +SimpleMaxFlow::SimpleMaxFlow() : num_nodes_(0) {} + +ArcIndex SimpleMaxFlow::AddArcWithCapacity(NodeIndex tail, NodeIndex head, + FlowQuantity capacity) { + const ArcIndex num_arcs = arc_tail_.size(); + num_nodes_ = std::max(num_nodes_, tail + 1); + num_nodes_ = std::max(num_nodes_, head + 1); + arc_tail_.push_back(tail); + arc_head_.push_back(head); + arc_capacity_.push_back(capacity); + return num_arcs; +} + +NodeIndex SimpleMaxFlow::NumNodes() const { return num_nodes_; } + +ArcIndex SimpleMaxFlow::NumArcs() const { return arc_tail_.size(); } + +NodeIndex SimpleMaxFlow::Tail(ArcIndex arc) const { return arc_tail_[arc]; } + +NodeIndex SimpleMaxFlow::Head(ArcIndex arc) const { return arc_head_[arc]; } + +FlowQuantity SimpleMaxFlow::Capacity(ArcIndex arc) const { + return arc_capacity_[arc]; +} + +void SimpleMaxFlow::SetArcCapacity(ArcIndex arc, FlowQuantity capacity) { + arc_capacity_[arc] = capacity; +} + +SimpleMaxFlow::Status SimpleMaxFlow::Solve(NodeIndex source, NodeIndex sink) { + const ArcIndex num_arcs = arc_capacity_.size(); + arc_flow_.assign(num_arcs, 0); + underlying_max_flow_.reset(); + underlying_graph_.reset(); + optimal_flow_ = 0; + if (source == sink || source < 0 || sink < 0) { + return BAD_INPUT; + } + if (source >= num_nodes_ || sink >= num_nodes_) { + return OPTIMAL; + } + underlying_graph_ = std::make_unique(num_nodes_, num_arcs); + underlying_graph_->AddNode(source); + underlying_graph_->AddNode(sink); + for (int arc = 0; arc < num_arcs; ++arc) { + underlying_graph_->AddArc(arc_tail_[arc], arc_head_[arc]); + } + underlying_graph_->Build(&arc_permutation_); + underlying_max_flow_ = std::make_unique>( + underlying_graph_.get(), source, sink); + for (ArcIndex arc = 0; arc < num_arcs; ++arc) { + ArcIndex permuted_arc = + arc < arc_permutation_.size() ? arc_permutation_[arc] : arc; + underlying_max_flow_->SetArcCapacity(permuted_arc, arc_capacity_[arc]); + } + if (underlying_max_flow_->Solve()) { + optimal_flow_ = underlying_max_flow_->GetOptimalFlow(); + for (ArcIndex arc = 0; arc < num_arcs; ++arc) { + ArcIndex permuted_arc = + arc < arc_permutation_.size() ? arc_permutation_[arc] : arc; + arc_flow_[arc] = underlying_max_flow_->Flow(permuted_arc); + } + } + // Translate the GenericMaxFlow::Status. It is different because NOT_SOLVED + // does not make sense in the simple api. + switch (underlying_max_flow_->status()) { + case GenericMaxFlow::NOT_SOLVED: + return BAD_RESULT; + case GenericMaxFlow::OPTIMAL: + return OPTIMAL; + case GenericMaxFlow::INT_OVERFLOW: + return POSSIBLE_OVERFLOW; + case GenericMaxFlow::BAD_INPUT: + return BAD_INPUT; + case GenericMaxFlow::BAD_RESULT: + return BAD_RESULT; + } + return BAD_RESULT; +} + +FlowQuantity SimpleMaxFlow::OptimalFlow() const { return optimal_flow_; } + +FlowQuantity SimpleMaxFlow::Flow(ArcIndex arc) const { return arc_flow_[arc]; } + +void SimpleMaxFlow::GetSourceSideMinCut(std::vector* result) { + if (underlying_max_flow_ == nullptr) return; + underlying_max_flow_->GetSourceSideMinCut(result); +} + +void SimpleMaxFlow::GetSinkSideMinCut(std::vector* result) { + if (underlying_max_flow_ == nullptr) return; + underlying_max_flow_->GetSinkSideMinCut(result); +} + +template +GenericMaxFlow::GenericMaxFlow(const Graph* graph, NodeIndex source, + NodeIndex sink) + : graph_(graph), + node_excess_(), + node_potential_(), + residual_arc_capacity_(), + first_admissible_arc_(), + active_nodes_(), + source_(source), + sink_(sink), + use_global_update_(true), + use_two_phase_algorithm_(true), + process_node_by_height_(true), + check_input_(true), + check_result_(true) { + assert(graph->IsNodeValid(source)); + assert(graph->IsNodeValid(sink)); + const NodeIndex max_num_nodes = Graphs::NodeReservation(*graph_); + if (max_num_nodes > 0) { + node_excess_.Reserve(0, max_num_nodes - 1); + node_excess_.SetAll(0); + node_potential_.Reserve(0, max_num_nodes - 1); + node_potential_.SetAll(0); + first_admissible_arc_.Reserve(0, max_num_nodes - 1); + first_admissible_arc_.SetAll(Graph::kNilArc); + bfs_queue_.reserve(max_num_nodes); + active_nodes_.reserve(max_num_nodes); + } + const ArcIndex max_num_arcs = Graphs::ArcReservation(*graph_); + if (max_num_arcs > 0) { + residual_arc_capacity_.Reserve(-max_num_arcs, max_num_arcs - 1); + residual_arc_capacity_.SetAll(0); + } +} + +template +bool GenericMaxFlow::CheckInputConsistency() const { + bool ok = true; + for (ArcIndex arc = 0; arc < graph_->num_arcs(); ++arc) { + if (residual_arc_capacity_[arc] < 0) { + ok = false; + } + } + return ok; +} + +template +void GenericMaxFlow::SetArcCapacity(ArcIndex arc, + FlowQuantity new_capacity) { + assert(0 <= new_capacity); + assert(IsArcDirect(arc)); + const FlowQuantity free_capacity = residual_arc_capacity_[arc]; + const FlowQuantity capacity_delta = new_capacity - Capacity(arc); + if (capacity_delta == 0) { + return; // Nothing to do. + } + status_ = NOT_SOLVED; + if (free_capacity + capacity_delta >= 0) { + // The above condition is true if one of the two conditions is true: + // 1/ (capacity_delta > 0), meaning we are increasing the capacity + // 2/ (capacity_delta < 0 && free_capacity + capacity_delta >= 0) + // meaning we are reducing the capacity, but that the capacity + // reduction is not larger than the free capacity. + assert((capacity_delta > 0) || + (capacity_delta < 0 && free_capacity + capacity_delta >= 0)); + residual_arc_capacity_.Set(arc, free_capacity + capacity_delta); + assert(0 <= residual_arc_capacity_[arc]); + } else { + // Note that this breaks the preflow invariants but it is currently not an + // issue since we restart from scratch on each Solve() and we set the status + // to NOT_SOLVED. + // + // TODO(user): The easiest is probably to allow negative node excess in + // other places than the source, but the current implementation does not + // deal with this. + SetCapacityAndClearFlow(arc, new_capacity); + } +} + +template +void GenericMaxFlow::SetArcFlow(ArcIndex arc, FlowQuantity new_flow) { + assert(IsArcValid(arc)); + assert(new_flow >= 0); + const FlowQuantity capacity = Capacity(arc); + assert(capacity >= new_flow); + + // Note that this breaks the preflow invariants but it is currently not an + // issue since we restart from scratch on each Solve() and we set the status + // to NOT_SOLVED. + residual_arc_capacity_.Set(Opposite(arc), -new_flow); + residual_arc_capacity_.Set(arc, capacity - new_flow); + status_ = NOT_SOLVED; +} + +template +void GenericMaxFlow::GetSourceSideMinCut( + std::vector* result) { + ComputeReachableNodes(source_, result); +} + +template +void GenericMaxFlow::GetSinkSideMinCut(std::vector* result) { + ComputeReachableNodes(sink_, result); +} + +template +bool GenericMaxFlow::CheckResult() const { + bool ok = true; + if (node_excess_[source_] != -node_excess_[sink_]) { + pyre::journal::firewall_t channel("isce3.unwrap.ortools.max_flow"); + channel << pyre::journal::at(__HERE__) + << "-node_excess_[source_] = " << -node_excess_[source_] + << " != node_excess_[sink_] = " << node_excess_[sink_] + << pyre::journal::endl; + ok = false; + } + for (NodeIndex node = 0; node < graph_->num_nodes(); ++node) { + if (node != source_ && node != sink_) { + if (node_excess_[node] != 0) { + pyre::journal::firewall_t channel("isce3.unwrap.ortools.max_flow"); + channel << pyre::journal::at(__HERE__) + << "node_excess_[" << node << "] = " << node_excess_[node] + << " != 0" << pyre::journal::endl; + ok = false; + } + } + } + for (ArcIndex arc = 0; arc < graph_->num_arcs(); ++arc) { + const ArcIndex opposite = Opposite(arc); + const FlowQuantity direct_capacity = residual_arc_capacity_[arc]; + const FlowQuantity opposite_capacity = residual_arc_capacity_[opposite]; + if (direct_capacity < 0) { + pyre::journal::firewall_t channel("isce3.unwrap.ortools.max_flow"); + channel << pyre::journal::at(__HERE__) + << "residual_arc_capacity_[" << arc + << "] = " << direct_capacity << " < 0" + << pyre::journal::endl; + ok = false; + } + if (opposite_capacity < 0) { + pyre::journal::firewall_t channel("isce3.unwrap.ortools.max_flow"); + channel << pyre::journal::at(__HERE__) + << "residual_arc_capacity_[" << opposite + << "] = " << opposite_capacity << " < 0" + << pyre::journal::endl; + ok = false; + } + // The initial capacity of the direct arcs is non-negative. + if (direct_capacity + opposite_capacity < 0) { + pyre::journal::firewall_t channel("isce3.unwrap.ortools.max_flow"); + channel << pyre::journal::at(__HERE__) + << "initial capacity [" << arc + << "] = " << direct_capacity + opposite_capacity << " < 0" + << pyre::journal::endl; + ok = false; + } + } + return ok; +} + +template +bool GenericMaxFlow::AugmentingPathExists() const { + // We simply compute the reachability from the source in the residual graph. + const NodeIndex num_nodes = graph_->num_nodes(); + std::vector is_reached(num_nodes, false); + std::vector to_process; + + to_process.push_back(source_); + is_reached[source_] = true; + while (!to_process.empty()) { + const NodeIndex node = to_process.back(); + to_process.pop_back(); + for (OutgoingOrOppositeIncomingArcIterator it(*graph_, node); it.Ok(); + it.Next()) { + const ArcIndex arc = it.Index(); + if (residual_arc_capacity_[arc] > 0) { + const NodeIndex head = graph_->Head(arc); + if (!is_reached[head]) { + is_reached[head] = true; + to_process.push_back(head); + } + } + } + } + return is_reached[sink_]; +} + +template +bool GenericMaxFlow::CheckRelabelPrecondition(NodeIndex node) const { + assert(IsActive(node)); + for (OutgoingOrOppositeIncomingArcIterator it(*graph_, node); it.Ok(); + it.Next()) { +#ifndef NDEBUG + const ArcIndex arc = it.Index(); + assert(!IsAdmissible(arc) && DebugString("CheckRelabelPrecondition:", arc)); +#endif + } + return true; +} + +template +std::string GenericMaxFlow::DebugString(const std::string& context, + ArcIndex arc) const { + const NodeIndex tail = Tail(arc); + const NodeIndex head = Head(arc); + return std::string(context) + " Arc " + std::to_string(arc) + + ", from " + std::to_string(tail) + " to " + std::to_string(head) + + ", Capacity = " + std::to_string(Capacity(arc)) + + ", Residual capacity = " + std::to_string(residual_arc_capacity_[arc]) + + ", Flow = residual capacity for reverse arc = " + std::to_string(Flow(arc)) + + ", Height(tail) = " + std::to_string(node_potential_[tail]) + + ", Height(head) = " + std::to_string(node_potential_[head]) + + ", Excess(tail) = " + std::to_string(node_excess_[tail]) + + ", Excess(head) = " + std::to_string(node_excess_[head]); +} + +template +bool GenericMaxFlow::Solve() { + status_ = NOT_SOLVED; + if (check_input_ && !CheckInputConsistency()) { + status_ = BAD_INPUT; + return false; + } + InitializePreflow(); + + // Deal with the case when source_ or sink_ is not inside graph_. + // Since they are both specified independently of the graph, we do need to + // take care of this corner case. + const NodeIndex num_nodes = graph_->num_nodes(); + if (sink_ >= num_nodes || source_ >= num_nodes) { + // Behave like a normal graph where source_ and sink_ are disconnected. + // Note that the arc flow is set to 0 by InitializePreflow(). + status_ = OPTIMAL; + return true; + } + if (use_global_update_) { + RefineWithGlobalUpdate(); + } else { + Refine(); + } + if (check_result_) { + if (!CheckResult()) { + status_ = BAD_RESULT; + return false; + } + if (GetOptimalFlow() < kMaxFlowQuantity && AugmentingPathExists()) { + pyre::journal::error_t channel("isce3.unwrap.ortools.max_flow"); + channel << pyre::journal::at(__HERE__) + << "The algorithm terminated, but the flow is not maximal!" + << pyre::journal::endl; + status_ = BAD_RESULT; + return false; + } + } + assert(node_excess_[sink_] == -node_excess_[source_]); + status_ = OPTIMAL; + if (GetOptimalFlow() == kMaxFlowQuantity && AugmentingPathExists()) { + // In this case, we are sure that the flow is > kMaxFlowQuantity. + status_ = INT_OVERFLOW; + } + return true; +} + +template +void GenericMaxFlow::InitializePreflow() { + // InitializePreflow() clears the whole flow that could have been computed + // by a previous Solve(). This is not optimal in terms of complexity. + // TODO(user): find a way to make the re-solving incremental (not an obvious + // task, and there has not been a lot of literature on the subject.) + node_excess_.SetAll(0); + const ArcIndex num_arcs = graph_->num_arcs(); + for (ArcIndex arc = 0; arc < num_arcs; ++arc) { + SetCapacityAndClearFlow(arc, Capacity(arc)); + } + + // All the initial heights are zero except for the source whose height is + // equal to the number of nodes and will never change during the algorithm. + node_potential_.SetAll(0); + node_potential_.Set(source_, graph_->num_nodes()); + + // Initially no arcs are admissible except maybe the one leaving the source, + // but we treat the source in a special way, see + // SaturateOutgoingArcsFromSource(). + const NodeIndex num_nodes = graph_->num_nodes(); + for (NodeIndex node = 0; node < num_nodes; ++node) { + first_admissible_arc_[node] = Graph::kNilArc; + } +} + +// Note(user): Calling this function will break the property on the node +// potentials because of the way we cancel flow on cycle. However, we only call +// that at the end of the algorithm, or just before a GlobalUpdate() that will +// restore the precondition on the node potentials. +template +void GenericMaxFlow::PushFlowExcessBackToSource() { + const NodeIndex num_nodes = graph_->num_nodes(); + + // We implement a variation of Tarjan's strongly connected component algorithm + // to detect cycles published in: Tarjan, R. E. (1972), "Depth-first search + // and linear graph algorithms", SIAM Journal on Computing. A description can + // also be found in wikipedia. + + // Stored nodes are settled nodes already stored in the + // reverse_topological_order (except the sink_ that we do not actually store). + std::vector stored(num_nodes, false); + stored[sink_] = true; + + // The visited nodes that are not yet stored are all the nodes from the + // source_ to the current node in the current dfs branch. + std::vector visited(num_nodes, false); + visited[sink_] = true; + + // Stack of arcs to explore in the dfs search. + // The current node is Head(arc_stack.back()). + std::vector arc_stack; + + // Increasing list of indices into the arc_stack that correspond to the list + // of arcs in the current dfs branch from the source_ to the current node. + std::vector index_branch; + + // Node in reverse_topological_order in the final dfs tree. + std::vector reverse_topological_order; + + // We start by pushing all the outgoing arcs from the source on the stack to + // avoid special conditions in the code. As a result, source_ will not be + // stored in reverse_topological_order, and this is what we want. + for (OutgoingArcIterator it(*graph_, source_); it.Ok(); it.Next()) { + const ArcIndex arc = it.Index(); + const FlowQuantity flow = Flow(arc); + if (flow > 0) { + arc_stack.push_back(arc); + } + } + visited[source_] = true; + + // Start the dfs on the subgraph formed by the direct arcs with positive flow. + while (!arc_stack.empty()) { + const NodeIndex node = Head(arc_stack.back()); + + // If the node is visited, it means we have explored all its arcs and we + // have just backtracked in the dfs. Store it if it is not already stored + // and process the next arc on the stack. + if (visited[node]) { + if (!stored[node]) { + stored[node] = true; + reverse_topological_order.push_back(node); + assert(!index_branch.empty()); + index_branch.pop_back(); + } + arc_stack.pop_back(); + continue; + } + + // The node is a new unexplored node, add all its outgoing arcs with + // positive flow to the stack and go deeper in the dfs. + assert(!stored[node]); + assert(index_branch.empty() || + (arc_stack.size() - 1 > index_branch.back())); + visited[node] = true; + index_branch.push_back(arc_stack.size() - 1); + + for (OutgoingArcIterator it(*graph_, node); it.Ok(); it.Next()) { + const ArcIndex arc = it.Index(); + const FlowQuantity flow = Flow(arc); + const NodeIndex head = Head(arc); + if (flow > 0 && !stored[head]) { + if (!visited[head]) { + arc_stack.push_back(arc); + } else { + // There is a cycle. + // Find the first index to consider, + // arc_stack[index_branch[cycle_begin]] will be the first arc on the + // cycle. + int cycle_begin = index_branch.size(); + while (cycle_begin > 0 && + Head(arc_stack[index_branch[cycle_begin - 1]]) != head) { + --cycle_begin; + } + + // Compute the maximum flow that can be canceled on the cycle and the + // min index such that arc_stack[index_branch[i]] will be saturated. + FlowQuantity max_flow = flow; + int first_saturated_index = index_branch.size(); + for (int i = index_branch.size() - 1; i >= cycle_begin; --i) { + const ArcIndex arc_on_cycle = arc_stack[index_branch[i]]; + if (Flow(arc_on_cycle) <= max_flow) { + max_flow = Flow(arc_on_cycle); + first_saturated_index = i; + } + } + +#ifndef NDEBUG + // This is just here for a DCHECK() below. + const FlowQuantity excess = node_excess_[head]; +#endif + + // Cancel the flow on the cycle, and set visited[node] = false for + // the node that will be backtracked over. + PushFlow(-max_flow, arc); + for (int i = index_branch.size() - 1; i >= cycle_begin; --i) { + const ArcIndex arc_on_cycle = arc_stack[index_branch[i]]; + PushFlow(-max_flow, arc_on_cycle); + if (i >= first_saturated_index) { + assert(visited[Head(arc_on_cycle)]); + visited[Head(arc_on_cycle)] = false; + } else { + assert(Flow(arc_on_cycle) > 0); + } + } + + // This is a simple check that the flow was pushed properly. + assert(excess == node_excess_[head]); + + // Backtrack the dfs just before index_branch[first_saturated_index]. + // If the current node is still active, there is nothing to do. + if (first_saturated_index < index_branch.size()) { + arc_stack.resize(index_branch[first_saturated_index]); + index_branch.resize(first_saturated_index); + + // We backtracked over the current node, so there is no need to + // continue looping over its arcs. + break; + } + } + } + } + } + assert(arc_stack.empty()); + assert(index_branch.empty()); + + // Return the flow to the sink. Note that the sink_ and the source_ are not + // stored in reverse_topological_order. + for (int i = 0; i < reverse_topological_order.size(); i++) { + const NodeIndex node = reverse_topological_order[i]; + if (node_excess_[node] == 0) continue; + for (IncomingArcIterator it(*graph_, node); it.Ok(); it.Next()) { + const ArcIndex opposite_arc = Opposite(it.Index()); + if (residual_arc_capacity_[opposite_arc] > 0) { + const FlowQuantity flow = + std::min(node_excess_[node], residual_arc_capacity_[opposite_arc]); + PushFlow(flow, opposite_arc); + if (node_excess_[node] == 0) break; + } + } + assert(0 == node_excess_[node]); + } + assert(-node_excess_[source_] == node_excess_[sink_]); +} + +template +void GenericMaxFlow::GlobalUpdate() { + bfs_queue_.clear(); + int queue_index = 0; + const NodeIndex num_nodes = graph_->num_nodes(); + node_in_bfs_queue_.assign(num_nodes, false); + node_in_bfs_queue_[sink_] = true; + node_in_bfs_queue_[source_] = true; + + // We do two BFS in the reverse residual graph, one from the sink and one from + // the source. Because all the arcs from the source are saturated (except in + // presence of integer overflow), the source cannot reach the sink in the + // residual graph. However, we still want to relabel all the nodes that cannot + // reach the sink but can reach the source (because if they have excess, we + // need to push it back to the source). + // + // Note that the second pass is not needed here if we use a two-pass algorithm + // to return the flow to the source after we found the min cut. + const int num_passes = use_two_phase_algorithm_ ? 1 : 2; + for (int pass = 0; pass < num_passes; ++pass) { + if (pass == 0) { + bfs_queue_.push_back(sink_); + } else { + bfs_queue_.push_back(source_); + } + + while (queue_index != bfs_queue_.size()) { + const NodeIndex node = bfs_queue_[queue_index]; + ++queue_index; + const NodeIndex candidate_distance = node_potential_[node] + 1; + for (OutgoingOrOppositeIncomingArcIterator it(*graph_, node); it.Ok(); + it.Next()) { + const ArcIndex arc = it.Index(); + const NodeIndex head = Head(arc); + + // Skip the arc if the height of head was already set to the correct + // value (Remember we are doing reverse BFS). + if (node_in_bfs_queue_[head]) continue; + + // TODO(user): By using more memory we can speed this up quite a bit by + // avoiding to take the opposite arc here, too options: + // - if (residual_arc_capacity_[arc] != arc_capacity_[arc]) + // - if (opposite_arc_is_admissible_[arc]) // need updates. + // Experiment with the first option shows more than 10% gain on this + // function running time, which is the bottleneck on many instances. + const ArcIndex opposite_arc = Opposite(arc); + if (residual_arc_capacity_[opposite_arc] > 0) { + // Note(user): We used to have a DCHECK_GE(candidate_distance, + // node_potential_[head]); which is always true except in the case + // where we can push more than kMaxFlowQuantity out of the source. The + // problem comes from the fact that in this case, we call + // PushFlowExcessBackToSource() in the middle of the algorithm. The + // later call will break the properties of the node potential. Note + // however, that this function will recompute a good node potential + // for all the nodes and thus fix the issue. + + // If head is active, we can steal some or all of its excess. + // This brings a huge gain on some problems. + // Note(user): I haven't seen this anywhere in the literature. + // TODO(user): Investigate more and maybe write a publication :) + if (node_excess_[head] > 0) { + const FlowQuantity flow = std::min( + node_excess_[head], residual_arc_capacity_[opposite_arc]); + PushFlow(flow, opposite_arc); + + // If the arc became saturated, it is no longer in the residual + // graph, so we do not need to consider head at this time. + if (residual_arc_capacity_[opposite_arc] == 0) continue; + } + + // Note that there is no need to touch first_admissible_arc_[node] + // because of the relaxed Relabel() we use. + node_potential_[head] = candidate_distance; + node_in_bfs_queue_[head] = true; + bfs_queue_.push_back(head); + } + } + } + } + + // At the end of the search, some nodes may not be in the bfs_queue_. Such + // nodes cannot reach the sink_ or source_ in the residual graph, so there is + // no point trying to push flow toward them. We obtain this effect by setting + // their height to something unreachable. + // + // Note that this also prevents cycling due to our anti-overflow procedure. + // For instance, suppose there is an edge s -> n outgoing from the source. If + // node n has no other connection and some excess, we will push the flow back + // to the source, but if we don't update the height of n + // SaturateOutgoingArcsFromSource() will push the flow to n again. + // TODO(user): This is another argument for another anti-overflow algorithm. + for (NodeIndex node = 0; node < num_nodes; ++node) { + if (!node_in_bfs_queue_[node]) { + node_potential_[node] = 2 * num_nodes - 1; + } + } + + // Reset the active nodes. Doing it like this pushes the nodes in increasing + // order of height. Note that bfs_queue_[0] is the sink_ so we skip it. + assert(IsEmptyActiveNodeContainer()); + for (int i = 1; i < bfs_queue_.size(); ++i) { + const NodeIndex node = bfs_queue_[i]; + if (node_excess_[node] > 0) { + assert(IsActive(node)); + PushActiveNode(node); + } + } +} + +template +bool GenericMaxFlow::SaturateOutgoingArcsFromSource() { + const NodeIndex num_nodes = graph_->num_nodes(); + + // If sink_ or source_ already have kMaxFlowQuantity, then there is no + // point pushing more flow since it will cause an integer overflow. + if (node_excess_[sink_] == kMaxFlowQuantity) return false; + if (node_excess_[source_] == -kMaxFlowQuantity) return false; + + bool flow_pushed = false; + for (OutgoingArcIterator it(*graph_, source_); it.Ok(); it.Next()) { + const ArcIndex arc = it.Index(); + const FlowQuantity flow = residual_arc_capacity_[arc]; + + // This is a special IsAdmissible() condition for the source. + if (flow == 0 || node_potential_[Head(arc)] >= num_nodes) continue; + + // We are careful in case the sum of the flow out of the source is greater + // than kMaxFlowQuantity to avoid overflow. + const FlowQuantity current_flow_out_of_source = -node_excess_[source_]; + assert(flow >= 0); + assert(current_flow_out_of_source >= 0); + const FlowQuantity capped_flow = + kMaxFlowQuantity - current_flow_out_of_source; + if (capped_flow < flow) { + // We push as much flow as we can so the current flow on the network will + // be kMaxFlowQuantity. + + // Since at the beginning of the function, current_flow_out_of_source + // was different from kMaxFlowQuantity, we are sure to have pushed some + // flow before if capped_flow is 0. + if (capped_flow == 0) return true; + PushFlow(capped_flow, arc); + return true; + } + PushFlow(flow, arc); + flow_pushed = true; + } + assert(node_excess_[source_] <= 0); + return flow_pushed; +} + +template +void GenericMaxFlow::PushFlow(FlowQuantity flow, ArcIndex arc) { + // TODO(user): Do not allow a zero flow after fixing the UniformMaxFlow code. + assert(residual_arc_capacity_[Opposite(arc)] + flow >= 0); + assert(residual_arc_capacity_[arc] - flow >= 0); + + // node_excess_ should be always greater than or equal to 0 except for the + // source where it should always be smaller than or equal to 0. Note however + // that we cannot check this because when we cancel the flow on a cycle in + // PushFlowExcessBackToSource(), we may break this invariant during the + // operation even if it is still valid at the end. + + // Update the residual capacity of the arc and its opposite arc. + residual_arc_capacity_[arc] -= flow; + residual_arc_capacity_[Opposite(arc)] += flow; + + // Update the excesses at the tail and head of the arc. + node_excess_[Tail(arc)] -= flow; + node_excess_[Head(arc)] += flow; +} + +template +void GenericMaxFlow::InitializeActiveNodeContainer() { + assert(IsEmptyActiveNodeContainer()); + const NodeIndex num_nodes = graph_->num_nodes(); + for (NodeIndex node = 0; node < num_nodes; ++node) { + if (IsActive(node)) { + if (use_two_phase_algorithm_ && node_potential_[node] >= num_nodes) { + continue; + } + PushActiveNode(node); + } + } +} + +template +void GenericMaxFlow::Refine() { + // Usually SaturateOutgoingArcsFromSource() will saturate all the arcs from + // the source in one go, and we will loop just once. But in case we can push + // more than kMaxFlowQuantity out of the source the loop is as follow: + // - Push up to kMaxFlowQuantity out of the source on the admissible outgoing + // arcs. Stop if no flow was pushed. + // - Compute the current max-flow. This will push some flow back to the + // source and render more outgoing arcs from the source not admissible. + // + // TODO(user): This may not be the most efficient algorithm if we need to loop + // many times. An alternative may be to handle the source like the other nodes + // in the algorithm, initially putting an excess of kMaxFlowQuantity on it, + // and making the source active like any other node with positive excess. To + // investigate. + // + // TODO(user): The code below is buggy when more than kMaxFlowQuantity can be + // pushed out of the source (i.e. when we loop more than once in the while()). + // This is not critical, since this code is not used in the default algorithm + // computation. The issue is twofold: + // - InitializeActiveNodeContainer() doesn't push the nodes in + // the correct order. + // - PushFlowExcessBackToSource() may break the node potential properties, and + // we will need a call to GlobalUpdate() to fix that. + while (SaturateOutgoingArcsFromSource()) { + assert(IsEmptyActiveNodeContainer()); + InitializeActiveNodeContainer(); + while (!IsEmptyActiveNodeContainer()) { + const NodeIndex node = GetAndRemoveFirstActiveNode(); + if (node == source_ || node == sink_) continue; + Discharge(node); + } + if (use_two_phase_algorithm_) { + PushFlowExcessBackToSource(); + } + } +} + +template +void GenericMaxFlow::RefineWithGlobalUpdate() { + // TODO(user): This should be graph_->num_nodes(), but ebert graph does not + // have a correct size if the highest index nodes have no arcs. + const NodeIndex num_nodes = Graphs::NodeReservation(*graph_); + std::vector skip_active_node; + + while (SaturateOutgoingArcsFromSource()) { + int num_skipped; + do { + num_skipped = 0; + skip_active_node.assign(num_nodes, 0); + skip_active_node[sink_] = 2; + skip_active_node[source_] = 2; + GlobalUpdate(); + while (!IsEmptyActiveNodeContainer()) { + const NodeIndex node = GetAndRemoveFirstActiveNode(); + if (skip_active_node[node] > 1) { + if (node != sink_ && node != source_) ++num_skipped; + continue; + } + const NodeIndex old_height = node_potential_[node]; + Discharge(node); + + // The idea behind this is that if a node height augments by more than + // one, then it is likely to push flow back the way it came. This can + // lead to very costly loops. A bad case is: source -> n1 -> n2 and n2 + // just recently isolated from the sink. Then n2 will push flow back to + // n1, and n1 to n2 and so on. The height of each node will increase by + // steps of two until the height of the source is reached, which can + // take a long time. If the chain is longer, the situation is even + // worse. The behavior of this heuristic is related to the Gap + // heuristic. + // + // Note that the global update will fix all such cases efficiently. So + // the idea is to discharge the active node as much as possible, and + // then do a global update. + // + // We skip a node when this condition was true 2 times to avoid doing a + // global update too frequently. + if (node_potential_[node] > old_height + 1) { + ++skip_active_node[node]; + } + } + } while (num_skipped > 0); + if (use_two_phase_algorithm_) { + PushFlowExcessBackToSource(); + } + } +} + +template +void GenericMaxFlow::Discharge(NodeIndex node) { + const NodeIndex num_nodes = graph_->num_nodes(); + while (true) { + assert(IsActive(node)); + for (OutgoingOrOppositeIncomingArcIterator it(*graph_, node, + first_admissible_arc_[node]); + it.Ok(); it.Next()) { + const ArcIndex arc = it.Index(); + if (IsAdmissible(arc)) { + assert(IsActive(node)); + const NodeIndex head = Head(arc); + if (node_excess_[head] == 0) { + // The push below will make the node active for sure. Note that we may + // push the sink_, but that is handled properly in Refine(). + PushActiveNode(head); + } + const FlowQuantity delta = + std::min(node_excess_[node], residual_arc_capacity_[arc]); + PushFlow(delta, arc); + if (node_excess_[node] == 0) { + first_admissible_arc_[node] = arc; // arc may still be admissible. + return; + } + } + } + Relabel(node); + if (use_two_phase_algorithm_ && node_potential_[node] >= num_nodes) break; + } +} + +template +void GenericMaxFlow::Relabel(NodeIndex node) { + // Because we use a relaxed version, this is no longer true if the + // first_admissible_arc_[node] was not actually the first arc! + // DCHECK(CheckRelabelPrecondition(node)); + NodeHeight min_height = std::numeric_limits::max(); + ArcIndex first_admissible_arc = Graph::kNilArc; + for (OutgoingOrOppositeIncomingArcIterator it(*graph_, node); it.Ok(); + it.Next()) { + const ArcIndex arc = it.Index(); + if (residual_arc_capacity_[arc] > 0) { + // Update min_height only for arcs with available capacity. + NodeHeight head_height = node_potential_[Head(arc)]; + if (head_height < min_height) { + min_height = head_height; + first_admissible_arc = arc; + + // We found an admissible arc at the current height, just stop there. + // This is the true first_admissible_arc_[node]. + if (min_height + 1 == node_potential_[node]) break; + } + } + } + assert(first_admissible_arc != Graph::kNilArc); + node_potential_[node] = min_height + 1; + + // Note that after a Relabel(), the loop will continue in Discharge(), and + // we are sure that all the arcs before first_admissible_arc are not + // admissible since their height is > min_height. + first_admissible_arc_[node] = first_admissible_arc; +} + +template +typename Graph::ArcIndex GenericMaxFlow::Opposite(ArcIndex arc) const { + return Graphs::OppositeArc(*graph_, arc); +} + +template +bool GenericMaxFlow::IsArcDirect(ArcIndex arc) const { + return IsArcValid(arc) && arc >= 0; +} + +template +bool GenericMaxFlow::IsArcValid(ArcIndex arc) const { + return Graphs::IsArcValid(*graph_, arc); +} + +template +const FlowQuantity GenericMaxFlow::kMaxFlowQuantity = + std::numeric_limits::max(); + +template +template +void GenericMaxFlow::ComputeReachableNodes( + NodeIndex start, std::vector* result) { + // If start is not a valid node index, it can reach only itself. + // Note(user): This is needed because source and sink are given independently + // of the graph and sometimes before it is even constructed. + const NodeIndex num_nodes = graph_->num_nodes(); + if (start >= num_nodes) { + result->clear(); + result->push_back(start); + return; + } + bfs_queue_.clear(); + node_in_bfs_queue_.assign(num_nodes, false); + + int queue_index = 0; + bfs_queue_.push_back(start); + node_in_bfs_queue_[start] = true; + while (queue_index != bfs_queue_.size()) { + const NodeIndex node = bfs_queue_[queue_index]; + ++queue_index; + for (OutgoingOrOppositeIncomingArcIterator it(*graph_, node); it.Ok(); + it.Next()) { + const ArcIndex arc = it.Index(); + const NodeIndex head = Head(arc); + if (node_in_bfs_queue_[head]) continue; + if (residual_arc_capacity_[reverse ? Opposite(arc) : arc] == 0) continue; + node_in_bfs_queue_[head] = true; + bfs_queue_.push_back(head); + } + } + *result = bfs_queue_; +} + +// Explicit instantiations that can be used by a client. +// +// TODO(user): moves this code out of a .cc file and include it at the end of +// the header so it can work with any graph implementation ? +template <> +const FlowQuantity GenericMaxFlow::kMaxFlowQuantity = + std::numeric_limits::max(); +template <> +const FlowQuantity + GenericMaxFlow<::util::ReverseArcListGraph<>>::kMaxFlowQuantity = + std::numeric_limits::max(); +template <> +const FlowQuantity + GenericMaxFlow<::util::ReverseArcStaticGraph<>>::kMaxFlowQuantity = + std::numeric_limits::max(); +template <> +const FlowQuantity + GenericMaxFlow<::util::ReverseArcMixedGraph<>>::kMaxFlowQuantity = + std::numeric_limits::max(); + +template class GenericMaxFlow; +template class GenericMaxFlow<::util::ReverseArcListGraph<>>; +template class GenericMaxFlow<::util::ReverseArcStaticGraph<>>; +template class GenericMaxFlow<::util::ReverseArcMixedGraph<>>; + +} // namespace operations_research diff --git a/cxx/isce3/unwrap/ortools/max_flow.h b/cxx/isce3/unwrap/ortools/max_flow.h new file mode 100644 index 000000000..c42db5eb5 --- /dev/null +++ b/cxx/isce3/unwrap/ortools/max_flow.h @@ -0,0 +1,699 @@ +// Copyright 2010-2021 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// An implementation of a push-relabel algorithm for the max flow problem. +// +// In the following, we consider a graph G = (V,E,s,t) where V denotes the set +// of nodes (vertices) in the graph, E denotes the set of arcs (edges). s and t +// denote distinguished nodes in G called source and target. n = |V| denotes the +// number of nodes in the graph, and m = |E| denotes the number of arcs in the +// graph. +// +// Each arc (v,w) is associated a capacity c(v,w). +// +// A flow is a function from E to R such that: +// +// a) f(v,w) <= c(v,w) for all (v,w) in E (capacity constraint.) +// +// b) f(v,w) = -f(w,v) for all (v,w) in E (flow antisymmetry constraint.) +// +// c) sum on v f(v,w) = 0 (flow conservation.) +// +// The goal of this algorithm is to find the maximum flow from s to t, i.e. +// for example to maximize sum v f(s,v). +// +// The starting reference for this class of algorithms is: +// A.V. Goldberg and R.E. Tarjan. A new approach to the maximum flow problem. +// ACM Symposium on Theory of Computing, pp. 136-146. +// http://portal.acm.org/citation.cfm?id=12144. +// +// The basic idea of the algorithm is to handle preflows instead of flows, +// and to refine preflows until a maximum flow is obtained. +// A preflow is like a flow, except that the inflow can be larger than the +// outflow. If it is the case at a given node v, it is said that there is an +// excess at node v, and inflow = outflow + excess. +// +// More formally, a preflow is a function f such that: +// +// 1) f(v,w) <= c(v,w) for all (v,w) in E (capacity constraint). c(v,w) is a +// value representing the maximum capacity for arc (v,w). +// +// 2) f(v,w) = -f(w,v) for all (v,w) in E (flow antisymmetry constraint) +// +// 3) excess(v) = sum on u f(u,v) >= 0 is the excess at node v, the +// algebraic sum of all the incoming preflows at this node. +// +// Each node has an associated "height", in addition to its excess. The +// height of the source is defined to be equal to n, and cannot change. The +// height of the target is defined to be zero, and cannot change either. The +// height of all the other nodes is initialized at zero and is updated during +// the algorithm (see below). For those who want to know the details, the height +// of a node, corresponds to a reduced cost, and this enables one to prove that +// the algorithm actually computes the max flow. Note that the height of a node +// can be initialized to the distance to the target node in terms of number of +// nodes. This has not been tried in this implementation. +// +// A node v is said to be *active* if excess(v) > 0. +// +// In this case the following operations can be applied to it: +// +// - if there are *admissible* incident arcs, i.e. arcs which are not saturated, +// and whose head's height is lower than the height of the active node +// considered, a PushFlow operation can be applied. It consists in sending as +// much flow as both the excess at the node and the capacity of the arc +// permit. +// - if there are no admissible arcs, the active node considered is relabeled, +// i.e. its height is increased to 1 + the minimum height of its neighboring +// nodes on admissible arcs. +// This is implemented in Discharge, which itself calls PushFlow and Relabel. +// +// Before running Discharge, it is necessary to initialize the algorithm with a +// preflow. This is done in InitializePreflow, which saturates all the arcs +// leaving the source node, and sets the excess at the heads of those arcs +// accordingly. +// +// The algorithm terminates when there are no remaining active nodes, i.e. all +// the excesses at all nodes are equal to zero. In this case, a maximum flow is +// obtained. +// +// The complexity of this algorithm depends amongst other things on the choice +// of the next active node. It has been shown, for example in: +// L. Tuncel, "On the Complexity of Preflow-Push Algorithms for Maximum-Flow +// Problems", Algorithmica 11(4): 353-359 (1994). +// and +// J. Cheriyan and K. Mehlhorn, "An analysis of the highest-level selection rule +// in the preflow-push max-flow algorithm", Information processing letters, +// 69(5):239-242 (1999). +// http://www.math.uwaterloo.ca/~jcheriya/PS_files/me3.0.ps +// +// ...that choosing the active node with the highest level yields a +// complexity of O(n^2 * sqrt(m)). +// +// TODO(user): implement the above active node choice rule. +// +// This has been validated experimentally in: +// R.K. Ahuja, M. Kodialam, A.K. Mishra, and J.B. Orlin, "Computational +// Investigations of Maximum Flow Algorithms", EJOR 97:509-542(1997). +// http://jorlin.scripts.mit.edu/docs/publications/58-comput%20investigations%20of.pdf. +// +// +// TODO(user): an alternative would be to evaluate: +// A.V. Goldberg, "The Partial Augment-Relabel Algorithm for the Maximum Flow +// Problem.” In Proceedings of Algorithms ESA, LNCS 5193:466-477, Springer 2008. +// http://www.springerlink.com/index/5535k2j1mt646338.pdf +// +// An interesting general reference on network flows is: +// R. K. Ahuja, T. L. Magnanti, J. B. Orlin, "Network Flows: Theory, Algorithms, +// and Applications," Prentice Hall, 1993, ISBN: 978-0136175490, +// http://www.amazon.com/dp/013617549X +// +// Keywords: Push-relabel, max-flow, network, graph, Goldberg, Tarjan, Dinic, +// Dinitz. + +#ifndef OR_TOOLS_GRAPH_MAX_FLOW_H_ +#define OR_TOOLS_GRAPH_MAX_FLOW_H_ + +#include +#include +#include +#include +#include + +#include "ebert_graph.h" +#include "graph.h" +#include "zvector.h" + +namespace operations_research { + +// Forward declaration. +template +class GenericMaxFlow; + +// A simple and efficient max-cost flow interface. This is as fast as +// GenericMaxFlow, which is the fastest, but uses +// more memory in order to hide the somewhat involved construction of the +// static graph. +// +// TODO(user): If the need arises, extend this interface to support warm start. +class SimpleMaxFlow { + public: + // The constructor takes no size. + // New node indices will be created lazily by AddArcWithCapacity(). + SimpleMaxFlow(); + + // Adds a directed arc with the given capacity from tail to head. + // * Node indices and capacity must be non-negative (>= 0). + // * Self-looping and duplicate arcs are supported. + // * After the method finishes, NumArcs() == the returned ArcIndex + 1. + ArcIndex AddArcWithCapacity(NodeIndex tail, NodeIndex head, + FlowQuantity capacity); + + // Returns the current number of nodes. This is one more than the largest + // node index seen so far in AddArcWithCapacity(). + NodeIndex NumNodes() const; + + // Returns the current number of arcs in the graph. + ArcIndex NumArcs() const; + + // Returns user-provided data. + // The implementation will crash if "arc" is not in [0, NumArcs()). + NodeIndex Tail(ArcIndex arc) const; + NodeIndex Head(ArcIndex arc) const; + FlowQuantity Capacity(ArcIndex arc) const; + + // Solves the problem (finds the maximum flow from the given source to the + // given sink), and returns the problem status. + enum Status { + // Solve() was called and found an optimal solution. Note that OptimalFlow() + // may be 0 which means that the sink is not reachable from the source. + OPTIMAL, + // There is a flow > std::numeric_limits::max(). Note that in + // this case, the class will contain a solution with a flow reaching that + // bound. + // + // TODO(user): rename POSSIBLE_OVERFLOW to INT_OVERFLOW and modify our + // clients. + POSSIBLE_OVERFLOW, + // The input is inconsistent (bad tail/head/capacity values). + BAD_INPUT, + // This should not happen. There was an error in our code (i.e. file a bug). + BAD_RESULT + }; + Status Solve(NodeIndex source, NodeIndex sink); + + // Returns the maximum flow we can send from the source to the sink in the + // last OPTIMAL Solve() context. + FlowQuantity OptimalFlow() const; + + // Returns the flow on the given arc in the last OPTIMAL Solve() context. + // + // Note: It is possible that there is more than one optimal solution. The + // algorithm is deterministic so it will always return the same solution for + // a given problem. However, there is no guarantee of this from one code + // version to the next (but the code does not change often). + FlowQuantity Flow(ArcIndex arc) const; + + // Returns the nodes reachable from the source by non-saturated arcs (.i.e. + // arc with Flow(arc) < Capacity(arc)), the outgoing arcs of this set form a + // minimum cut. This works only if Solve() returned OPTIMAL. + void GetSourceSideMinCut(std::vector* result); + + // Returns the nodes that can reach the sink by non-saturated arcs, the + // outgoing arcs of this set form a minimum cut. Note that if this is the + // complement set of GetNodeReachableFromSource(), then the min-cut is unique. + // This works only if Solve() returned OPTIMAL. + void GetSinkSideMinCut(std::vector* result); + + // Change the capacity of an arc. + // + // WARNING: This looks like it enables incremental solves, but as of 2018-02, + // the next Solve() will restart from scratch anyway. + // TODO(user): Support incrementality in the max flow implementation. + void SetArcCapacity(ArcIndex arc, FlowQuantity capacity); + + private: + NodeIndex num_nodes_; + std::vector arc_tail_; + std::vector arc_head_; + std::vector arc_capacity_; + std::vector arc_permutation_; + std::vector arc_flow_; + FlowQuantity optimal_flow_; + + // Note that we cannot free the graph before we stop using the max-flow + // instance that uses it. + typedef ::util::ReverseArcStaticGraph Graph; + std::unique_ptr underlying_graph_; + std::unique_ptr > underlying_max_flow_; + + SimpleMaxFlow(const SimpleMaxFlow&); + SimpleMaxFlow& operator=(const SimpleMaxFlow&); +}; + +// Specific but efficient priority queue implementation. The priority type must +// be an integer. The queue allows to retrieve the element with highest priority +// but only allows pushes with a priority greater or equal to the highest +// priority in the queue minus one. All operations are in O(1) and the memory is +// in O(num elements in the queue). Elements with the same priority are +// retrieved with LIFO order. +// +// Note(user): As far as I know, this is an original idea and is the only code +// that use this in the Maximum Flow context. Papers usually refer to an +// height-indexed array of simple linked lists of active node with the same +// height. Even worse, sometimes they use double-linked list to allow arbitrary +// height update in order to detect missing height (used for the Gap heuristic). +// But this can actually be implemented a lot more efficiently by just +// maintaining the height distribution of all the node in the graph. +template +class PriorityQueueWithRestrictedPush { + public: + PriorityQueueWithRestrictedPush() : even_queue_(), odd_queue_() {} + + // Is the queue empty? + bool IsEmpty() const; + + // Clears the queue. + void Clear(); + + // Push a new element in the queue. Its priority must be greater or equal to + // the highest priority present in the queue, minus one. This condition is + // DCHECKed, and violating it yields erroneous queue behavior in NDEBUG mode. + void Push(Element element, IntegerPriority priority); + + // Returns the element with highest priority and remove it from the queue. + // IsEmpty() must be false, this condition is DCHECKed. + Element Pop(); + + private: + // Helper function to get the last element of a vector and pop it. + Element PopBack(std::vector >* queue); + + // This is the heart of the algorithm. basically we split the elements by + // parity of their priority and the precondition on the Push() ensures that + // both vectors are always sorted by increasing priority. + std::vector > even_queue_; + std::vector > odd_queue_; + + PriorityQueueWithRestrictedPush(const PriorityQueueWithRestrictedPush&); + PriorityQueueWithRestrictedPush& operator=(const PriorityQueueWithRestrictedPush&); +}; + +// We want an enum for the Status of a max flow run, and we want this +// enum to be scoped under GenericMaxFlow<>. Unfortunately, swig +// doesn't handle templated enums very well, so we need a base, +// untemplated class to hold it. +class MaxFlowStatusClass { + public: + enum Status { + NOT_SOLVED, // The problem was not solved, or its data were edited. + OPTIMAL, // Solve() was called and found an optimal solution. + INT_OVERFLOW, // There is a feasible flow > max possible flow. + BAD_INPUT, // The input is inconsistent. + BAD_RESULT // There was an error. + }; +}; + +// Generic MaxFlow (there is a default MaxFlow specialization defined below) +// that works with StarGraph and all the reverse arc graphs from graph.h, see +// the end of max_flow.cc for the exact types this class is compiled for. +template +class GenericMaxFlow : public MaxFlowStatusClass { + public: + typedef typename Graph::NodeIndex NodeIndex; + typedef typename Graph::ArcIndex ArcIndex; + typedef typename Graph::OutgoingArcIterator OutgoingArcIterator; + typedef typename Graph::OutgoingOrOppositeIncomingArcIterator + OutgoingOrOppositeIncomingArcIterator; + typedef typename Graph::IncomingArcIterator IncomingArcIterator; + typedef ZVector ArcIndexArray; + + // The height of a node never excess 2 times the number of node, so we + // use the same type as a Node index. + typedef NodeIndex NodeHeight; + typedef ZVector NodeHeightArray; + + // Initialize a MaxFlow instance on the given graph. The graph does not need + // to be fully built yet, but its capacity reservation are used to initialize + // the memory of this class. source and sink must also be valid node of + // graph. + GenericMaxFlow(const Graph* graph, NodeIndex source, NodeIndex sink); + virtual ~GenericMaxFlow() {} + + // Returns the graph associated to the current object. + const Graph* graph() const { return graph_; } + + // Returns the status of last call to Solve(). NOT_SOLVED is returned if + // Solve() has never been called or if the problem has been modified in such a + // way that the previous solution becomes invalid. + Status status() const { return status_; } + + // Returns the index of the node corresponding to the source of the network. + NodeIndex GetSourceNodeIndex() const { return source_; } + + // Returns the index of the node corresponding to the sink of the network. + NodeIndex GetSinkNodeIndex() const { return sink_; } + + // Sets the capacity for arc to new_capacity. + void SetArcCapacity(ArcIndex arc, FlowQuantity new_capacity); + + // Sets the flow for arc. + void SetArcFlow(ArcIndex arc, FlowQuantity new_flow); + + // Returns true if a maximum flow was solved. + bool Solve(); + + // Returns the total flow found by the algorithm. + FlowQuantity GetOptimalFlow() const { return node_excess_[sink_]; } + + // Returns the flow on arc using the equations given in the comment on + // residual_arc_capacity_. + FlowQuantity Flow(ArcIndex arc) const { + if (IsArcDirect(arc)) { + return residual_arc_capacity_[Opposite(arc)]; + } else { + return -residual_arc_capacity_[arc]; + } + } + + // Returns the capacity of arc using the equations given in the comment on + // residual_arc_capacity_. + FlowQuantity Capacity(ArcIndex arc) const { + if (IsArcDirect(arc)) { + return residual_arc_capacity_[arc] + + residual_arc_capacity_[Opposite(arc)]; + } else { + return 0; + } + } + + // Returns the nodes reachable from the source in the residual graph, the + // outgoing arcs of this set form a minimum cut. + void GetSourceSideMinCut(std::vector* result); + + // Returns the nodes that can reach the sink in the residual graph, the + // outgoing arcs of this set form a minimum cut. Note that if this is the + // complement of GetNodeReachableFromSource(), then the min-cut is unique. + // + // TODO(user): In the two-phases algorithm, we can get this minimum cut + // without doing the second phase. Add an option for this if there is a need + // to, note that the second phase is pretty fast so the gain will be small. + void GetSinkSideMinCut(std::vector* result); + + // Checks the consistency of the input, i.e. that capacities on the arcs are + // non-negative or null. + bool CheckInputConsistency() const; + + // Checks whether the result is valid, i.e. that node excesses are all equal + // to zero (we have a flow) and that residual capacities are all non-negative + // or zero. + bool CheckResult() const; + + // Returns true if there exists a path from the source to the sink with + // remaining capacity. This allows us to easily check at the end that the flow + // we computed is indeed optimal (provided that all the conditions tested by + // CheckResult() also hold). + bool AugmentingPathExists() const; + + // Sets the different algorithm options. All default to true. + // See the corresponding variable declaration below for more details. + void SetUseGlobalUpdate(bool value) { + use_global_update_ = value; + if (!use_global_update_) process_node_by_height_ = false; + } + void SetUseTwoPhaseAlgorithm(bool value) { use_two_phase_algorithm_ = value; } + void SetCheckInput(bool value) { check_input_ = value; } + void SetCheckResult(bool value) { check_result_ = value; } + void ProcessNodeByHeight(bool value) { + process_node_by_height_ = value && use_global_update_; + } + + protected: + // Returns true if arc is admissible. + bool IsAdmissible(ArcIndex arc) const { + return residual_arc_capacity_[arc] > 0 && + node_potential_[Tail(arc)] == node_potential_[Head(arc)] + 1; + } + + // Returns true if node is active, i.e. if its excess is positive and it + // is neither the source or the sink of the graph. + bool IsActive(NodeIndex node) const { + return (node != source_) && (node != sink_) && (node_excess_[node] > 0); + } + + // Sets the capacity of arc to 'capacity' and clears the flow on arc. + void SetCapacityAndClearFlow(ArcIndex arc, FlowQuantity capacity) { + residual_arc_capacity_.Set(arc, capacity); + residual_arc_capacity_.Set(Opposite(arc), 0); + } + + // Returns true if a precondition for Relabel is met, i.e. the outgoing arcs + // of node are all either saturated or the heights of their heads are greater + // or equal to the height of node. + bool CheckRelabelPrecondition(NodeIndex node) const; + + // Returns context concatenated with information about arc + // in a human-friendly way. + std::string DebugString(const std::string& context, ArcIndex arc) const; + + // Initializes the container active_nodes_. + void InitializeActiveNodeContainer(); + + // Get the first element from the active node container. + NodeIndex GetAndRemoveFirstActiveNode() { + if (process_node_by_height_) return active_node_by_height_.Pop(); + const NodeIndex node = active_nodes_.back(); + active_nodes_.pop_back(); + return node; + } + + // Push element to the active node container. + void PushActiveNode(const NodeIndex& node) { + if (process_node_by_height_) { + active_node_by_height_.Push(node, node_potential_[node]); + } else { + active_nodes_.push_back(node); + } + } + + // Check the emptiness of the container. + bool IsEmptyActiveNodeContainer() { + if (process_node_by_height_) { + return active_node_by_height_.IsEmpty(); + } else { + return active_nodes_.empty(); + } + } + + // Performs optimization step. + void Refine(); + void RefineWithGlobalUpdate(); + + // Discharges an active node node by saturating its admissible adjacent arcs, + // if any, and by relabelling it when it becomes inactive. + void Discharge(NodeIndex node); + + // Initializes the preflow to a state that enables to run Refine. + void InitializePreflow(); + + // Clears the flow excess at each node by pushing the flow back to the source: + // - Do a depth-first search from the source in the direct graph to cancel + // flow cycles. + // - Then, return flow excess along the depth-first search tree (by pushing + // the flow in the reverse dfs topological order). + // The theoretical complexity is O(mn), but it is a lot faster in practice. + void PushFlowExcessBackToSource(); + + // Computes the best possible node potential given the current flow using a + // reverse breadth-first search from the sink in the reverse residual graph. + // This is an implementation of the global update heuristic mentioned in many + // max-flow papers. See for instance: B.V. Cherkassky, A.V. Goldberg, "On + // implementing push-relabel methods for the maximum flow problem", + // Algorithmica, 19:390-410, 1997. + // ftp://reports.stanford.edu/pub/cstr/reports/cs/tr/94/1523/CS-TR-94-1523.pdf + void GlobalUpdate(); + + // Tries to saturate all the outgoing arcs from the source that can reach the + // sink. Most of the time, we can do that in one go, except when more flow + // than kMaxFlowQuantity can be pushed out of the source in which case we + // have to be careful. Returns true if some flow was pushed. + bool SaturateOutgoingArcsFromSource(); + + // Pushes flow on arc, i.e. consumes flow on residual_arc_capacity_[arc], + // and consumes -flow on residual_arc_capacity_[Opposite(arc)]. Updates + // node_excess_ at the tail and head of arc accordingly. + void PushFlow(FlowQuantity flow, ArcIndex arc); + + // Relabels a node, i.e. increases its height by the minimum necessary amount. + // This version of Relabel is relaxed in a way such that if an admissible arc + // exists at the current node height, then the node is not relabeled. This + // enables us to deal with wrong values of first_admissible_arc_[node] when + // updating it is too costly. + void Relabel(NodeIndex node); + + // Handy member functions to make the code more compact. + NodeIndex Head(ArcIndex arc) const { return graph_->Head(arc); } + NodeIndex Tail(ArcIndex arc) const { return graph_->Tail(arc); } + ArcIndex Opposite(ArcIndex arc) const; + bool IsArcDirect(ArcIndex arc) const; + bool IsArcValid(ArcIndex arc) const; + + // Returns the set of nodes reachable from start in the residual graph or in + // the reverse residual graph (if reverse is true). + template + void ComputeReachableNodes(NodeIndex start, std::vector* result); + + // Maximum manageable flow. + static const FlowQuantity kMaxFlowQuantity; + + // A pointer to the graph passed as argument. + const Graph* graph_; + + // An array representing the excess for each node in graph_. + QuantityArray node_excess_; + + // An array representing the height function for each node in graph_. For a + // given node, this is a lower bound on the shortest path length from this + // node to the sink in the residual network. The height of a node always goes + // up during the course of a Solve(). + // + // Since initially we saturate all the outgoing arcs of the source, we can + // never reach the sink from the source in the residual graph. Initially we + // set the height of the source to n (the number of node of the graph) and it + // never changes. If a node as an height >= n, then this node can't reach the + // sink and its height minus n is a lower bound on the shortest path length + // from this node to the source in the residual graph. + NodeHeightArray node_potential_; + + // An array representing the residual_capacity for each arc in graph_. + // Residual capacities enable one to represent the capacity and flow for all + // arcs in the graph in the following manner. + // For all arc, residual_arc_capacity_[arc] = capacity[arc] - flow[arc] + // Moreover, for reverse arcs, capacity[arc] = 0 by definition, + // Also flow[Opposite(arc)] = -flow[arc] by definition. + // Therefore: + // - for a direct arc: + // flow[arc] = 0 - flow[Opposite(arc)] + // = capacity[Opposite(arc)] - flow[Opposite(arc)] + // = residual_arc_capacity_[Opposite(arc)] + // - for a reverse arc: + // flow[arc] = -residual_arc_capacity_[arc] + // Using these facts enables one to only maintain residual_arc_capacity_, + // instead of both capacity and flow, for each direct and indirect arc. This + // reduces the amount of memory for this information by a factor 2. + QuantityArray residual_arc_capacity_; + + // An array representing the first admissible arc for each node in graph_. + ArcIndexArray first_admissible_arc_; + + // A stack used for managing active nodes in the algorithm. + // Note that the papers cited above recommend the use of a queue, but + // benchmarking so far has not proved it is better. In particular, processing + // nodes in LIFO order has better cache locality. + std::vector active_nodes_; + + // A priority queue used for managing active nodes in the algorithm. It allows + // to select the active node with highest height before each Discharge(). + // Moreover, since all pushes from this node will be to nodes with height + // greater or equal to the initial discharged node height minus one, the + // PriorityQueueWithRestrictedPush is a perfect fit. + PriorityQueueWithRestrictedPush active_node_by_height_; + + // The index of the source node in graph_. + NodeIndex source_; + + // The index of the sink node in graph_. + NodeIndex sink_; + + // The status of the problem. + Status status_; + + // BFS queue used by the GlobalUpdate() function. We do not use a C++ queue + // because we need access to the vector for different optimizations. + std::vector node_in_bfs_queue_; + std::vector bfs_queue_; + + // Whether or not to use GlobalUpdate(). + bool use_global_update_; + + // Whether or not we use a two-phase algorithm: + // 1/ Only deal with nodes that can reach the sink. At the end we know the + // value of the maximum flow and we have a min-cut. + // 2/ Call PushFlowExcessBackToSource() to obtain a max-flow. This is usually + // a lot faster than the first phase. + bool use_two_phase_algorithm_; + + // Whether or not we use the PriorityQueueWithRestrictedPush to process the + // active nodes rather than a simple queue. This can only be true if + // use_global_update_ is true. + // + // Note(user): using a template will be slightly faster, but since we test + // this in a non-critical path, this only has a minor impact. + bool process_node_by_height_; + + // Whether or not we check the input, this is a small price to pay for + // robustness. Disable only if you know the input is valid because an invalid + // input can cause the algorithm to run into an infinite loop! + bool check_input_; + + // Whether or not we check the result. + // TODO(user): Make the check more exhaustive by checking the optimality? + bool check_result_; + + private: + GenericMaxFlow(const GenericMaxFlow&); + GenericMaxFlow& operator=(const GenericMaxFlow&); +}; + +// Default instance MaxFlow that uses StarGraph. Note that we cannot just use a +// typedef because of dependent code expecting MaxFlow to be a real class. +// TODO(user): Modify this code and remove it. +class MaxFlow : public GenericMaxFlow { + public: + MaxFlow(const StarGraph* graph, NodeIndex source, NodeIndex target) + : GenericMaxFlow(graph, source, target) {} +}; + +template +bool PriorityQueueWithRestrictedPush::IsEmpty() + const { + return even_queue_.empty() && odd_queue_.empty(); +} + +template +void PriorityQueueWithRestrictedPush::Clear() { + even_queue_.clear(); + odd_queue_.clear(); +} + +template +void PriorityQueueWithRestrictedPush::Push( + Element element, IntegerPriority priority) { + // Since users may rely on it, we DCHECK the exact condition. + assert(even_queue_.empty() || priority >= even_queue_.back().second - 1); + assert(odd_queue_.empty() || priority >= odd_queue_.back().second - 1); + + // Note that the DCHECK() below are less restrictive than the ones above but + // check a necessary and sufficient condition for the priority queue to behave + // as expected. + if (priority & 1) { + assert(odd_queue_.empty() || priority >= odd_queue_.back().second); + odd_queue_.push_back(std::make_pair(element, priority)); + } else { + assert(even_queue_.empty() || priority >= even_queue_.back().second); + even_queue_.push_back(std::make_pair(element, priority)); + } +} + +template +Element PriorityQueueWithRestrictedPush::Pop() { + assert(!IsEmpty()); + if (even_queue_.empty()) return PopBack(&odd_queue_); + if (odd_queue_.empty()) return PopBack(&even_queue_); + if (odd_queue_.back().second > even_queue_.back().second) { + return PopBack(&odd_queue_); + } else { + return PopBack(&even_queue_); + } +} + +template +Element PriorityQueueWithRestrictedPush::PopBack( + std::vector >* queue) { + assert(!queue->empty()); + Element element = queue->back().first; + queue->pop_back(); + return element; +} + +} // namespace operations_research +#endif // OR_TOOLS_GRAPH_MAX_FLOW_H_ diff --git a/cxx/isce3/unwrap/ortools/min_cost_flow.cc b/cxx/isce3/unwrap/ortools/min_cost_flow.cc new file mode 100644 index 000000000..663835cab --- /dev/null +++ b/cxx/isce3/unwrap/ortools/min_cost_flow.cc @@ -0,0 +1,1209 @@ +// Copyright 2010-2021 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "min_cost_flow.h" + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "graphs.h" +#include "max_flow.h" + +namespace operations_research { + +template +GenericMinCostFlow::GenericMinCostFlow( + const Graph* graph) + : graph_(graph), + node_excess_(), + node_potential_(), + residual_arc_capacity_(), + first_admissible_arc_(), + active_nodes_(), + epsilon_(0), + alpha_(5), + cost_scaling_factor_(1), + scaled_arc_unit_cost_(), + total_flow_cost_(0), + status_(NOT_SOLVED), + initial_node_excess_(), + feasible_node_excess_(), + feasibility_checked_(false), + use_price_update_(false), + check_feasibility_(true) { + const NodeIndex max_num_nodes = Graphs::NodeReservation(*graph_); + if (max_num_nodes > 0) { + node_excess_.Reserve(0, max_num_nodes - 1); + node_excess_.SetAll(0); + node_potential_.Reserve(0, max_num_nodes - 1); + node_potential_.SetAll(0); + first_admissible_arc_.Reserve(0, max_num_nodes - 1); + first_admissible_arc_.SetAll(Graph::kNilArc); + initial_node_excess_.Reserve(0, max_num_nodes - 1); + initial_node_excess_.SetAll(0); + feasible_node_excess_.Reserve(0, max_num_nodes - 1); + feasible_node_excess_.SetAll(0); + } + const ArcIndex max_num_arcs = Graphs::ArcReservation(*graph_); + if (max_num_arcs > 0) { + residual_arc_capacity_.Reserve(-max_num_arcs, max_num_arcs - 1); + residual_arc_capacity_.SetAll(0); + scaled_arc_unit_cost_.Reserve(-max_num_arcs, max_num_arcs - 1); + scaled_arc_unit_cost_.SetAll(0); + } +} + +template +void GenericMinCostFlow::SetNodeSupply( + NodeIndex node, FlowQuantity supply) { + assert(graph_->IsNodeValid(node)); + node_excess_.Set(node, supply); + initial_node_excess_.Set(node, supply); + status_ = NOT_SOLVED; + feasibility_checked_ = false; +} + +template +void GenericMinCostFlow::SetArcUnitCost( + ArcIndex arc, ArcScaledCostType unit_cost) { + assert(IsArcDirect(arc)); + scaled_arc_unit_cost_.Set(arc, unit_cost); + scaled_arc_unit_cost_.Set(Opposite(arc), -scaled_arc_unit_cost_[arc]); + status_ = NOT_SOLVED; + feasibility_checked_ = false; +} + +template +void GenericMinCostFlow::SetArcCapacity( + ArcIndex arc, ArcFlowType new_capacity) { + assert(0 <= new_capacity); + assert(IsArcDirect(arc)); + const FlowQuantity free_capacity = residual_arc_capacity_[arc]; + const FlowQuantity capacity_delta = new_capacity - Capacity(arc); + if (capacity_delta == 0) { + return; // Nothing to do. + } + status_ = NOT_SOLVED; + feasibility_checked_ = false; + const FlowQuantity new_availability = free_capacity + capacity_delta; + if (new_availability >= 0) { + // The above condition is true when one of two following holds: + // 1/ (capacity_delta > 0), meaning we are increasing the capacity + // 2/ (capacity_delta < 0 && free_capacity + capacity_delta >= 0) + // meaning we are reducing the capacity, but that the capacity + // reduction is not larger than the free capacity. + assert((capacity_delta > 0) || + (capacity_delta < 0 && new_availability >= 0)); + residual_arc_capacity_.Set(arc, new_availability); + assert(0 <= residual_arc_capacity_[arc]); + } else { + // We have to reduce the flow on the arc, and update the excesses + // accordingly. + const FlowQuantity flow = residual_arc_capacity_[Opposite(arc)]; + const FlowQuantity flow_excess = flow - new_capacity; + residual_arc_capacity_.Set(arc, 0); + residual_arc_capacity_.Set(Opposite(arc), new_capacity); + const NodeIndex tail = Tail(arc); + node_excess_.Set(tail, node_excess_[tail] + flow_excess); + const NodeIndex head = Head(arc); + node_excess_.Set(head, node_excess_[head] - flow_excess); + assert(0 <= residual_arc_capacity_[arc]); + assert(0 <= residual_arc_capacity_[Opposite(arc)]); + } +} + +template +void GenericMinCostFlow::SetArcFlow( + ArcIndex arc, ArcFlowType new_flow) { + assert(IsArcValid(arc)); + const FlowQuantity capacity = Capacity(arc); + assert(capacity >= new_flow); + residual_arc_capacity_.Set(Opposite(arc), new_flow); + residual_arc_capacity_.Set(arc, capacity - new_flow); + status_ = NOT_SOLVED; + feasibility_checked_ = false; +} + +template +bool GenericMinCostFlow::CheckInputConsistency() const { + FlowQuantity total_supply = 0; + uint64_t max_capacity = 0; // uint64_t because it is positive and will be + // used to check against FlowQuantity overflows. + for (ArcIndex arc = 0; arc < graph_->num_arcs(); ++arc) { + const uint64_t capacity = + static_cast(residual_arc_capacity_[arc]); + max_capacity = std::max(capacity, max_capacity); + } + uint64_t total_flow = 0; // uint64_t for the same reason as max_capacity. + for (NodeIndex node = 0; node < graph_->num_nodes(); ++node) { + const FlowQuantity excess = node_excess_[node]; + total_supply += excess; + if (excess > 0) { + total_flow += excess; + if (std::numeric_limits::max() < + max_capacity + total_flow) { + pyre::journal::firewall_t channel("isce3.unwrap.ortools.min_cost_flow"); + channel << pyre::journal::at(__HERE__) + << "Input consistency error: max capacity + flow exceed " + "precision" + << pyre::journal::endl; + return false; + } + } + } + if (total_supply != 0) { + pyre::journal::firewall_t channel("isce3.unwrap.ortools.min_cost_flow"); + channel << pyre::journal::at(__HERE__) + << "Input consistency error: unbalanced problem" + << pyre::journal::endl; + return false; + } + return true; +} + +template +bool GenericMinCostFlow::CheckResult() + const { + for (NodeIndex node = 0; node < graph_->num_nodes(); ++node) { + if (node_excess_[node] != 0) { + pyre::journal::firewall_t channel("isce3.unwrap.ortools.min_cost_flow"); + channel << pyre::journal::at(__HERE__) + << "node_excess_[" << node << "] != 0" + << pyre::journal::endl; + return false; + } + for (OutgoingOrOppositeIncomingArcIterator it(*graph_, node); it.Ok(); + it.Next()) { + const ArcIndex arc = it.Index(); + bool ok = true; + if (residual_arc_capacity_[arc] < 0) { + pyre::journal::firewall_t channel("isce3.unwrap.ortools.min_cost_flow"); + channel << pyre::journal::at(__HERE__) + << "residual_arc_capacity_[" << arc << "] < 0" + << pyre::journal::endl; + ok = false; + } + if (residual_arc_capacity_[arc] > 0 && ReducedCost(arc) < -epsilon_) { + pyre::journal::firewall_t channel("isce3.unwrap.ortools.min_cost_flow"); + channel << pyre::journal::at(__HERE__) + << "residual_arc_capacity_[" << arc + << "] > 0 && ReducedCost(" << arc << ") < " << -epsilon_ + << ". (epsilon_ = " << epsilon_ << ")." + << pyre::journal::endl; + ok = false; + } + if (!ok) { + pyre::journal::firewall_t channel("isce3.unwrap.ortools.min_cost_flow"); + channel << pyre::journal::at(__HERE__) + << DebugString("CheckResult ", arc) + << pyre::journal::endl; + return false; + } + } + } + return true; +} + +template +bool GenericMinCostFlow::CheckCostRange() + const { + CostValue min_cost_magnitude = std::numeric_limits::max(); + CostValue max_cost_magnitude = 0; + // Traverse the initial arcs of the graph: + for (ArcIndex arc = 0; arc < graph_->num_arcs(); ++arc) { + const CostValue cost_magnitude = std::abs(scaled_arc_unit_cost_[arc]); + max_cost_magnitude = std::max(max_cost_magnitude, cost_magnitude); + if (cost_magnitude != 0.0) { + min_cost_magnitude = std::min(min_cost_magnitude, cost_magnitude); + } + } + pyre::journal::debug_t debug("isce3.unwrap.ortools.min_cost_flow"); + debug << pyre::journal::at(__HERE__) + << "Min cost magnitude = " << min_cost_magnitude + << ", Max cost magnitude = " << max_cost_magnitude + << pyre::journal::endl; +#if !defined(_MSC_VER) + if (log(std::numeric_limits::max()) < + log(max_cost_magnitude + 1) + log(graph_->num_nodes() + 1)) { + pyre::journal::firewall_t firewall("isce3.unwrap.ortools.min_cost_flow"); + firewall << pyre::journal::at(__HERE__) + << "Maximum cost magnitude " << max_cost_magnitude + << " is too high for the number of nodes. Try changing the data." + << pyre::journal::endl; + return false; + } +#endif + return true; +} + +template +bool GenericMinCostFlow:: + CheckRelabelPrecondition(NodeIndex node) const { + // Note that the classical Relabel precondition assumes IsActive(node), i.e., + // the node_excess_[node] > 0. However, to implement the Push Look-Ahead + // heuristic, we can relax this condition as explained in the section 4.3 of + // the article "An Efficient Implementation of a Scaling Minimum-Cost Flow + // Algorithm", A.V. Goldberg, Journal of Algorithms 22(1), January 1997, pp. + // 1-29. + assert(node_excess_[node] >= 0); + for (OutgoingOrOppositeIncomingArcIterator it(*graph_, node); it.Ok(); + it.Next()) { +#ifndef NDEBUG + const ArcIndex arc = it.Index(); + assert(!IsAdmissible(arc) && DebugString("CheckRelabelPrecondition:", arc)); +#endif + } + return true; +} + +template +std::string +GenericMinCostFlow::DebugString( + const std::string& context, ArcIndex arc) const { + const NodeIndex tail = Tail(arc); + const NodeIndex head = Head(arc); + // Reduced cost is computed directly without calling ReducedCost to avoid + // recursive calls between ReducedCost and DebugString in case a DCHECK in + // ReducedCost fails. + const CostValue reduced_cost = scaled_arc_unit_cost_[arc] + + node_potential_[tail] - node_potential_[head]; + return std::string(context) + " Arc " + std::to_string(arc) + + ", from " + std::to_string(tail) + " to " + std::to_string(head) + + ", Capacity = " + std::to_string(Capacity(arc)) + ", Residual capacity = " + + std::to_string(static_cast(residual_arc_capacity_[arc])) + + ", Flow = residual capacity for reverse arc = " + std::to_string(Flow(arc)) + + ", Height(tail) = " + std::to_string(node_potential_[tail]) + + ", Height(head) = " + std::to_string(node_potential_[head]) + + ", Excess(tail) = " + std::to_string(node_excess_[tail]) + + ", Excess(head) = " + std::to_string(node_excess_[head]) + ", Cost = " + + std::to_string(static_cast(scaled_arc_unit_cost_[arc])) + + ", Reduced cost = " + std::to_string(reduced_cost) + ", "; +} + +template +bool GenericMinCostFlow:: + CheckFeasibility(std::vector* const infeasible_supply_node, + std::vector* const infeasible_demand_node) { + // Create a new graph, which is a copy of graph_, with the following + // modifications: + // Two nodes are added: a source and a sink. + // The source is linked to each supply node (whose supply > 0) by an arc whose + // capacity is equal to the supply at the supply node. + // The sink is linked to each demand node (whose supply < 0) by an arc whose + // capacity is the demand (-supply) at the demand node. + // There are no supplies or demands or costs in the graph, as we will run + // max-flow. + // TODO(user): make it possible to share a graph by MaxFlow and MinCostFlow. + // For this it is necessary to make StarGraph resizable. + feasibility_checked_ = false; + ArcIndex num_extra_arcs = 0; + for (NodeIndex node = 0; node < graph_->num_nodes(); ++node) { + if (initial_node_excess_[node] != 0) { + ++num_extra_arcs; + } + } + const NodeIndex num_nodes_in_max_flow = graph_->num_nodes() + 2; + const ArcIndex num_arcs_in_max_flow = graph_->num_arcs() + num_extra_arcs; + const NodeIndex source = num_nodes_in_max_flow - 2; + const NodeIndex sink = num_nodes_in_max_flow - 1; + StarGraph checker_graph(num_nodes_in_max_flow, num_arcs_in_max_flow); + MaxFlow checker(&checker_graph, source, sink); + checker.SetCheckInput(false); + checker.SetCheckResult(false); + // Copy graph_ to checker_graph. + for (ArcIndex arc = 0; arc < graph_->num_arcs(); ++arc) { + const ArcIndex new_arc = + checker_graph.AddArc(graph_->Tail(arc), graph_->Head(arc)); + assert(arc == new_arc); + checker.SetArcCapacity(new_arc, Capacity(arc)); + } + FlowQuantity total_demand = 0; + FlowQuantity total_supply = 0; + // Create the source-to-supply node arcs and the demand-node-to-sink arcs. + for (NodeIndex node = 0; node < graph_->num_nodes(); ++node) { + const FlowQuantity supply = initial_node_excess_[node]; + if (supply > 0) { + const ArcIndex new_arc = checker_graph.AddArc(source, node); + checker.SetArcCapacity(new_arc, supply); + total_supply += supply; + } else if (supply < 0) { + const ArcIndex new_arc = checker_graph.AddArc(node, sink); + checker.SetArcCapacity(new_arc, -supply); + total_demand -= supply; + } + } + if (total_supply != total_demand) { + pyre::journal::firewall_t channel("isce3.unwrap.ortools.min_cost_flow"); + channel << pyre::journal::at(__HERE__) + << "total_supply(" << total_supply << ") != total_demand(" + << total_demand << ")." + << pyre::journal::endl; + return false; + } + if (!checker.Solve()) { + pyre::journal::firewall_t channel("isce3.unwrap.ortools.min_cost_flow"); + channel << pyre::journal::at(__HERE__) + << "Max flow could not be computed." + << pyre::journal::endl; + return false; + } + const FlowQuantity optimal_max_flow = checker.GetOptimalFlow(); + feasible_node_excess_.SetAll(0); + for (StarGraph::OutgoingArcIterator it(checker_graph, source); it.Ok(); + it.Next()) { + const ArcIndex arc = it.Index(); + const NodeIndex node = checker_graph.Head(arc); + const FlowQuantity flow = checker.Flow(arc); + feasible_node_excess_.Set(node, flow); + if (infeasible_supply_node != nullptr) { + infeasible_supply_node->push_back(node); + } + } + for (StarGraph::IncomingArcIterator it(checker_graph, sink); it.Ok(); + it.Next()) { + const ArcIndex arc = it.Index(); + const NodeIndex node = checker_graph.Tail(arc); + const FlowQuantity flow = checker.Flow(arc); + feasible_node_excess_.Set(node, -flow); + if (infeasible_demand_node != nullptr) { + infeasible_demand_node->push_back(node); + } + } + feasibility_checked_ = true; + return optimal_max_flow == total_supply; +} + +template +bool GenericMinCostFlow::MakeFeasible() { + if (!feasibility_checked_) { + return false; + } + for (NodeIndex node = 0; node < graph_->num_nodes(); ++node) { + const FlowQuantity excess = feasible_node_excess_[node]; + node_excess_.Set(node, excess); + initial_node_excess_.Set(node, excess); + } + return true; +} + +template +FlowQuantity GenericMinCostFlow::Flow( + ArcIndex arc) const { + if (IsArcDirect(arc)) { + return residual_arc_capacity_[Opposite(arc)]; + } else { + return -residual_arc_capacity_[arc]; + } +} + +// We use the equations given in the comment of residual_arc_capacity_. +template +FlowQuantity +GenericMinCostFlow::Capacity( + ArcIndex arc) const { + if (IsArcDirect(arc)) { + return residual_arc_capacity_[arc] + residual_arc_capacity_[Opposite(arc)]; + } else { + return 0; + } +} + +template +CostValue GenericMinCostFlow::UnitCost( + ArcIndex arc) const { + assert(IsArcValid(arc)); + assert(uint64_t{1} == cost_scaling_factor_); + return scaled_arc_unit_cost_[arc]; +} + +template +FlowQuantity GenericMinCostFlow::Supply( + NodeIndex node) const { + assert(graph_->IsNodeValid(node)); + return node_excess_[node]; +} + +template +FlowQuantity +GenericMinCostFlow::InitialSupply( + NodeIndex node) const { + return initial_node_excess_[node]; +} + +template +FlowQuantity +GenericMinCostFlow::FeasibleSupply( + NodeIndex node) const { + return feasible_node_excess_[node]; +} + +template +bool GenericMinCostFlow::IsAdmissible( + ArcIndex arc) const { + return FastIsAdmissible(arc, node_potential_[Tail(arc)]); +} + +template +bool GenericMinCostFlow:: + FastIsAdmissible(ArcIndex arc, CostValue tail_potential) const { + assert(node_potential_[Tail(arc)] == tail_potential); + return residual_arc_capacity_[arc] > 0 && + FastReducedCost(arc, tail_potential) < 0; +} + +template +bool GenericMinCostFlow::IsActive( + NodeIndex node) const { + return node_excess_[node] > 0; +} + +template +CostValue +GenericMinCostFlow::ReducedCost( + ArcIndex arc) const { + return FastReducedCost(arc, node_potential_[Tail(arc)]); +} + +template +CostValue +GenericMinCostFlow::FastReducedCost( + ArcIndex arc, CostValue tail_potential) const { + assert(node_potential_[Tail(arc)] == tail_potential); + assert(graph_->IsNodeValid(Tail(arc))); + assert(graph_->IsNodeValid(Head(arc))); + assert(node_potential_[Tail(arc)] <= 0 && DebugString("ReducedCost:", arc)); + assert(node_potential_[Head(arc)] <= 0 && DebugString("ReducedCost:", arc)); + return scaled_arc_unit_cost_[arc] + tail_potential - + node_potential_[Head(arc)]; +} + +template +typename GenericMinCostFlow::ArcIndex +GenericMinCostFlow:: + GetFirstOutgoingOrOppositeIncomingArc(NodeIndex node) const { + OutgoingOrOppositeIncomingArcIterator arc_it(*graph_, node); + return arc_it.Index(); +} + +template +bool GenericMinCostFlow::Solve() { + status_ = NOT_SOLVED; + if (!CheckInputConsistency()) { + status_ = UNBALANCED; + return false; + } + if (!CheckCostRange()) { + status_ = BAD_COST_RANGE; + return false; + } + if (check_feasibility_ && !CheckFeasibility(nullptr, nullptr)) { + status_ = INFEASIBLE; + return false; + } + node_potential_.SetAll(0); + ResetFirstAdmissibleArcs(); + ScaleCosts(); + Optimize(); + if (!CheckResult()) { + status_ = BAD_RESULT; + UnscaleCosts(); + return false; + } + UnscaleCosts(); + if (status_ != OPTIMAL) { + pyre::journal::firewall_t channel("isce3.unwrap.ortools.min_cost_flow"); + channel << pyre::journal::at(__HERE__) + << "Status != OPTIMAL" + << pyre::journal::endl; + total_flow_cost_ = 0; + return false; + } + total_flow_cost_ = 0; + for (ArcIndex arc = 0; arc < graph_->num_arcs(); ++arc) { + const FlowQuantity flow_on_arc = residual_arc_capacity_[Opposite(arc)]; + total_flow_cost_ += scaled_arc_unit_cost_[arc] * flow_on_arc; + } + status_ = OPTIMAL; + return true; +} + +template +void GenericMinCostFlow::ResetFirstAdmissibleArcs() { + for (NodeIndex node = 0; node < graph_->num_nodes(); ++node) { + first_admissible_arc_.Set(node, + GetFirstOutgoingOrOppositeIncomingArc(node)); + } +} + +template +void GenericMinCostFlow::ScaleCosts() { + cost_scaling_factor_ = graph_->num_nodes() + 1; + epsilon_ = 1LL; + pyre::journal::debug_t channel("isce3.unwrap.ortools.min_cost_flow"); + channel << pyre::journal::at(__HERE__) + << "Number of nodes in the graph = " << graph_->num_nodes() + << pyre::journal::endl + << "Number of arcs in the graph = " << graph_->num_arcs() + << pyre::journal::endl; + for (ArcIndex arc = 0; arc < graph_->num_arcs(); ++arc) { + const CostValue cost = scaled_arc_unit_cost_[arc] * cost_scaling_factor_; + scaled_arc_unit_cost_.Set(arc, cost); + scaled_arc_unit_cost_.Set(Opposite(arc), -cost); + epsilon_ = std::max(epsilon_, std::abs(cost)); + } + channel << pyre::journal::at(__HERE__) + << "Cost scaling factor = " << cost_scaling_factor_ + << pyre::journal::endl + << "Initial epsilon = " << epsilon_ + << pyre::journal::endl; +} + +template +void GenericMinCostFlow::UnscaleCosts() { + for (ArcIndex arc = 0; arc < graph_->num_arcs(); ++arc) { + const CostValue cost = scaled_arc_unit_cost_[arc] / cost_scaling_factor_; + scaled_arc_unit_cost_.Set(arc, cost); + scaled_arc_unit_cost_.Set(Opposite(arc), -cost); + } + cost_scaling_factor_ = 1; +} + +template +void GenericMinCostFlow::Optimize() { + const CostValue kEpsilonMin = 1LL; + num_relabels_since_last_price_update_ = 0; + do { + // Avoid epsilon_ == 0. + epsilon_ = std::max(epsilon_ / alpha_, kEpsilonMin); + pyre::journal::debug_t channel("isce3.unwrap.ortools.min_cost_flow"); + channel << pyre::journal::at(__HERE__) + << "Epsilon changed to: " << epsilon_ + << pyre::journal::endl; + Refine(); + } while (epsilon_ != 1LL && status_ != INFEASIBLE); + if (status_ == NOT_SOLVED) { + status_ = OPTIMAL; + } +} + +template +void GenericMinCostFlow::SaturateAdmissibleArcs() { + for (NodeIndex node = 0; node < graph_->num_nodes(); ++node) { + const CostValue tail_potential = node_potential_[node]; + for (OutgoingOrOppositeIncomingArcIterator it(*graph_, node, + first_admissible_arc_[node]); + it.Ok(); it.Next()) { + const ArcIndex arc = it.Index(); + if (FastIsAdmissible(arc, tail_potential)) { + FastPushFlow(residual_arc_capacity_[arc], arc, node); + } + } + + // We just saturated all the admissible arcs, so there are no arcs with a + // positive residual capacity that are incident to the current node. + // Moreover, during the course of the algorithm, if the residual capacity of + // such an arc becomes positive again, then the arc is still not admissible + // until we relabel the node (because the reverse arc was admissible for + // this to happen). In conclusion, the optimization below is correct. + first_admissible_arc_[node] = Graph::kNilArc; + } +} + +template +void GenericMinCostFlow::PushFlow( + FlowQuantity flow, ArcIndex arc) { + FastPushFlow(flow, arc, Tail(arc)); +} + +template +void GenericMinCostFlow::FastPushFlow( + FlowQuantity flow, ArcIndex arc, NodeIndex tail) { + assert(Tail(arc) == tail); + assert(residual_arc_capacity_[arc] > 0); + assert(flow <= residual_arc_capacity_[arc]); + // Reduce the residual capacity on the arc by flow. + residual_arc_capacity_.Set(arc, residual_arc_capacity_[arc] - flow); + // Increase the residual capacity on the opposite arc by flow. + const ArcIndex opposite = Opposite(arc); + residual_arc_capacity_.Set(opposite, residual_arc_capacity_[opposite] + flow); + // Update the excesses at the tail and head of the arc. + node_excess_.Set(tail, node_excess_[tail] - flow); + const NodeIndex head = Head(arc); + node_excess_.Set(head, node_excess_[head] + flow); +} + +template +void GenericMinCostFlow::InitializeActiveNodeStack() { + assert(active_nodes_.empty()); + for (NodeIndex node = 0; node < graph_->num_nodes(); ++node) { + if (IsActive(node)) { + active_nodes_.push(node); + } + } +} + +template +void GenericMinCostFlow::UpdatePrices() { + + // The algorithm works as follows. Start with a set of nodes S containing all + // the nodes with negative excess. Expand the set along reverse admissible + // arcs. If at the end, the complement of S contains at least one node with + // positive excess, relabel all the nodes in the complement of S by + // subtracting epsilon from their current potential. See the paper cited in + // the .h file. + // + // After this relabeling is done, the heuristic is reapplied by extending S as + // much as possible, relabeling the complement of S, and so on until there is + // no node with positive excess that is not in S. Note that this is not + // described in the paper. + // + // Note(user): The triggering mechanism of this UpdatePrices() is really + // important; if it is not done properly it may degrade performance! + + // This represents the set S. + const NodeIndex num_nodes = graph_->num_nodes(); + std::vector bfs_queue; + std::vector node_in_queue(num_nodes, false); + + // This is used to update the potential of the nodes not in S. + const CostValue kMinCostValue = std::numeric_limits::min(); + std::vector min_non_admissible_potential(num_nodes, kMinCostValue); + std::vector nodes_to_process; + + // Sum of the positive excesses out of S, used for early exit. + FlowQuantity remaining_excess = 0; + + // First consider the nodes which have a negative excess. + for (NodeIndex node = 0; node < num_nodes; ++node) { + if (node_excess_[node] < 0) { + bfs_queue.push_back(node); + node_in_queue[node] = true; + + // This uses the fact that the sum of excesses is always 0. + remaining_excess -= node_excess_[node]; + } + } + + // All the nodes not yet in the bfs_queue will have their potential changed by + // +potential_delta (which becomes more and more negative at each pass). This + // update is applied when a node is pushed into the queue and at the end of + // the function for the nodes that are still unprocessed. + CostValue potential_delta = 0; + + int queue_index = 0; + while (remaining_excess > 0) { + // Reverse BFS that expands S as much as possible in the reverse admissible + // graph. Once S cannot be expanded anymore, perform a relabeling on the + // nodes not in S but that can reach it in one arc and try to expand S + // again. + for (; queue_index < bfs_queue.size(); ++queue_index) { + assert(num_nodes >= bfs_queue.size()); + const NodeIndex node = bfs_queue[queue_index]; + for (OutgoingOrOppositeIncomingArcIterator it(*graph_, node); it.Ok(); + it.Next()) { + const NodeIndex head = Head(it.Index()); + if (node_in_queue[head]) continue; + const ArcIndex opposite_arc = Opposite(it.Index()); + if (residual_arc_capacity_[opposite_arc] > 0) { + node_potential_[head] += potential_delta; + if (ReducedCost(opposite_arc) < 0) { + assert(IsAdmissible(opposite_arc)); + + // TODO(user): Try to steal flow if node_excess_[head] > 0. + // An initial experiment didn't show a big speedup though. + + remaining_excess -= node_excess_[head]; + if (remaining_excess == 0) { + node_potential_[head] -= potential_delta; + break; + } + bfs_queue.push_back(head); + node_in_queue[head] = true; + if (potential_delta < 0) { + first_admissible_arc_[head] = + GetFirstOutgoingOrOppositeIncomingArc(head); + } + } else { + // The opposite_arc is not admissible but is in the residual graph; + // this updates its min_non_admissible_potential. + node_potential_[head] -= potential_delta; + if (min_non_admissible_potential[head] == kMinCostValue) { + nodes_to_process.push_back(head); + } + min_non_admissible_potential[head] = std::max( + min_non_admissible_potential[head], + node_potential_[node] - scaled_arc_unit_cost_[opposite_arc]); + } + } + } + if (remaining_excess == 0) break; + } + if (remaining_excess == 0) break; + + // Decrease by as much as possible instead of decreasing by epsilon. + // TODO(user): Is it worth the extra loop? + CostValue max_potential_diff = kMinCostValue; + for (int i = 0; i < nodes_to_process.size(); ++i) { + const NodeIndex node = nodes_to_process[i]; + if (node_in_queue[node]) continue; + max_potential_diff = + std::max(max_potential_diff, + min_non_admissible_potential[node] - node_potential_[node]); + if (max_potential_diff == potential_delta) break; + } + assert(max_potential_diff <= potential_delta); + potential_delta = max_potential_diff - epsilon_; + + // Loop over nodes_to_process_ and for each node, apply the first of the + // rules below that match or leave it in the queue for later iteration: + // - Remove it if it is already in the queue. + // - If the node is connected to S by an admissible arc after it is + // relabeled by +potential_delta, add it to bfs_queue_ and remove it from + // nodes_to_process. + int index = 0; + for (int i = 0; i < nodes_to_process.size(); ++i) { + const NodeIndex node = nodes_to_process[i]; + if (node_in_queue[node]) continue; + if (node_potential_[node] + potential_delta < + min_non_admissible_potential[node]) { + node_potential_[node] += potential_delta; + first_admissible_arc_[node] = + GetFirstOutgoingOrOppositeIncomingArc(node); + bfs_queue.push_back(node); + node_in_queue[node] = true; + remaining_excess -= node_excess_[node]; + continue; + } + + // Keep the node for later iteration. + nodes_to_process[index] = node; + ++index; + } + nodes_to_process.resize(index); + } + + // Update the potentials of the nodes not yet processed. + if (potential_delta == 0) return; + for (NodeIndex node = 0; node < num_nodes; ++node) { + if (!node_in_queue[node]) { + node_potential_[node] += potential_delta; + first_admissible_arc_[node] = GetFirstOutgoingOrOppositeIncomingArc(node); + } + } +} + +template +void GenericMinCostFlow::Refine() { + SaturateAdmissibleArcs(); + InitializeActiveNodeStack(); + + const NodeIndex num_nodes = graph_->num_nodes(); + while (status_ != INFEASIBLE && !active_nodes_.empty()) { + // TODO(user): Experiment with different factors in front of num_nodes. + if (num_relabels_since_last_price_update_ >= num_nodes) { + num_relabels_since_last_price_update_ = 0; + if (use_price_update_) { + UpdatePrices(); + } + } + const NodeIndex node = active_nodes_.top(); + active_nodes_.pop(); + assert(IsActive(node)); + Discharge(node); + } +} + +template +void GenericMinCostFlow::Discharge( + NodeIndex node) { + do { + // The node is initially active, and we exit as soon as it becomes + // inactive. + assert(IsActive(node)); + const CostValue tail_potential = node_potential_[node]; + for (OutgoingOrOppositeIncomingArcIterator it(*graph_, node, + first_admissible_arc_[node]); + it.Ok(); it.Next()) { + const ArcIndex arc = it.Index(); + if (FastIsAdmissible(arc, tail_potential)) { + const NodeIndex head = Head(arc); + if (!LookAhead(arc, tail_potential, head)) continue; + const bool head_active_before_push = IsActive(head); + const FlowQuantity delta = + std::min(node_excess_[node], + static_cast(residual_arc_capacity_[arc])); + FastPushFlow(delta, arc, node); + if (IsActive(head) && !head_active_before_push) { + active_nodes_.push(head); + } + if (node_excess_[node] == 0) { + // arc may still be admissible. + first_admissible_arc_.Set(node, arc); + return; + } + } + } + Relabel(node); + } while (status_ != INFEASIBLE); +} + +template +bool GenericMinCostFlow::LookAhead( + ArcIndex in_arc, CostValue in_tail_potential, NodeIndex node) { + assert(Head(in_arc) == node); + assert(node_potential_[Tail(in_arc)] == in_tail_potential); + if (node_excess_[node] < 0) return true; + const CostValue tail_potential = node_potential_[node]; + for (OutgoingOrOppositeIncomingArcIterator it(*graph_, node, + first_admissible_arc_[node]); + it.Ok(); it.Next()) { + const ArcIndex arc = it.Index(); + if (FastIsAdmissible(arc, tail_potential)) { + first_admissible_arc_.Set(node, arc); + return true; + } + } + + // The node we looked ahead has no admissible arc at its current potential. + // We relabel it and return true if the original arc is still admissible. + Relabel(node); + return FastIsAdmissible(in_arc, in_tail_potential); +} + +template +void GenericMinCostFlow::Relabel( + NodeIndex node) { + assert(CheckRelabelPrecondition(node)); + ++num_relabels_since_last_price_update_; + + // By setting node_potential_[node] to the guaranteed_new_potential we are + // sure to keep epsilon-optimality of the pseudo-flow. Note that we could + // return right away with this value, but we prefer to check that this value + // will lead to at least one admissible arc, and if not, to decrease the + // potential as much as possible. + const CostValue guaranteed_new_potential = node_potential_[node] - epsilon_; + + // This will be updated to contain the minimum node potential for which + // the node has no admissible arc. We know that: + // - min_non_admissible_potential <= node_potential_[node] + // - We can set the new node potential to min_non_admissible_potential - + // epsilon_ and still keep the epsilon-optimality of the pseudo flow. + const CostValue kMinCostValue = std::numeric_limits::min(); + CostValue min_non_admissible_potential = kMinCostValue; + + // The following variables help setting the first_admissible_arc_[node] to a + // value different from GetFirstOutgoingOrOppositeIncomingArc(node) which + // avoids looking again at some arcs. + CostValue previous_min_non_admissible_potential = kMinCostValue; + ArcIndex first_arc = Graph::kNilArc; + + for (OutgoingOrOppositeIncomingArcIterator it(*graph_, node); it.Ok(); + it.Next()) { + const ArcIndex arc = it.Index(); + if (residual_arc_capacity_[arc] > 0) { + const CostValue min_non_admissible_potential_for_arc = + node_potential_[Head(arc)] - scaled_arc_unit_cost_[arc]; + if (min_non_admissible_potential_for_arc > min_non_admissible_potential) { + if (min_non_admissible_potential_for_arc > guaranteed_new_potential) { + // We found an admissible arc for the guaranteed_new_potential. We + // stop right now instead of trying to compute the minimum possible + // new potential that keeps the epsilon-optimality of the pseudo flow. + node_potential_.Set(node, guaranteed_new_potential); + first_admissible_arc_.Set(node, arc); + return; + } + previous_min_non_admissible_potential = min_non_admissible_potential; + min_non_admissible_potential = min_non_admissible_potential_for_arc; + first_arc = arc; + } + } + } + + // No admissible arc leaves this node! + if (min_non_admissible_potential == kMinCostValue) { + if (node_excess_[node] != 0) { + // Note that this infeasibility detection is incomplete. + // Only max flow can detect that a min-cost flow problem is infeasible. + status_ = INFEASIBLE; + pyre::journal::info_t channel("isce3.unwrap.ortools.min_cost_flow"); + channel << pyre::journal::at(__HERE__) << "Infeasible problem." + << pyre::journal::endl; + } else { + // This source saturates all its arcs, we can actually decrease the + // potential by as much as we want. + // TODO(user): Set it to a minimum value, but be careful of overflow. + node_potential_.Set(node, guaranteed_new_potential); + first_admissible_arc_.Set(node, + GetFirstOutgoingOrOppositeIncomingArc(node)); + } + return; + } + + // We decrease the potential as much as possible, but we do not know the first + // admissible arc (most of the time). Keeping the + // previous_min_non_admissible_potential makes it faster by a few percent. + const CostValue new_potential = min_non_admissible_potential - epsilon_; + node_potential_.Set(node, new_potential); + if (previous_min_non_admissible_potential <= new_potential) { + first_admissible_arc_.Set(node, first_arc); + } else { + // We have no indication of what may be the first admissible arc. + first_admissible_arc_.Set(node, + GetFirstOutgoingOrOppositeIncomingArc(node)); + } +} + +template +typename Graph::ArcIndex +GenericMinCostFlow::Opposite( + ArcIndex arc) const { + return Graphs::OppositeArc(*graph_, arc); +} + +template +bool GenericMinCostFlow::IsArcValid( + ArcIndex arc) const { + return Graphs::IsArcValid(*graph_, arc); +} + +template +bool GenericMinCostFlow::IsArcDirect( + ArcIndex arc) const { + assert(IsArcValid(arc)); + return arc >= 0; +} + +// Explicit instantiations that can be used by a client. +// +// TODO(user): Move this code out of a .cc file and include it at the end of +// the header so it can work with any graph implementation? +template class GenericMinCostFlow; +template class GenericMinCostFlow<::util::ReverseArcListGraph<>>; +template class GenericMinCostFlow<::util::ReverseArcStaticGraph<>>; +template class GenericMinCostFlow<::util::ReverseArcMixedGraph<>>; +template class GenericMinCostFlow< + ::util::ReverseArcStaticGraph>; + +// A more memory-efficient version for large graphs. +template class GenericMinCostFlow< + ::util::ReverseArcStaticGraph, + /*ArcFlowType=*/int16_t, + /*ArcScaledCostType=*/int32_t>; + +SimpleMinCostFlow::SimpleMinCostFlow(NodeIndex reserve_num_nodes, + ArcIndex reserve_num_arcs) { + if (reserve_num_nodes > 0) { + node_supply_.reserve(reserve_num_nodes); + } + if (reserve_num_arcs > 0) { + arc_tail_.reserve(reserve_num_arcs); + arc_head_.reserve(reserve_num_arcs); + arc_capacity_.reserve(reserve_num_arcs); + arc_cost_.reserve(reserve_num_arcs); + arc_permutation_.reserve(reserve_num_arcs); + arc_flow_.reserve(reserve_num_arcs); + } +} + +void SimpleMinCostFlow::SetNodeSupply(NodeIndex node, FlowQuantity supply) { + ResizeNodeVectors(node); + node_supply_[node] = supply; +} + +ArcIndex SimpleMinCostFlow::AddArcWithCapacityAndUnitCost(NodeIndex tail, + NodeIndex head, + FlowQuantity capacity, + CostValue unit_cost) { + ResizeNodeVectors(std::max(tail, head)); + const ArcIndex arc = arc_tail_.size(); + arc_tail_.push_back(tail); + arc_head_.push_back(head); + arc_capacity_.push_back(capacity); + arc_cost_.push_back(unit_cost); + return arc; +} + +ArcIndex SimpleMinCostFlow::PermutedArc(ArcIndex arc) { + return arc < arc_permutation_.size() ? arc_permutation_[arc] : arc; +} + +SimpleMinCostFlow::Status SimpleMinCostFlow::SolveWithPossibleAdjustment( + SupplyAdjustment adjustment) { + optimal_cost_ = 0; + maximum_flow_ = 0; + arc_flow_.clear(); + const NodeIndex num_nodes = node_supply_.size(); + const ArcIndex num_arcs = arc_capacity_.size(); + if (num_nodes == 0) return OPTIMAL; + + int supply_node_count = 0, demand_node_count = 0; + FlowQuantity total_supply = 0, total_demand = 0; + for (NodeIndex node = 0; node < num_nodes; ++node) { + if (node_supply_[node] > 0) { + ++supply_node_count; + total_supply += node_supply_[node]; + } else if (node_supply_[node] < 0) { + ++demand_node_count; + total_demand -= node_supply_[node]; + } + } + if (adjustment == DONT_ADJUST && total_supply != total_demand) { + return UNBALANCED; + } + + // Feasibility checking, and possible supply/demand adjustment, is done by: + // 1. Creating a new source and sink node. + // 2. Taking all nodes that have a non-zero supply or demand and + // connecting them to the source or sink respectively. The arc thus + // added has a capacity of the supply or demand. + // 3. Computing the max flow between the new source and sink. + // 4. If adjustment isn't being done, checking that the max flow is equal + // to the total supply/demand (and returning INFEASIBLE if it isn't). + // 5. Running min-cost max-flow on this augmented graph, using the max + // flow computed in step 3 as the supply of the source and demand of + // the sink. + const ArcIndex augmented_num_arcs = + num_arcs + supply_node_count + demand_node_count; + const NodeIndex source = num_nodes; + const NodeIndex sink = num_nodes + 1; + const NodeIndex augmented_num_nodes = num_nodes + 2; + + Graph graph(augmented_num_nodes, augmented_num_arcs); + for (ArcIndex arc = 0; arc < num_arcs; ++arc) { + graph.AddArc(arc_tail_[arc], arc_head_[arc]); + } + + for (NodeIndex node = 0; node < num_nodes; ++node) { + if (node_supply_[node] > 0) { + graph.AddArc(source, node); + } else if (node_supply_[node] < 0) { + graph.AddArc(node, sink); + } + } + + graph.Build(&arc_permutation_); + + { + GenericMaxFlow max_flow(&graph, source, sink); + ArcIndex arc; + for (arc = 0; arc < num_arcs; ++arc) { + max_flow.SetArcCapacity(PermutedArc(arc), arc_capacity_[arc]); + } + for (NodeIndex node = 0; node < num_nodes; ++node) { + if (node_supply_[node] != 0) { + max_flow.SetArcCapacity(PermutedArc(arc), std::abs(node_supply_[node])); + ++arc; + } + } + if (arc != augmented_num_arcs) { + throw isce3::except::RuntimeError( + ISCE_SRCINFO(), "arc != augmented_num_arcs"); + } + if (!max_flow.Solve()) { + pyre::journal::info_t channel("isce3.unwrap.ortools.min_cost_flow"); + channel << pyre::journal::at(__HERE__) + << "Max flow could not be computed." + << pyre::journal::endl; + switch (max_flow.status()) { + case MaxFlowStatusClass::NOT_SOLVED: + return NOT_SOLVED; + case MaxFlowStatusClass::OPTIMAL: + channel << pyre::journal::at(__HERE__) + << "Max flow failed but claimed to have an optimal solution" + << pyre::journal::endl; + [[fallthrough]]; + default: + return BAD_RESULT; + } + } + maximum_flow_ = max_flow.GetOptimalFlow(); + } + + if (adjustment == DONT_ADJUST && maximum_flow_ != total_supply) { + return INFEASIBLE; + } + + GenericMinCostFlow min_cost_flow(&graph); + ArcIndex arc; + for (arc = 0; arc < num_arcs; ++arc) { + ArcIndex permuted_arc = PermutedArc(arc); + min_cost_flow.SetArcUnitCost(permuted_arc, arc_cost_[arc]); + min_cost_flow.SetArcCapacity(permuted_arc, arc_capacity_[arc]); + } + for (NodeIndex node = 0; node < num_nodes; ++node) { + if (node_supply_[node] != 0) { + ArcIndex permuted_arc = PermutedArc(arc); + min_cost_flow.SetArcCapacity(permuted_arc, std::abs(node_supply_[node])); + min_cost_flow.SetArcUnitCost(permuted_arc, 0); + ++arc; + } + } + min_cost_flow.SetNodeSupply(source, maximum_flow_); + min_cost_flow.SetNodeSupply(sink, -maximum_flow_); + min_cost_flow.SetCheckFeasibility(false); + + arc_flow_.resize(num_arcs); + if (min_cost_flow.Solve()) { + optimal_cost_ = min_cost_flow.GetOptimalCost(); + for (arc = 0; arc < num_arcs; ++arc) { + arc_flow_[arc] = min_cost_flow.Flow(PermutedArc(arc)); + } + } + return min_cost_flow.status(); +} + +CostValue SimpleMinCostFlow::OptimalCost() const { return optimal_cost_; } + +FlowQuantity SimpleMinCostFlow::MaximumFlow() const { return maximum_flow_; } + +FlowQuantity SimpleMinCostFlow::Flow(ArcIndex arc) const { + return arc_flow_[arc]; +} + +NodeIndex SimpleMinCostFlow::NumNodes() const { return node_supply_.size(); } + +ArcIndex SimpleMinCostFlow::NumArcs() const { return arc_tail_.size(); } + +ArcIndex SimpleMinCostFlow::Tail(ArcIndex arc) const { return arc_tail_[arc]; } + +ArcIndex SimpleMinCostFlow::Head(ArcIndex arc) const { return arc_head_[arc]; } + +FlowQuantity SimpleMinCostFlow::Capacity(ArcIndex arc) const { + return arc_capacity_[arc]; +} + +CostValue SimpleMinCostFlow::UnitCost(ArcIndex arc) const { + return arc_cost_[arc]; +} + +FlowQuantity SimpleMinCostFlow::Supply(NodeIndex node) const { + return node_supply_[node]; +} + +void SimpleMinCostFlow::ResizeNodeVectors(NodeIndex node) { + if (node < node_supply_.size()) return; + node_supply_.resize(node + 1); +} + +} // namespace operations_research diff --git a/cxx/isce3/unwrap/ortools/min_cost_flow.h b/cxx/isce3/unwrap/ortools/min_cost_flow.h new file mode 100644 index 000000000..904e0b4c4 --- /dev/null +++ b/cxx/isce3/unwrap/ortools/min_cost_flow.h @@ -0,0 +1,613 @@ +// Copyright 2010-2021 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// An implementation of a cost-scaling push-relabel algorithm for +// the min-cost flow problem. +// +// In the following, we consider a graph G = (V,E) where V denotes the set +// of nodes (vertices) in the graph, E denotes the set of arcs (edges). +// n = |V| denotes the number of nodes in the graph, and m = |E| denotes the +// number of arcs in the graph. +// +// With each arc (v,w) is associated a nonnegative capacity u(v,w) +// (where 'u' stands for "upper bound") and a unit cost c(v,w). With +// each node v is associated a quantity named supply(v), which +// represents a supply of fluid (if >0) or a demand (if <0). +// Furthermore, no fluid is created in the graph so +// sum_{v in V} supply(v) = 0. +// +// A flow is a function from E to R such that: +// a) f(v,w) <= u(v,w) for all (v,w) in E (capacity constraint). +// b) f(v,w) = -f(w,v) for all (v,w) in E (flow antisymmetry constraint). +// c) sum on v f(v,w) + supply(w) = 0 (flow conservation). +// +// The cost of a flow is sum on (v,w) in E ( f(v,w) * c(v,w) ) [Note: +// It can be confusing to beginners that the cost is actually double +// the amount that it might seem at first because of flow +// antisymmetry.] +// +// The problem to solve: find a flow of minimum cost such that all the +// fluid flows from the supply nodes to the demand nodes. +// +// The principles behind this algorithm are the following: +// 1/ handle pseudo-flows instead of flows and refine pseudo-flows until an +// epsilon-optimal minimum-cost flow is obtained, +// 2/ deal with epsilon-optimal pseudo-flows. +// +// 1/ A pseudo-flow is like a flow, except that a node's outflow minus +// its inflow can be different from its supply. If it is the case at a +// given node v, it is said that there is an excess (or deficit) at +// node v. A deficit is denoted by a negative excess and inflow = +// outflow + excess. +// (Look at ortools/graph/max_flow.h to see that the definition +// of preflow is more restrictive than the one for pseudo-flow in that a preflow +// only allows non-negative excesses, i.e., no deficit.) +// More formally, a pseudo-flow is a function f such that: +// a) f(v,w) <= u(v,w) for all (v,w) in E (capacity constraint). +// b) f(v,w) = -f(w,v) for all (v,w) in E (flow antisymmetry constraint). +// +// For each v in E, we also define the excess at node v, the algebraic sum of +// all the incoming preflows at this node, added together with the supply at v. +// excess(v) = sum on u f(u,v) + supply(v) +// +// The goal of the algorithm is to obtain excess(v) = 0 for all v in V, while +// consuming capacity on some arcs, at the lowest possible cost. +// +// 2/ Internally to the algorithm and its analysis (but invisibly to +// the client), each node has an associated "price" (or potential), in +// addition to its excess. It is formally a function from E to R (the +// set of real numbers.). For a given price function p, the reduced +// cost of an arc (v,w) is: +// c_p(v,w) = c(v,w) + p(v) - p(w) +// (c(v,w) is the cost of arc (v,w).) For those familiar with linear +// programming, the price function can be viewed as a set of dual +// variables. +// +// For a constant epsilon >= 0, a pseudo-flow f is said to be epsilon-optimal +// with respect to a price function p if for every residual arc (v,w) in E, +// c_p(v,w) >= -epsilon. +// +// A flow f is optimal if and only if there exists a price function p such that +// no arc is admissible with respect to f and p. +// +// If the arc costs are integers, and epsilon < 1/n, any epsilon-optimal flow +// is optimal. The integer cost case is handled by multiplying all the arc costs +// and the initial value of epsilon by (n+1). When epsilon reaches 1, and +// the solution is epsilon-optimal, it means: for all residual arc (v,w) in E, +// (n+1) * c_p(v,w) >= -1, thus c_p(v,w) >= -1/(n+1) >= 1/n, and the +// solution is optimal. +// +// A node v is said to be *active* if excess(v) > 0. +// In this case the following operations can be applied to it: +// - if there are *admissible* incident arcs, i.e. arcs which are not saturated, +// and whose reduced costs are negative, a PushFlow operation can +// be applied. It consists in sending as much flow as both the excess at the +// node and the capacity of the arc permit. +// - if there are no admissible arcs, the active node considered is relabeled, +// This is implemented in Discharge, which itself calls PushFlow and Relabel. +// +// Discharge itself is called by Refine. Refine first saturates all the +// admissible arcs, then builds a stack of active nodes. It then applies +// Discharge for each active node, possibly adding new ones in the process, +// until no nodes are active. In that case an epsilon-optimal flow is obtained. +// +// Optimize iteratively calls Refine, while epsilon > 1, and divides epsilon by +// alpha (set by default to 5) before each iteration. +// +// The algorithm starts with epsilon = C, where C is the maximum absolute value +// of the arc costs. In the integer case which we are dealing with, since all +// costs are multiplied by (n+1), the initial value of epsilon is (n+1)*C. +// The algorithm terminates when epsilon = 1, and Refine() has been called. +// In this case, a minimum-cost flow is obtained. +// +// The complexity of the algorithm is O(n^2*m*log(n*C)) where C is the value of +// the largest arc cost in the graph. +// +// IMPORTANT: +// The algorithm is not able to detect the infeasibility of a problem (i.e., +// when a bottleneck in the network prohibits sending all the supplies.) +// Worse, it could in some cases loop forever. This is why feasibility checking +// is enabled by default (FLAGS_min_cost_flow_check_feasibility=true.) +// Feasibility checking is implemented using a max-flow, which has a much lower +// complexity. The impact on performance is negligible, while the risk of being +// caught in an endless loop is removed. Note that using the feasibility checker +// roughly doubles the memory consumption. +// +// The starting reference for this class of algorithms is: +// A.V. Goldberg and R.E. Tarjan, "Finding Minimum-Cost Circulations by +// Successive Approximation." Mathematics of Operations Research, Vol. 15, +// 1990:430-466. +// http://portal.acm.org/citation.cfm?id=92225 +// +// Implementation issues are tackled in: +// A.V. Goldberg, "An Efficient Implementation of a Scaling Minimum-Cost Flow +// Algorithm," Journal of Algorithms, (1997) 22:1-29 +// http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.31.258 +// +// A.V. Goldberg and M. Kharitonov, "On Implementing Scaling Push-Relabel +// Algorithms for the Minimum-Cost Flow Problem", Network flows and matching: +// First DIMACS implementation challenge, DIMACS Series in Discrete Mathematics +// and Theoretical Computer Science, (1993) 12:157-198. +// ftp://dimacs.rutgers.edu/pub/netflow/submit/papers/Goldberg-mincost/scalmin.ps +// and in: +// U. Bunnagel, B. Korte, and J. Vygen. “Efficient implementation of the +// Goldberg-Tarjan minimum-cost flow algorithm.” Optimization Methods and +// Software (1998) vol. 10, no. 2:157-174. +// http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.84.9897 +// +// We have tried as much as possible in this implementation to keep the +// notations and namings of the papers cited above, except for 'demand' or +// 'balance' which have been replaced by 'supply', with the according sign +// changes to better accommodate with the API of the rest of our tools. A demand +// is denoted by a negative supply. +// +// TODO(user): See whether the following can bring any improvements on real-life +// problems. +// R.K. Ahuja, A.V. Goldberg, J.B. Orlin, and R.E. Tarjan, "Finding minimum-cost +// flows by double scaling," Mathematical Programming, (1992) 53:243-266. +// http://www.springerlink.com/index/gu7404218u6kt166.pdf +// +// An interesting general reference on network flows is: +// R. K. Ahuja, T. L. Magnanti, J. B. Orlin, "Network Flows: Theory, Algorithms, +// and Applications," Prentice Hall, 1993, ISBN: 978-0136175490, +// http://www.amazon.com/dp/013617549X +// +// Keywords: Push-relabel, min-cost flow, network, graph, Goldberg, Tarjan, +// Dinic, Dinitz. + +#ifndef OR_TOOLS_GRAPH_MIN_COST_FLOW_H_ +#define OR_TOOLS_GRAPH_MIN_COST_FLOW_H_ + +#include +#include +#include +#include +#include + +#include "ebert_graph.h" +#include "graph.h" +#include "zvector.h" + +namespace operations_research { + +// Forward declaration. +template +class GenericMinCostFlow; + +// Different statuses for a solved problem. +// We use a base class to share it between our different interfaces. +class MinCostFlowBase { + public: + enum Status { + NOT_SOLVED, + OPTIMAL, + FEASIBLE, + INFEASIBLE, + UNBALANCED, + BAD_RESULT, + BAD_COST_RANGE + }; +}; + +// A simple and efficient min-cost flow interface. This is as fast as +// GenericMinCostFlow, which is the fastest, but is uses +// more memory in order to hide the somewhat involved construction of the +// static graph. +// +// TODO(user): If the need arises, extend this interface to support warm start +// and incrementality between solves. Note that this is already supported by the +// GenericMinCostFlow<> interface. +class SimpleMinCostFlow : public MinCostFlowBase { + public: + // By default, the constructor takes no size. New node indices are created + // lazily by AddArcWithCapacityAndUnitCost() or SetNodeSupply() such that the + // set of valid nodes will always be [0, NumNodes()). + // + // You may pre-reserve the internal data structures with a given expected + // number of nodes and arcs, to potentially gain performance. + explicit SimpleMinCostFlow(NodeIndex reserve_num_nodes = 0, + ArcIndex reserve_num_arcs = 0); + + // Adds a directed arc from tail to head to the underlying graph with + // a given capacity and cost per unit of flow. + // * Node indices and the capacity must be non-negative (>= 0). + // * The unit cost can take any integer value (even negative). + // * Self-looping and duplicate arcs are supported. + // * After the method finishes, NumArcs() == the returned ArcIndex + 1. + ArcIndex AddArcWithCapacityAndUnitCost(NodeIndex tail, NodeIndex head, + FlowQuantity capacity, + CostValue unit_cost); + + // Sets the supply of the given node. The node index must be non-negative (>= + // 0). Nodes implicitly created will have a default supply set to 0. A demand + // is modeled as a negative supply. + void SetNodeSupply(NodeIndex node, FlowQuantity supply); + + // Solves the problem, and returns the problem status. This function + // requires that the sum of all node supply minus node demand is zero and + // that the graph has enough capacity to send all supplies and serve all + // demands. Otherwise, it will return INFEASIBLE. + Status Solve() { + return SolveWithPossibleAdjustment(SupplyAdjustment::DONT_ADJUST); + } + + // Same as Solve(), but does not have the restriction that the supply + // must match the demand or that the graph has enough capacity to serve + // all the demand or use all the supply. This will compute a maximum-flow + // with minimum cost. The value of the maximum-flow will be given by + // MaximumFlow(). + Status SolveMaxFlowWithMinCost() { + return SolveWithPossibleAdjustment(SupplyAdjustment::ADJUST); + } + + // Returns the cost of the minimum-cost flow found by the algorithm when + // the returned Status is OPTIMAL. + CostValue OptimalCost() const; + + // Returns the total flow of the minimum-cost flow found by the algorithm + // when the returned Status is OPTIMAL. + FlowQuantity MaximumFlow() const; + + // Returns the flow on arc, this only make sense for a successful Solve(). + // + // Note: It is possible that there is more than one optimal solution. The + // algorithm is deterministic so it will always return the same solution for + // a given problem. However, there is no guarantee of this from one code + // version to the next (but the code does not change often). + FlowQuantity Flow(ArcIndex arc) const; + + // Accessors for the user given data. The implementation will crash if "arc" + // is not in [0, NumArcs()) or "node" is not in [0, NumNodes()). + NodeIndex NumNodes() const; + ArcIndex NumArcs() const; + NodeIndex Tail(ArcIndex arc) const; + NodeIndex Head(ArcIndex arc) const; + FlowQuantity Capacity(ArcIndex arc) const; + FlowQuantity Supply(NodeIndex node) const; + CostValue UnitCost(ArcIndex arc) const; + + private: + typedef ::util::ReverseArcStaticGraph Graph; + enum SupplyAdjustment { ADJUST, DONT_ADJUST }; + + // Applies the permutation in arc_permutation_ to the given arc index. + ArcIndex PermutedArc(ArcIndex arc); + // Solves the problem, potentially applying supply and demand adjustment, + // and returns the problem status. + Status SolveWithPossibleAdjustment(SupplyAdjustment adjustment); + void ResizeNodeVectors(NodeIndex node); + + std::vector arc_tail_; + std::vector arc_head_; + std::vector arc_capacity_; + std::vector node_supply_; + std::vector arc_cost_; + std::vector arc_permutation_; + std::vector arc_flow_; + CostValue optimal_cost_; + FlowQuantity maximum_flow_; + + SimpleMinCostFlow(const SimpleMinCostFlow&); + SimpleMinCostFlow& operator=(const SimpleMinCostFlow&); +}; + +// Generic MinCostFlow that works with StarGraph and all the graphs handling +// reverse arcs from graph.h, see the end of min_cost_flow.cc for the exact +// types this class is compiled for. +// +// One can greatly decrease memory usage by using appropriately small integer +// types: +// - For the Graph<> types, i.e. NodeIndexType and ArcIndexType, see graph.h. +// - ArcFlowType is used for the *per-arc* flow quantity. It must be signed, and +// large enough to hold the maximum arc capacity and its negation. +// - ArcScaledCostType is used for a per-arc scaled cost. It must be signed +// and large enough to hold the maximum unit cost of an arc times +// (num_nodes + 1). +// +// Note that the latter two are different than FlowQuantity and CostValue, which +// are used for global, aggregated values and may need to be larger. +// +// TODO(user): Avoid using the globally defined type CostValue and FlowQuantity. +// Also uses the Arc*Type where there is no risk of overflow in more places. +template +class GenericMinCostFlow : public MinCostFlowBase { + public: + typedef typename Graph::NodeIndex NodeIndex; + typedef typename Graph::ArcIndex ArcIndex; + typedef typename Graph::OutgoingArcIterator OutgoingArcIterator; + typedef typename Graph::OutgoingOrOppositeIncomingArcIterator + OutgoingOrOppositeIncomingArcIterator; + typedef ZVector ArcIndexArray; + + // Initialize a MinCostFlow instance on the given graph. The graph does not + // need to be fully built yet, but its capacity reservation is used to + // initialize the memory of this class. + explicit GenericMinCostFlow(const Graph* graph); + + // Returns the graph associated to the current object. + const Graph* graph() const { return graph_; } + + // Returns the status of last call to Solve(). NOT_SOLVED is returned if + // Solve() has never been called or if the problem has been modified in such a + // way that the previous solution becomes invalid. + Status status() const { return status_; } + + // Sets the supply corresponding to node. A demand is modeled as a negative + // supply. + void SetNodeSupply(NodeIndex node, FlowQuantity supply); + + // Sets the unit cost for the given arc. + void SetArcUnitCost(ArcIndex arc, ArcScaledCostType unit_cost); + + // Sets the capacity for the given arc. + void SetArcCapacity(ArcIndex arc, ArcFlowType new_capacity); + + // Sets the flow for the given arc. Note that new_flow must be smaller than + // the capacity of the arc. + void SetArcFlow(ArcIndex arc, ArcFlowType new_flow); + + // Solves the problem, returning true if a min-cost flow could be found. + bool Solve(); + + // Checks for feasibility, i.e., that all the supplies and demands can be + // matched without exceeding bottlenecks in the network. + // If infeasible_supply_node (resp. infeasible_demand_node) are not NULL, + // they are populated with the indices of the nodes where the initial supplies + // (resp. demands) are too large. Feasible values for the supplies and + // demands are accessible through FeasibleSupply. + // Note that CheckFeasibility is called by Solve() when the flag + // min_cost_flow_check_feasibility is set to true (which is the default.) + bool CheckFeasibility(std::vector* const infeasible_supply_node, + std::vector* const infeasible_demand_node); + + // Makes the min-cost flow problem solvable by truncating supplies and + // demands to a level acceptable by the network. There may be several ways to + // do it. In our case, the levels are computed from the result of the max-flow + // algorithm run in CheckFeasibility(). + // MakeFeasible returns false if CheckFeasibility() was not called before. + bool MakeFeasible(); + + // Returns the cost of the minimum-cost flow found by the algorithm. + CostValue GetOptimalCost() const { return total_flow_cost_; } + + // Returns the flow on the given arc using the equations given in the + // comment on residual_arc_capacity_. + FlowQuantity Flow(ArcIndex arc) const; + + // Returns the capacity of the given arc. + FlowQuantity Capacity(ArcIndex arc) const; + + // Returns the unscaled cost for the given arc. + CostValue UnitCost(ArcIndex arc) const; + + // Returns the supply at a given node. Demands are modelled as negative + // supplies. + FlowQuantity Supply(NodeIndex node) const; + + // Returns the initial supply at a given node. + FlowQuantity InitialSupply(NodeIndex node) const; + + // Returns the largest supply (if > 0) or largest demand in absolute value + // (if < 0) admissible at node. If the problem is not feasible, some of these + // values will be smaller (in absolute value) than the initial supplies + // and demand given as input. + FlowQuantity FeasibleSupply(NodeIndex node) const; + + // Whether to use the UpdatePrices() heuristic. + void SetUseUpdatePrices(bool value) { use_price_update_ = value; } + + // Whether to check the feasibility of the problem with a max-flow, prior to + // solving it. This uses about twice as much memory, but detects infeasible + // problems (where the flow can't be satisfied) and makes Solve() return + // INFEASIBLE. If you disable this check, you will spare memory but you must + // make sure that your problem is feasible, otherwise the code can loop + // forever. + void SetCheckFeasibility(bool value) { check_feasibility_ = value; } + + private: + // Returns true if the given arc is admissible i.e. if its residual capacity + // is strictly positive, and its reduced cost strictly negative, i.e., pushing + // more flow into it will result in a reduction of the total cost. + bool IsAdmissible(ArcIndex arc) const; + bool FastIsAdmissible(ArcIndex arc, CostValue tail_potential) const; + + // Returns true if node is active, i.e., if its supply is positive. + bool IsActive(NodeIndex node) const; + + // Returns the reduced cost for a given arc. + CostValue ReducedCost(ArcIndex arc) const; + CostValue FastReducedCost(ArcIndex arc, CostValue tail_potential) const; + + // Returns the first incident arc of a given node. + ArcIndex GetFirstOutgoingOrOppositeIncomingArc(NodeIndex node) const; + + // Checks the consistency of the input, i.e., whether the sum of the supplies + // for all nodes is equal to zero. To be used in a DCHECK. + bool CheckInputConsistency() const; + + // Checks whether the result is valid, i.e. whether for each arc, + // residual_arc_capacity_[arc] == 0 || ReducedCost(arc) >= -epsilon_. + // (A solution is epsilon-optimal if ReducedCost(arc) >= -epsilon.) + // To be used in a DCHECK. + bool CheckResult() const; + + // Checks that the cost range fits in the range of int64_t's. + // To be used in a DCHECK. + bool CheckCostRange() const; + + // Checks the relabel precondition (to be used in a DCHECK): + // - The node must be active, or have a 0 excess (relaxation for the Push + // Look-Ahead heuristic). + // - The node must have no admissible arcs. + bool CheckRelabelPrecondition(NodeIndex node) const; + + // Returns context concatenated with information about a given arc + // in a human-friendly way. + std::string DebugString(const std::string& context, ArcIndex arc) const; + + // Resets the first_admissible_arc_ array to the first incident arc of each + // node. + void ResetFirstAdmissibleArcs(); + + // Scales the costs, by multiplying them by (graph_->num_nodes() + 1). + void ScaleCosts(); + + // Unscales the costs, by dividing them by (graph_->num_nodes() + 1). + void UnscaleCosts(); + + // Optimizes the cost by dividing epsilon_ by alpha_ and calling Refine(). + void Optimize(); + + // Saturates the admissible arcs, i.e., push as much flow as possible. + void SaturateAdmissibleArcs(); + + // Pushes flow on a given arc, i.e., consumes flow on + // residual_arc_capacity_[arc], and consumes -flow on + // residual_arc_capacity_[Opposite(arc)]. Updates node_excess_ at the tail + // and head of the arc accordingly. + void PushFlow(FlowQuantity flow, ArcIndex arc); + void FastPushFlow(FlowQuantity flow, ArcIndex arc, NodeIndex tail); + + // Initializes the stack active_nodes_. + void InitializeActiveNodeStack(); + + // Price update heuristics as described in A.V. Goldberg, "An Efficient + // Implementation of a Scaling Minimum-Cost Flow Algorithm," Journal of + // Algorithms, (1997) 22:1-29 + // http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.31.258 + void UpdatePrices(); + + // Performs an epsilon-optimization step by saturating admissible arcs + // and discharging the active nodes. + void Refine(); + + // Discharges an active node by saturating its admissible adjacent arcs, + // if any, and by relabelling it when it becomes inactive. + void Discharge(NodeIndex node); + + // Part of the Push LookAhead heuristic. When we are about to push on the + // in_arc, we check that the head (i.e node here) can accept the flow and + // return true if this is the case: + // - Returns true if the node excess is < 0. + // - Returns true if node is an admissible arc at its current potential. + // - If the two conditions above are false, the node can be relabeled. We + // do that and return true if the in_arc is still admissible. + bool LookAhead(ArcIndex in_arc, CostValue in_tail_potential, NodeIndex node); + + // Relabels node, i.e., decreases its potential while keeping the + // epsilon-optimality of the pseudo flow. See CheckRelabelPrecondition() for + // details on the preconditions. + void Relabel(NodeIndex node); + + // Handy member functions to make the code more compact. + NodeIndex Head(ArcIndex arc) const { return graph_->Head(arc); } + NodeIndex Tail(ArcIndex arc) const { return graph_->Tail(arc); } + ArcIndex Opposite(ArcIndex arc) const; + bool IsArcDirect(ArcIndex arc) const; + bool IsArcValid(ArcIndex arc) const; + + // Pointer to the graph passed as argument. + const Graph* graph_; + + // An array representing the supply (if > 0) or the demand (if < 0) + // for each node in graph_. + QuantityArray node_excess_; + + // An array representing the potential (or price function) for + // each node in graph_. + CostArray node_potential_; + + // An array representing the residual_capacity for each arc in graph_. + // Residual capacities enable one to represent the capacity and flow for all + // arcs in the graph in the following manner. + // For all arcs, residual_arc_capacity_[arc] = capacity[arc] - flow[arc] + // Moreover, for reverse arcs, capacity[arc] = 0 by definition. + // Also flow[Opposite(arc)] = -flow[arc] by definition. + // Therefore: + // - for a direct arc: + // flow[arc] = 0 - flow[Opposite(arc)] + // = capacity[Opposite(arc)] - flow[Opposite(arc)] + // = residual_arc_capacity_[Opposite(arc)] + // - for a reverse arc: + // flow[arc] = -residual_arc_capacity_[arc] + // Using these facts enables one to only maintain residual_arc_capacity_, + // instead of both capacity and flow, for each direct and indirect arc. This + // reduces the amount of memory for this information by a factor 2. + // Note that the sum of the largest capacity of an arc in the graph and of + // the total flow in the graph mustn't exceed the largest 64 bit integer + // to avoid errors. CheckInputConsistency() verifies this constraint. + ZVector residual_arc_capacity_; + + // An array representing the first admissible arc for each node in graph_. + ArcIndexArray first_admissible_arc_; + + // A stack used for managing active nodes in the algorithm. + // Note that the papers cited above recommend the use of a queue, but + // benchmarking so far has not proved it is better. + std::stack active_nodes_; + + // epsilon_ is the tolerance for optimality. + CostValue epsilon_; + + // alpha_ is the factor by which epsilon_ is divided at each iteration of + // Refine(). + const int64_t alpha_; + + // cost_scaling_factor_ is the scaling factor for cost. + CostValue cost_scaling_factor_; + + // An array representing the scaled unit cost for each arc in graph_. + ZVector scaled_arc_unit_cost_; + + // The total cost of the flow. + CostValue total_flow_cost_; + + // The status of the problem. + Status status_; + + // An array containing the initial excesses (i.e. the supplies) for each + // node. This is used to create the max-flow-based feasibility checker. + QuantityArray initial_node_excess_; + + // An array containing the best acceptable excesses for each of the + // nodes. These excesses are imposed by the result of the max-flow-based + // feasibility checker for the nodes with an initial supply != 0. For the + // other nodes, the excess is simply 0. + QuantityArray feasible_node_excess_; + + // Number of Relabel() since last UpdatePrices(). + int num_relabels_since_last_price_update_; + + // A Boolean which is true when feasibility has been checked. + bool feasibility_checked_; + + // Whether to use the UpdatePrices() heuristic. + bool use_price_update_; + + // Whether to check the problem feasibility with a max-flow. + bool check_feasibility_; + + GenericMinCostFlow(const GenericMinCostFlow&); + GenericMinCostFlow& operator=(const GenericMinCostFlow&); +}; + +// Default MinCostFlow instance that uses StarGraph. +// New clients should use SimpleMinCostFlow if they can. +class MinCostFlow : public GenericMinCostFlow { + public: + explicit MinCostFlow(const StarGraph* graph) : GenericMinCostFlow(graph) {} +}; + +} // namespace operations_research +#endif // OR_TOOLS_GRAPH_MIN_COST_FLOW_H_ diff --git a/cxx/isce3/unwrap/ortools/permutation.h b/cxx/isce3/unwrap/ortools/permutation.h new file mode 100644 index 000000000..7c122e76f --- /dev/null +++ b/cxx/isce3/unwrap/ortools/permutation.h @@ -0,0 +1,227 @@ +// Copyright 2010-2021 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +// Classes for permuting indexable, ordered containers of data without +// depending on that data to be accessible in any particular way. The +// client needs to give us two things: +// 1. a permutation to apply to some container(s) of data, and +// 2. a description of how to move data around in the container(s). +// +// The permutation (1) comes to us in the form of an array argument to +// PermutationApplier::Apply(), along with index values that tell us +// where in that array the permutation of interest lies. Typically +// those index values will span the entire array that describes the +// permutation. +// +// Applying a permutation involves decomposing the permutation into +// disjoint cycles and walking each element of the underlying data one +// step around the unique cycle in which it participates. The +// decomposition into disjoint cycles is done implicitly on the fly as +// the code in PermutationApplier::Apply() advances through the array +// describing the permutation. As an important piece of bookkeeping to +// support the decomposition into cycles, the elements of the +// permutation array typically get modified somehow to indicate which +// ones have already been used. +// +// At first glance, it would seem that if the containers are +// indexable, we don't need anything more complicated than just the +// permutation and the container of data we want to permute; it would +// seem we can just use the container's operator[] to retrieve and +// assign elements within the container. Unfortunately it's not so +// simple because the containers of interest can be indexable without +// providing any consistent way of accessing their contents that +// applies to all the containers of interest. For instance, if we +// could insist that every indexable container must define an lvalue +// operator[]() we could simply use that for the assignments we need +// to do while walking around cycles of the permutation. But we cannot +// insist on any such thing. To see why, consider the PackedArray +// class template in ortools/util/packed_array.h +// where operator[] is supplied for rvalues, but because each logical +// array element is packed across potentially multiple instances of +// the underlying data type that the C++ language knows about, there +// is no way to have a C++ reference to an element of a +// PackedArray. There are other such examples besides PackedArray, +// too. This is the main reason we need a codified description (2) of +// how to move data around in the indexable container. That +// description comes to us via the PermutationApplier constructor's +// argument which is a PermutationCycleHandler instance. Such an +// object has three important methods defined: SetTempFromIndex(), +// SetIndexFromIndex(), and SetIndexFromTemp(). Those methods embody +// all we need to know about how to move data in the indexable +// container(s) underlying the PermutationCycleHandler. +// +// Another reason we need the description (2) of how to move elements +// around in the container(s) is that it is often important to permute +// side-by-side containers of elements according to the same +// permutation. This situation, too, is covered by defining a +// PermutationCycleHandler that knows about multiple underlying +// indexable containers. +// +// The above-mentioned PermutationCycleHandler methods embody +// knowledge of how to assign elements. It happens that +// PermutationCycleHandler is also a convenient place to embody the +// knowledge of how to keep track of which permutation elements have +// been consumed by the process of walking data around cycles. We +// depend on the PermutationCycleHandler instance we're given to +// define SetSeen() and Unseen() methods for that purpose. +// +// For the common case in which elements can be accessed using +// operator[](), we provide the class template +// ArrayIndexCycleHandler. + +#ifndef OR_TOOLS_UTIL_PERMUTATION_H_ +#define OR_TOOLS_UTIL_PERMUTATION_H_ + +#include + +#include + +namespace operations_research { + +// Abstract base class template defining the interface needed by +// PermutationApplier to handle a single cycle of a permutation. +template +class PermutationCycleHandler { + public: + // Sets the internal temporary storage from the given index in the + // underlying container(s). + virtual void SetTempFromIndex(IndexType source) = 0; + + // Moves a data element one step along its cycle. + virtual void SetIndexFromIndex(IndexType source, + IndexType destination) const = 0; + + // Sets a data element from the temporary. + virtual void SetIndexFromTemp(IndexType destination) const = 0; + + // Marks an element of the permutation as handled by + // PermutationHandler::Apply(), meaning that we have read the + // corresponding value from the data to be permuted, and put that + // value somewhere (either in the temp or in its ultimate + // destination in the data. + // + // This method must be overridden in implementations where it is + // called. If an implementation doesn't call it, no need to + // override. + virtual void SetSeen(IndexType* unused_permutation_element) const { + pyre::journal::error_t channel("isce3.unwrap.ortools.permutation"); + channel << pyre::journal::at(__HERE__) + << "Base implementation of SetSeen() must not be called." + << pyre::journal::endl; + } + + // Returns true iff the given element of the permutation is unseen, + // meaning that it has not yet been handled by + // PermutationApplier::Apply(). + // + // This method must be overridden in implementations where it is + // called. If an implementation doesn't call it, no need to + // override. + virtual bool Unseen(IndexType unused_permutation_element) const { + pyre::journal::error_t channel("isce3.unwrap.ortools.permutation"); + channel << pyre::journal::at(__HERE__) + << "Base implementation of Unseen() must not be called." + << pyre::journal::endl; + return false; + } + + virtual ~PermutationCycleHandler() {} + + protected: + PermutationCycleHandler() {} + + private: + PermutationCycleHandler(const PermutationCycleHandler&); + PermutationCycleHandler& operator=(const PermutationCycleHandler&); +}; + +// A generic cycle handler class for the common case in which the +// object to be permuted is indexable with T& operator[](int), and the +// permutation is represented by a mutable array of nonnegative +// int-typed index values. To mark a permutation element as seen, we +// replace it by its ones-complement value. +template +class ArrayIndexCycleHandler : public PermutationCycleHandler { + public: + explicit ArrayIndexCycleHandler(DataType* data) : data_(data) {} + + void SetTempFromIndex(IndexType source) override { temp_ = data_[source]; } + void SetIndexFromIndex(IndexType source, + IndexType destination) const override { + data_[destination] = data_[source]; + } + void SetIndexFromTemp(IndexType destination) const override { + data_[destination] = temp_; + } + void SetSeen(IndexType* permutation_element) const override { + *permutation_element = -*permutation_element - 1; + } + bool Unseen(IndexType permutation_element) const override { + return permutation_element >= 0; + } + + private: + // Pointer to the base of the array of data to be permuted. + DataType* data_; + + // Temporary storage for the one extra element we need. + DataType temp_; + + ArrayIndexCycleHandler(const ArrayIndexCycleHandler&); + ArrayIndexCycleHandler& operator=(const ArrayIndexCycleHandler&); +}; + +// Note that this template is not implemented in an especially +// performance-sensitive way. In particular, it makes multiple virtual +// method calls for each element of the permutation. +template +class PermutationApplier { + public: + explicit PermutationApplier(PermutationCycleHandler* cycle_handler) + : cycle_handler_(cycle_handler) {} + + void Apply(IndexType permutation[], int permutation_start, + int permutation_end) { + for (IndexType current = permutation_start; current < permutation_end; + ++current) { + IndexType next = permutation[current]; + // cycle_start is only for debugging. + const IndexType cycle_start = current; + if (cycle_handler_->Unseen(next)) { + cycle_handler_->SetSeen(&permutation[current]); + assert(!cycle_handler_->Unseen(permutation[current])); + cycle_handler_->SetTempFromIndex(current); + while (cycle_handler_->Unseen(permutation[next])) { + cycle_handler_->SetIndexFromIndex(next, current); + current = next; + next = permutation[next]; + cycle_handler_->SetSeen(&permutation[current]); + assert(!cycle_handler_->Unseen(permutation[current])); + } + cycle_handler_->SetIndexFromTemp(current); + // Set current back to the start of this cycle. + current = next; + } + assert(cycle_start == current); + } + } + + private: + PermutationCycleHandler* cycle_handler_; + + PermutationApplier(const PermutationApplier&); + PermutationApplier& operator=(const PermutationApplier&); +}; +} // namespace operations_research +#endif // OR_TOOLS_UTIL_PERMUTATION_H_ diff --git a/cxx/isce3/unwrap/ortools/zvector.h b/cxx/isce3/unwrap/ortools/zvector.h new file mode 100644 index 000000000..ce3722ba5 --- /dev/null +++ b/cxx/isce3/unwrap/ortools/zvector.h @@ -0,0 +1,169 @@ +// Copyright 2010-2021 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_UTIL_ZVECTOR_H_ +#define OR_TOOLS_UTIL_ZVECTOR_H_ + +#if (defined(__APPLE__) || defined(__FreeBSD__)) && defined(__GNUC__) +#include +#elif !defined(_MSC_VER) +#include +#endif +#include +#include +#include +#include +#include +#include +#include + +#include + +// An array class for storing arrays of integers. +// +// The range of indices is specified at the construction of the object. +// The minimum and maximum indices are inclusive. +// Think of the Pascal syntax array[min_index..max_index] of ... +// +// For example, ZVector(-100000,100000) will store 200001 +// signed integers of 32 bits each, and the possible range of indices +// will be -100000..100000. + +namespace operations_research { + +template +class ZVector { + public: + ZVector() + : base_(nullptr), min_index_(0), max_index_(-1), size_(0), storage_() {} + + ZVector(int64_t min_index, int64_t max_index) + : base_(nullptr), min_index_(0), max_index_(-1), size_(0), storage_() { + if (!Reserve(min_index, max_index)) { + pyre::journal::firewall_t channel("isce3.unwrap.ortools.zvector"); + channel << pyre::journal::at(__HERE__) + << "Could not reserve memory for indices ranging from " + << min_index << " to " << max_index + << pyre::journal::endl; + } + } + + int64_t min_index() const { return min_index_; } + + int64_t max_index() const { return max_index_; } + + // Returns the value stored at index. + T Value(int64_t index) const { + assert(min_index_ <= index); + assert(max_index_ >= index); + assert(base_ != nullptr); + return base_[index]; + } + + // Shortcut for returning the value stored at index. + T& operator[](int64_t index) { + assert(min_index_ <= index); + assert(max_index_ >= index); + assert(base_ != nullptr); + return base_[index]; + } + + const T operator[](int64_t index) const { + assert(min_index_ <= index); + assert(max_index_ >= index); + assert(base_ != nullptr); + return base_[index]; + } + + // Sets to value the content of the array at index. + void Set(int64_t index, T value) { + assert(min_index_ <= index); + assert(max_index_ >= index); + assert(base_ != nullptr); + base_[index] = value; + } + + // Reserves memory for new minimum and new maximum indices. + // Returns true if the memory could be reserved. + // Never shrinks the memory allocated. + bool Reserve(int64_t new_min_index, int64_t new_max_index) { + if (new_min_index > new_max_index) { + return false; + } + const uint64_t new_size = new_max_index - new_min_index + 1; + if (base_ != nullptr) { + if (new_min_index >= min_index_ && new_max_index <= max_index_) { + min_index_ = new_min_index; + max_index_ = new_max_index; + size_ = new_size; + return true; + } else if (new_min_index > min_index_ || new_max_index < max_index_) { + return false; + } + } + T* new_storage = new T[new_size]; + if (new_storage == nullptr) { + return false; + } + + T* const new_base = new_storage - new_min_index; + if (base_ != nullptr) { + T* const destination = new_base + min_index_; + std::memcpy(destination, storage_.get(), size_ * sizeof(*base_)); + } + + base_ = new_base; + size_ = new_size; + min_index_ = new_min_index; + max_index_ = new_max_index; + storage_.reset(new_storage); + return true; + } + + // Sets all the elements in the array to value. + void SetAll(T value) { + for (int64_t i = 0; i < size_; ++i) { + base_[min_index_ + i] = value; + } + } + + private: + // Pointer to the element indexed by zero in the array. + T* base_; + + // Minimum index for the array. + int64_t min_index_; + + // Maximum index for the array. + int64_t max_index_; + + // The number of elements in the array. + int64_t size_; + + // Storage memory for the array. + std::unique_ptr storage_; +}; + +// Shorthands for all the types of ZVector's. +typedef ZVector Int8ZVector; +typedef ZVector Int16ZVector; +typedef ZVector Int32ZVector; +typedef ZVector Int64ZVector; +typedef ZVector UInt8ZVector; +typedef ZVector UInt16ZVector; +typedef ZVector UInt32ZVector; +typedef ZVector UInt64ZVector; + +} // namespace operations_research + +#endif // OR_TOOLS_UTIL_ZVECTOR_H_ diff --git a/cxx/isce3/unwrap/snaphu/snaphu.cpp b/cxx/isce3/unwrap/snaphu/snaphu.cpp index 8336683f8..f36bc730f 100644 --- a/cxx/isce3/unwrap/snaphu/snaphu.cpp +++ b/cxx/isce3/unwrap/snaphu/snaphu.cpp @@ -517,9 +517,8 @@ int UnwrapTile(infileT *infiles, outfileT *outfiles, paramT *params, }else if(params->initmethod==MCFINIT){ - fflush(NULL); - throw isce3::except::InvalidArgument(ISCE_SRCINFO(), - "MCF initialization not implemented"); + /* use minimum cost flow (MCF) algorithm */ + MCFInitFlows(wrappedphase,&flows,mstcosts,nrow,ncol); }else{ fflush(NULL); @@ -725,7 +724,6 @@ int UnwrapTile(infileT *infiles, outfileT *outfiles, paramT *params, /* flip the sign of the unwrapped phase array if it was flipped initially, */ FlipPhaseArraySign(unwrappedphase,params,nrow,ncol); - /* write the unwrapped output */ fprintf(sp1,"Writing output to file %s\n",outfiles->outfile); WriteOutputFile(mag,unwrappedphase,outfiles->outfile,outfiles, diff --git a/cxx/isce3/unwrap/snaphu/snaphu.h b/cxx/isce3/unwrap/snaphu/snaphu.h index ee1649d28..72185cab5 100644 --- a/cxx/isce3/unwrap/snaphu/snaphu.h +++ b/cxx/isce3/unwrap/snaphu/snaphu.h @@ -19,8 +19,6 @@ #include -namespace isce3::unwrap { - /**********************/ /* defined constants */ /**********************/ @@ -129,6 +127,7 @@ namespace isce3::unwrap { #define NARMS 8 /* number of arms for Despeckle() */ #define ARMLEN 5 /* length of arms for Despeckle() */ #define KEDGE 5 /* length of edge detection window */ +#define ARCUBOUND 200 /* capacities for MCF solver */ #define MSTINIT 1 /* initialization method */ #define MCFINIT 2 /* initialization method */ #define BIGGESTDZRHOMAX 10000.0 @@ -419,6 +418,8 @@ namespace isce3::unwrap { "\n" +namespace isce3::unwrap { + /********************/ /* type definitions */ /********************/ @@ -823,6 +824,8 @@ totalcostT EvaluateTotalCost(Array2D& costs, Array2D& wrappedphase, Array2D* flowsptr, Array2D& mstcosts, long nrow, long ncol, Array2D* nodes, nodeT *ground, long maxflow); +int MCFInitFlows(Array2D& wrappedphase, Array2D* flowsptr, Array2D& mstcosts, + long nrow, long ncol); /* functions in snaphu_cost.c */ diff --git a/cxx/isce3/unwrap/snaphu/snaphu_solver.cpp b/cxx/isce3/unwrap/snaphu/snaphu_solver.cpp index 67d74cacb..f04973f84 100644 --- a/cxx/isce3/unwrap/snaphu/snaphu_solver.cpp +++ b/cxx/isce3/unwrap/snaphu/snaphu_solver.cpp @@ -10,8 +10,10 @@ #include #include +#include #include +#include #include "snaphu.h" @@ -2350,7 +2352,7 @@ int InitNetwork(Array2D& flows, long *ngroundarcsptr, long *ncycleptr, long i; /* get and initialize memory for nodes */ - if(ground!=NULL && nodesptr->size()){ + if(ground!=NULL && !nodesptr->size()){ *nodesptr = Array2D(nrow-1, ncol-1); InitNodeNums(nrow-1,ncol-1,*nodesptr,ground); } @@ -3673,6 +3675,173 @@ signed char ClipFlow(Array2D& residue, Array2D& flows, } +/* function: MCFInitFlows() + * ------------------------ + * Initializes the flow on the network using a minimum cost flow + * algorithm. + */ +int MCFInitFlows(Array2D& wrappedphase, Array2D* flowsptr, + Array2D& mstcosts, long nrow, long ncol){ + + /* number of rows & cols of nodes in the residue network */ + const auto m=nrow-1; + const auto n=ncol-1; + + /* calculate phase residues (integer numbers of cycles) */ + auto residue=Array2D(m,n); + CycleResidue(wrappedphase,residue,nrow,ncol); + + /* total number of nodes and directed arcs in the network */ + const auto nnodes=m*n+1; + const auto narcs=2*((m+1)*n+(n+1)*m); + + /* the solver uses 32-bit integers for node & arc indices */ + /* check for possible overflow */ + using operations_research::NodeIndex; + using operations_research::ArcIndex; + if(nnodes>std::numeric_limits::max()){ + throw isce3::except::RuntimeError(ISCE_SRCINFO(), + "Number of MCF network nodes exceeds maximum representable value"); + } + if(narcs>std::numeric_limits::max()){ + throw isce3::except::RuntimeError(ISCE_SRCINFO(), + "Number of MCF network arcs exceeds maximum representable value"); + } + + /* begin building the network topology and setting up the MCF problem */ + using Network=operations_research::SimpleMinCostFlow; + auto network=Network(nnodes,narcs); + + /* assigns a positive integer label to each grid node */ + /* grid node indices begin at 1 (index 0 is used for the ground node) */ + auto GetNodeIndex=[=](long i, long j)->NodeIndex{ + return 1+i*n+j; + }; + constexpr NodeIndex ground=0; + + /* adds a pair of forward & reverse arcs to the network connecting two nodes */ + /* sister arcs have equal cost and capacity */ + using operations_research::CostValue; + using operations_research::FlowQuantity; + auto AddSisterArcs=[&](NodeIndex node1, NodeIndex node2, CostValue cost){ + constexpr static auto capacity=static_cast(ARCUBOUND); + network.AddArcWithCapacityAndUnitCost(node2,node1,capacity,cost); + network.AddArcWithCapacityAndUnitCost(node1,node2,capacity,cost); + }; + + /* break down arc costs into row (horizontal) & col (vertical) cost arrays */ + const auto rowcosts=mstcosts.topLeftCorner(m,n+1); + const auto colcosts=mstcosts.bottomLeftCorner(m+1,n); + + /* arcs are assigned sequential indices (starting from 0) in the order that + they're added to the network */ + /* we rely on this fact later on when extracting flows from the network */ + + /* begin adding horizontal arcs to the network */ + for(long i=0;i(rowcosts(i,0)); + AddSisterArcs(ground,node,cost); + } + + /* add a pair of horizontal arcs between each adjacent grid node */ + for(long j=0;j(rowcosts(i,j+1)); + AddSisterArcs(node1,node2,cost); + } + + /* add a pair of arcs between the right border node and the ground node */ + { + const auto node=GetNodeIndex(i,n-1); + const auto cost=static_cast(rowcosts(i,n)); + AddSisterArcs(node,ground,cost); + } + } + + /* begin adding vertical arcs to the network */ + /* add a pair of arcs between each top border node and the ground node */ + for(long j=0;j(colcosts(0,j)); + AddSisterArcs(ground,node,cost); + } + /* add a pair of vertical arcs between each adjacent grid node */ + for(long i=0;i(colcosts(i+1,j)); + AddSisterArcs(node1,node2,cost); + } + } + /* add a pair of arcs between each bottom border node and the ground node */ + for(long j=0;j(colcosts(m,j)); + AddSisterArcs(node,ground,cost); + } + + /* add node supplies to the network */ + FlowQuantity totalsupply=0; + for(long i=0;i(residue(i,j)); + network.SetNodeSupply(node,supply); + totalsupply+=supply; + } + } + + /* add enough demand to the ground node to balance the network */ + network.SetNodeSupply(ground,-totalsupply); + + /* run the solver to produce L1-optimal flows */ + if(network.Solve() != Network::OPTIMAL){ + throw isce3::except::RuntimeError(ISCE_SRCINFO(), + "MCF initialization failed"); + } + + *flowsptr=MakeRowColArray2D(nrow,ncol); + + /* break down arc flows into row (horizontal) & col (vertical) flow arrays */ + auto rowflows=flowsptr->topLeftCorner(m,n+1); + auto colflows=flowsptr->bottomLeftCorner(m+1,n); + + /* extract arc flows from the network */ + /* the easiest way to do this is in the exact order in which the arcs were + added to the network (relying implicitly on the sequential ordering of arc + indices) */ + + /* extract horizontal flows from the network */ + ArcIndex arcidx=0; + for(long i=0;i&, Array2D&, nodeT*, \ nodeT*, Array1D*, \ diff --git a/doc/doxygen/tutorial/geometry.dox b/doc/doxygen/tutorial/geometry.dox index f20ee3078..71328436e 100644 --- a/doc/doxygen/tutorial/geometry.dox +++ b/doc/doxygen/tutorial/geometry.dox @@ -246,7 +246,7 @@ int main(int argc, char *argv[]) ellipse, //Ellipsoid orbit, //Orbit dop, //Doppler - mode, //Product metadata + mode, //RadarGridProduct metadata aztime, //Estimated azimuth time rng, //Estimated slant range side, //Look side diff --git a/doc/sphinx/library.rst b/doc/sphinx/library.rst index 1bc2179dd..77e11436b 100644 --- a/doc/sphinx/library.rst +++ b/doc/sphinx/library.rst @@ -30,7 +30,7 @@ I/O Datastructures Product Datastructures ---------------------- * :doc:`RadarGridParameters <./product/RadarGridParameters>` -* :doc:`Product <./product/Product>` +* :doc:`RadarGridProduct <./product/RadarGridProduct>` * :doc:`Metadata <./product/Metadata>` Image Datastructures diff --git a/doc/sphinx/product/Product.rst b/doc/sphinx/product/Product.rst index 5bc3a22d3..5ec05b3c7 100644 --- a/doc/sphinx/product/Product.rst +++ b/doc/sphinx/product/Product.rst @@ -1,17 +1,17 @@ :orphan: -.. title:: Product +.. title:: RadarGridProduct -Product +RadarGridProduct ========= -Product provides a light wrapper for the isce::product::Product class instantiated from a IH5File object, +RadarGridProduct provides a light wrapper for the isce::product::RadarGridProduct class instantiated from a IH5File object, which is the highest level representation of an ISCE radar product. Documentation ---------------- -.. autoclass:: isce3.product.Product.Product +.. autoclass:: isce3.product.Product.RadarGridProduct :members: :inherited-members: diff --git a/python/extensions/pybind_isce3/Sources.cmake b/python/extensions/pybind_isce3/Sources.cmake index 4bafcac8a..baa2f578b 100644 --- a/python/extensions/pybind_isce3/Sources.cmake +++ b/python/extensions/pybind_isce3/Sources.cmake @@ -74,6 +74,7 @@ product/GeoGridParameters.cpp product/product.cpp product/RadarGridParameters.cpp product/Swath.cpp +product/Grid.cpp unwrap/unwrap.cpp unwrap/ICU.cpp unwrap/Phass.cpp diff --git a/python/extensions/pybind_isce3/product/Grid.cpp b/python/extensions/pybind_isce3/product/Grid.cpp new file mode 100644 index 000000000..b50250914 --- /dev/null +++ b/python/extensions/pybind_isce3/product/Grid.cpp @@ -0,0 +1,73 @@ +#include "Grid.h" + +#include + +#include + +#include +#include +#include + +namespace py = pybind11; + +using isce3::product::Grid; + +void addbinding(pybind11::class_ & pyGrid) +{ + pyGrid + .def(py::init<>()) + .def(py::init([](const std::string &h5file, const char freq) + { + // open file + isce3::io::IH5File file(h5file); + + // instantiate and load a product + isce3::product::GeoGridProduct product(file); + + // return grid from product + return product.grid(freq); + }), + py::arg("h5file"), py::arg("freq")) + + .def_property_readonly("wavelength", &Grid::wavelength) + .def_property("geogrid", + py::overload_cast<>(&Grid::geogrid), + py::overload_cast(&Grid::geogrid)) + .def_property("range_bandwidth", + py::overload_cast<>(&Grid::rangeBandwidth, py::const_), + py::overload_cast(&Grid::rangeBandwidth)) + .def_property("azimuth_bandwidth", + py::overload_cast<>(&Grid::azimuthBandwidth, py::const_), + py::overload_cast(&Grid::azimuthBandwidth)) + .def_property("center_frequency", + py::overload_cast<>(&Grid::centerFrequency, py::const_), + py::overload_cast(&Grid::centerFrequency)) + .def_property("slant_range_spacing", + py::overload_cast<>(&Grid::slantRangeSpacing, py::const_), + py::overload_cast(&Grid::slantRangeSpacing)) + .def_property("zero_doppler_time_spacing", + py::overload_cast<>(&Grid::zeroDopplerTimeSpacing, py::const_), + py::overload_cast(&Grid::zeroDopplerTimeSpacing)) + + .def_property("start_x", + py::overload_cast<>(&Grid::startX, py::const_), + py::overload_cast(&Grid::startX)) + .def_property("start_y", + py::overload_cast<>(&Grid::startY, py::const_), + py::overload_cast(&Grid::startY)) + .def_property("spacing_x", + py::overload_cast<>(&Grid::spacingX, py::const_), + py::overload_cast(&Grid::spacingX)) + .def_property("spacing_y", + py::overload_cast<>(&Grid::spacingY, py::const_), + py::overload_cast(&Grid::spacingY)) + .def_property("width", + py::overload_cast<>(&Grid::width, py::const_), + py::overload_cast(&Grid::width)) + .def_property("length", + py::overload_cast<>(&Grid::length, py::const_), + py::overload_cast(&Grid::length)) + .def_property("epsg", + py::overload_cast<>(&Grid::epsg, py::const_), + py::overload_cast(&Grid::epsg)); +} diff --git a/python/extensions/pybind_isce3/product/Grid.h b/python/extensions/pybind_isce3/product/Grid.h new file mode 100644 index 000000000..cc9b81fb3 --- /dev/null +++ b/python/extensions/pybind_isce3/product/Grid.h @@ -0,0 +1,6 @@ +#pragma once + +#include +#include + +void addbinding(pybind11::class_ &); diff --git a/python/extensions/pybind_isce3/product/RadarGridParameters.cpp b/python/extensions/pybind_isce3/product/RadarGridParameters.cpp index 43811ebd8..624c067de 100644 --- a/python/extensions/pybind_isce3/product/RadarGridParameters.cpp +++ b/python/extensions/pybind_isce3/product/RadarGridParameters.cpp @@ -5,7 +5,7 @@ #include #include -#include +#include namespace py = pybind11; @@ -22,7 +22,7 @@ void addbinding(pybind11::class_ & pyRadarGridParameters) isce3::io::IH5File file(h5file); // instantiate and load a product - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); // return swath from product return RadarGridParameters(product, freq); diff --git a/python/extensions/pybind_isce3/product/Swath.cpp b/python/extensions/pybind_isce3/product/Swath.cpp index d39b3788c..f925efdb8 100644 --- a/python/extensions/pybind_isce3/product/Swath.cpp +++ b/python/extensions/pybind_isce3/product/Swath.cpp @@ -7,7 +7,7 @@ #include #include -#include +#include namespace py = pybind11; @@ -23,7 +23,7 @@ void addbinding(pybind11::class_ & pySwath) isce3::io::IH5File file(h5file); // instantiate and load a product - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); // return swath from product return product.swath(freq); diff --git a/python/extensions/pybind_isce3/product/product.cpp b/python/extensions/pybind_isce3/product/product.cpp index e1e099241..a5b23e592 100644 --- a/python/extensions/pybind_isce3/product/product.cpp +++ b/python/extensions/pybind_isce3/product/product.cpp @@ -3,6 +3,7 @@ #include "GeoGridParameters.h" #include "RadarGridParameters.h" #include "Swath.h" +#include "Grid.h" namespace py = pybind11; @@ -14,10 +15,12 @@ void addsubmodule_product(py::module & m) py::class_ pyGeoGridParameters(m_product, "GeoGridParameters"); py::class_ pyRadarGridParameters(m_product, "RadarGridParameters"); py::class_ pySwath(m_product, "Swath"); + py::class_ pyGrid(m_product, "Grid"); // add bindings addbinding(pyGeoGridParameters); addbinding(pyRadarGridParameters); addbinding(pySwath); + addbinding(pyGrid); addbinding_bbox_to_geogrid(m_product); } diff --git a/python/packages/CMakeLists.txt b/python/packages/CMakeLists.txt index 2e5ebf235..ed37c8ef6 100644 --- a/python/packages/CMakeLists.txt +++ b/python/packages/CMakeLists.txt @@ -27,6 +27,7 @@ endforeach() set (list_of_exe nisar/workflows/crossmul.py nisar/workflows/focus.py + nisar/workflows/gen_doppler_range_product.py nisar/workflows/geo2rdr.py nisar/workflows/geocode_insar.py nisar/workflows/gcov.py diff --git a/python/packages/isce3/signal/__init__.py b/python/packages/isce3/signal/__init__.py index 470af4ffd..052947f05 100644 --- a/python/packages/isce3/signal/__init__.py +++ b/python/packages/isce3/signal/__init__.py @@ -1,2 +1,5 @@ from pybind_isce3.signal import * +from .fir_filter_func import cheby_equi_ripple_filter +from .doppler_est_func import (corr_doppler_est, sign_doppler_est, + unwrap_doppler) from . import point_target_info diff --git a/python/packages/isce3/signal/doppler_est_func.py b/python/packages/isce3/signal/doppler_est_func.py new file mode 100644 index 000000000..0c3233669 --- /dev/null +++ b/python/packages/isce3/signal/doppler_est_func.py @@ -0,0 +1,338 @@ +""" +Collection of functions for doppler centroid estimation. +""" +import functools +import numbers +import collections as cl +import numpy as np +from scipy import fft + + +def corr_doppler_est(echo, prf, lag=1, axis=None): + """Estimate Doppler centroid based on complex correlator. + + It uses the Correlation Doppler Estimator (CDE) approach + proposed by [MADSEN1989]_ + + Parameters + ---------- + echo : np.ndarray(complex) + 1-D or 2-D numpy complex array + prf : float + Pulse-repetition frequency or sampling rate in the azimuth + direction in (Hz). + lag : int, default=1 + Lag of the correlator, a positive value. + axis : None or int, optional + Axis along which the correlator is performed. + If None it will be the first axis. + + Returns + ------- + float + Ambiguous Doppler centroid within [-0.5*prf, 0.5*prf] + float + Correlation coefficient, a value within [0, 1] + + Raises + ------ + ValueError + For bad input arguments + TypeError + If echo is not numpy array + RuntimeError: + Mismtach between lag and number of elements of echo used in correlator + np.AxisError: + Mismtach between axis value and echo dimension + + See Also + -------- + sign_doppler_est : Sign-Doppler estimator + wavelen_diversity_doppler_est + + References + ---------- + .. [MADSEN1989] S. Madsen, 'Estimating The Doppler Centroid of SAR Data', + IEEE Transaction On Aerospace and Elect Sys, March 1989 + + """ + if prf <= 0.0: + raise ValueError('prf must be a positive value') + if not isinstance(echo, np.ndarray): + raise TypeError('echo must be a numpy array') + if echo.ndim > 2: + raise ValueError('Max dimension of echo must be 2') + if lag < 1: + raise ValueError('Lag must be equal or larger than 1') + if axis is None: + axis = 0 + else: + if axis > (echo.ndim - 1): + raise np.AxisError( + f'axis {axis} is out of bound for dimenion {echo.ndim}') + + if axis == 0: + if echo.shape[0] < (lag + 1): + raise RuntimeError( + f'Not enough samples for correlator along axis {axis}') + xcor_cmp = (echo[lag:] * echo[:-lag].conj()).mean() + # get mag of product of auto correlations + acor_mag = np.sqrt((abs(echo[lag:])**2).mean()) + acor_mag *= np.sqrt((abs(echo[:-lag])**2).mean()) + else: + if echo.shape[1] < (lag + 1): + raise RuntimeError( + f'Not enough samples for correlator along axis {axis}') + xcor_cmp = (echo[:, lag:] * echo[:, :-lag].conj()).mean() + # get mag of product of auto correlations + acor_mag = np.sqrt((abs(echo[:, lag:])**2).mean()) + acor_mag *= np.sqrt((abs(echo[:, :-lag])**2).mean()) + + # calculate correlation coefficient + if acor_mag > 0: + corr_coef = abs(xcor_cmp) / acor_mag + else: + corr_coef = 0.0 + + return prf / (2.0 * np.pi * lag) * np.angle(xcor_cmp), corr_coef + + +def sign_doppler_est(echo, prf, lag=1, axis=None): + """Estimate Doppler centroid based on sign of correlator coeffs. + + It uses Sign-Doppler estimator (SDE) approach proposed by [MADSEN1989]_ + + Parameters + ---------- + echo : np.ndarray(complex) + 1-D or 2-D numpy complex array + prf : float + Pulse-repetition frequency or sampling rate in the azimuth + direction in (Hz). + lag : int, default=1 + Lag of the correlator, a positive value. + axis : None or int, optional + Axis along which the correlator is perform. + If None it will be the firsr axis. + + Returns + ------- + float + Ambiguous Doppler centroid within [-0.5*prf, 0.5*prf] + + Raises + ------ + ValueError + For bad input arguments + TypeError + If echo is not numpy array + RuntimeError: + Mismtach between lag and number of elements of echo used in correlator + np.AxisError: + Mismtach between Axis value and echo dimension + + See Also + -------- + corr_doppler_est : Correlation Doppler Estimator (CDE) + wavelen_diversity_doppler_est + + References + ---------- + .. [MADSEN1989] S. Madsen, 'Estimating The Doppler Centroid of SAR Data', + IEEE Transaction On Aerospace and Elect Sys, March 1989 + + """ + if prf <= 0.0: + raise ValueError('prf must be a positive value') + if not isinstance(echo, np.ndarray): + raise TypeError('echo must be a numpy array') + if echo.ndim > 2: + raise ValueError('Max dimension of echo must be 2') + if lag < 1: + raise ValueError('Lag must be equal or larger than 1') + if axis is None: + axis = 0 + else: + if axis > (echo.ndim - 1): + raise np.AxisError( + f'axis {axis} is out of bound for dimenion {echo.ndim}') + + sgn_i = _sgn(echo.real) + sgn_q = _sgn(echo.imag) + + if axis == 0: + if echo.shape[0] < (lag + 1): + raise RuntimeError( + f'Not enough samples for correlator along axis {axis}') + xcor_ii = (sgn_i[lag:] * sgn_i[:-lag]).mean() + xcor_qq = (sgn_q[lag:] * sgn_q[:-lag]).mean() + xcor_iq = (sgn_i[lag:] * sgn_q[:-lag]).mean() + xcor_qi = (sgn_q[lag:] * sgn_i[:-lag]).mean() + else: + if echo.shape[1] < (lag + 1): + raise RuntimeError( + f'Not enough samples for correlator along axis {axis}') + xcor_ii = (sgn_i[:, lag:] * sgn_i[:, :-lag]).mean() + xcor_qq = (sgn_q[:, lag:] * sgn_q[:, :-lag]).mean() + xcor_iq = (sgn_i[:, lag:] * sgn_q[:, :-lag]).mean() + xcor_qi = (sgn_q[:, lag:] * sgn_i[:, :-lag]).mean() + + r_sinlaw = np.sin(0.5 * np.pi * np.asarray([xcor_ii, xcor_qq, + xcor_qi, -xcor_iq])) + xcor_cmp = 0.5 * complex(r_sinlaw[:2].sum(), r_sinlaw[2:].sum()) + + return prf / (2.0 * np.pi * lag) * np.angle(xcor_cmp) + + +def wavelen_diversity_doppler_est(echo, prf, samprate, bandwidth, + centerfreq): + """Estimate Doppler based on wavelength diversity. + + It uses slope of phase of range frequency along with single-lag + time-domain correlator approach proposed by [BAMLER1991]_. + + Parameters + ---------- + echo : np.ndarray(complex) + 2-D complex basebanded echo, azimuth by range in time domain. + prf : float + Pulse repetition frequency in (Hz) + samprate : float + Sampling rate in range , second dim, in (Hz) + bandwidth : float + RF/chirp bandiwdth in (Hz) + centerfreq : float + RF center frequency of chirp in (Hz) + + Returns + ------- + float + Unambiguous Doppler centroid at center frequency in (Hz) + + Raises + ------ + ValueError + For bad input + TypeError + If echo is not numpy array + + See Also + -------- + corr_doppler_est : Correlation Doppler Estimator (CDE) + sign_doppler_est : Sign-Doppler estimator (SDE) + + References + ---------- + .. [BAMLER1991] R. Bamler and H. Runge, 'PRF-Ambiguity Resolving by + Wavelength Diversity', IEEE Transaction on GeoSci and Remote Sensing, + November 1991. + + """ + if prf <= 0: + raise ValueError('PRF must be positive value!') + if samprate <= 0: + raise ValueError('samprate must be positive value!') + if bandwidth <= 0 or bandwidth >= samprate: + raise ValueError('badnwidth must be positive less than samprate!') + if centerfreq <= 0.0: + raise ValueError('centerfreq must be positive value!') + if not isinstance(echo, np.ndarray): + raise TypeError('echo must be a numpy array') + if echo.ndim != 2: + raise ValueError('echo must have two dimensions') + num_azb, num_rgb = echo.shape + if num_azb <= 2: + raise ValueError('The first dimension of echo must be larger than 2') + if num_rgb > 2: + raise ValueError('The second dimension of echo must be larger than 2!') + + # FFT along range + nfft = fft.next_fast_len(num_rgb) + echo_fft = fft.fft(echo, nfft, axis=1) + + # one-lag correlator along azimuth + az_corr = (echo_fft[1:] * echo_fft[:-1].conj()).mean(axis=0) + + # Get the unwrapped phase of range spectrum within +/-bandwidth/2. + df = samprate / nfft + half_bw = 0.5 * bandwidth + idx_hbw = nfft // 2 - int(half_bw / df) + unwrap_phs_rg = np.unwrap(np.angle(fft.fftshift(az_corr) + [idx_hbw: -idx_hbw])) # (rad) + + # perform linear regression in range freq within bandwidth + freq_bw = -half_bw + df * np.arange(nfft - 2 * idx_hbw) + pf_coef = np.polyfit(freq_bw, unwrap_phs_rg, deg=1) + + # get the doppler centroid at center freq based on slope + dop_slope = prf / (2. * np.pi) * pf_coef[0] + + return centerfreq * dop_slope + + +@functools.singledispatch +def unwrap_doppler(dop, prf): + """Unwrap doppler value(s) + + Parameters + ---------- + dop : float or np.ndarray(float) or Sequence[float] + Doppler centroid value(s) in (Hz) + prf : float + Pulse repetition frequency in (Hz). + + Returns + ------- + float or np.ndarray(float) + Unwrapped Doppler values the same format as input in (Hz) + + Raises + ------ + ValueError + For non-positive prf + TypeError: + Bad data stype for dop + + """ + raise TypeError('Unsupported data type for doppler') + + +@unwrap_doppler.register(numbers.Real) +def _unwrap_doppler_scalar(dop: float, prf: float) -> float: + """Returns single doppler as it is""" + if prf <= 0.0: + raise ValueError('prf must be a positive value') + return dop + + +@unwrap_doppler.register(np.ndarray) +def _unwrap_doppler_array(dop: np.ndarray, prf: float) -> np.ndarray: + """Unwrap doppler values stored as numpy array""" + if prf <= 0.0: + raise ValueError('prf must be a positive value') + freq2phs = 2 * np.pi / prf + phs2freq = 1.0 / freq2phs + return phs2freq*np.unwrap(freq2phs * dop) + + +@unwrap_doppler.register(cl.abc.Sequence) +def _unwrap_doppler_sequence(dop: cl.abc.Sequence, prf: float) -> np.ndarray: + """Unwrap doppler values stored as Sequence """ + if prf <= 0.0: + raise ValueError('prf must be a positive value') + freq2phs = 2 * np.pi / prf + phs2freq = 1.0 / freq2phs + return phs2freq*np.unwrap(freq2phs * np.asarray(dop)) + +# List of helper functions + + +def _sgn(x: np.ndarray) -> np.ndarray: + """Wrapper around numpy.sign. + + It replaces zero values with one. + + """ + s = np.sign(x) + s[s == 0] = 1 + return s diff --git a/python/packages/isce3/signal/fir_filter_func.py b/python/packages/isce3/signal/fir_filter_func.py new file mode 100644 index 000000000..8fcf2a724 --- /dev/null +++ b/python/packages/isce3/signal/fir_filter_func.py @@ -0,0 +1,87 @@ +""" +Generate arbitrary FIR , LPF or BPF, Filter Coefficients +""" +import numpy as np +import scipy.signal as spsg + + +def cheby_equi_ripple_filter(samprate, bandwidth, rolloff=1.2, ripple=0.1, + stopatt=40, centerfreq=0.0, force_odd_len=False): + """ + Generate an arbitrary FIR equi-ripple Chebyshev , Low Pass Filter (LPF) + or Band Pass Filter (BPF) coefficients. + + It uses 'remez' optmization algorithm for designing Chebyshev filter + with equal pass-band and stop-band ripples. + The min length of the filter is determined based on 'Kaiser' formula. + + Parameters + ---------- + samprate : float + Sampling frequency in Hz, MHz, etc. + bandwidth : float + Bandwidth in same unit as samprate + rollfoff : float, default=1.2 + Roll-off factor or shaping factor of the filter. This must be > 1.0. + ripple : float, default=0.1 + Pass-band ripples in dB. + stopatt : float, default=40.0 + Minimum Stopband attenuation in dB. + centerfreq : float, default=0.0 + Center frequency in the same unit as samprate. + force_odd_len : bool, default=False + Whether or to not to force the filter length to be an odd value. + + Returns + ------- + numpy.ndarray + Filter coefficients. + + Raises + ------ + ValueError + For bad inputs. + + """ + if samprate <= 0.0: + raise ValueError('samprate must be a positive value') + if bandwidth <= 0.0 or bandwidth >= samprate: + raise ValueError( + 'bandwidth must be a positive value less than samprate') + max_rolloff = samprate / bandwidth + if rolloff <= 1 or rolloff > max_rolloff: + raise ValueError( + 'rolloff must be a value greater than 1 and equal or less' + f' than {max_rolloff}' + ) + if ripple <= 0: + raise ValueError('rippler must be a positive value') + if stopatt <= 0: + raise ValueError('stopatt must be a positive value') + + # LPF params + delta_pas = 10.**(ripple/20.) - 1 + delta_stp = 10.**(-stopatt/20.) + weight_fact = delta_pas / delta_stp + max_rolloff = samprate / bandwidth + + # get LPF length + fstop = rolloff * bandwidth + deltaf = (fstop - bandwidth) / samprate / 2.0 + len_flt = np.int_(np.ceil((-20. * np.log10(np.sqrt(delta_stp*delta_pas)) + - 13.) / 14.6 / deltaf) + 1) + + if (force_odd_len and len_flt % 2 == 0): + len_flt += 1 + + # get LPF coeffs + coeffs = spsg.remez(len_flt, 0.5 / samprate + * np.array([0, bandwidth, fstop, samprate]), + np.array([1.0, 0.0]), np.array([1, weight_fact]), + Hz=1, type='bandpass', maxiter=50) + + # up/down conversion + if abs(centerfreq) > 0.0: + return coeffs * np.exp(2j * np.pi * centerfreq / samprate * + np.arange(len_flt)) + return coeffs diff --git a/python/packages/isce3/unwrap/snaphu.py b/python/packages/isce3/unwrap/snaphu.py index 76f7acf3a..502c9489c 100644 --- a/python/packages/isce3/unwrap/snaphu.py +++ b/python/packages/isce3/unwrap/snaphu.py @@ -3,25 +3,13 @@ import pathlib import tempfile from dataclasses import dataclass -from typing import Literal, Optional, Union +from typing import Optional, Union import isce3 import numpy as np from pybind_isce3.unwrap import _snaphu_unwrap -TransmitMode = Literal["pingpong", "repeat_pass", "single_antenna_transmit"] -TransmitMode.__doc__ = """Radar transmit mode - - 'pingpong' and 'repeat_pass' modes indicate that both antennas both - transmitted and received. Both modes have the same effect in the algorithm. - - 'single_antenna_transmit' indicates that a single antenna was used to - transmit while both antennas received. In this mode, the baseline is - effectively halved. - """ - - @dataclass(frozen=True) class TopoCostParams: r"""Configuration parameters for SNAPHU "topo" cost mode @@ -159,7 +147,7 @@ class TopoCostParams: range_res: float az_res: float wavelength: float - transmit_mode: TransmitMode + transmit_mode: str altitude: float earth_radius: float = 6_378_000.0 kds: float = 0.02 @@ -656,7 +644,7 @@ def tostring(self): @contextlib.contextmanager -def scratch_directory(d: Optional[os.PathLike] = None, /) -> pathlib.Path: +def scratch_directory(d: Optional[os.PathLike] = None) -> pathlib.Path: """Context manager that creates a (possibly temporary) filesystem directory If the input is a path-like object, a directory will be created at the @@ -772,9 +760,6 @@ def from_flat_file( raster.data[i0:i1] = mmap[i0:i1] -CostMode = Literal["topo", "defo", "smooth", "p-norm"] -CostMode.__doc__ = """SNAPHU cost mode options""" - CostParams = Union[ TopoCostParams, DefoCostParams, SmoothCostParams, PNormCostParams, ] @@ -787,8 +772,9 @@ def unwrap( igram: isce3.io.gdal.Raster, corr: isce3.io.gdal.Raster, nlooks: float, - cost: CostMode = "smooth", + cost: str = "smooth", cost_params: Optional[CostParams] = None, + init_method: str = "mcf", pwr: Optional[isce3.io.gdal.Raster] = None, mask: Optional[isce3.io.gdal.Raster] = None, unwest: Optional[isce3.io.gdal.Raster] = None, @@ -881,6 +867,10 @@ def unwrap( Configuration parameters for the specified cost mode. This argument is required for "topo" mode and optional for all other modes. If None, the default configuration parameters are used. (default: None) + init_method: {"mst", "mcf"}, optional + Algorithm used for initialization of unwrapped phase gradients. + Supported algorithms include Minimum Spanning Tree ("mst") and Minimum + Cost Flow ("mcf"). (default: "mcf") pwr : isce3.io.gdal.Raster or None, optional Average intensity of the two SLCs, in linear units (not dB). Only used in "topo" cost mode. If None, interferogram magnitude is used as @@ -986,8 +976,14 @@ def cost_string(): configstr += f"STATCOSTMODE {cost_string()}\n" - # XXX Currently, only "MST" initialization method is supported. - configstr += "INITMETHOD MST\n" + def init_string(): + if init_method == "mst": + return "MST" + if init_method == "mcf": + return "MCF" + raise ValueError(f"invalid init method '{init_method}'") + + configstr += f"INITMETHOD {init_string()}\n" # Check cost mode-specific configuration params. if cost == "topo": diff --git a/python/packages/nisar/products/readers/Base/Base.py b/python/packages/nisar/products/readers/Base/Base.py index 76fbf0bb0..046712003 100644 --- a/python/packages/nisar/products/readers/Base/Base.py +++ b/python/packages/nisar/products/readers/Base/Base.py @@ -232,6 +232,11 @@ def RootPath(self): self.filename) return self._RootPath + @property + def sarBand(self): + """SAR band string such as 'L' or 'S' for NISAR.""" + return self.RootPath[-4] + @property def IdentificationPath(self): return os.path.join(self.RootPath, self._IdentificationPath) diff --git a/python/packages/nisar/products/readers/Raw/DataDecoder.py b/python/packages/nisar/products/readers/Raw/DataDecoder.py index 45e56aa82..ffcec2699 100644 --- a/python/packages/nisar/products/readers/Raw/DataDecoder.py +++ b/python/packages/nisar/products/readers/Raw/DataDecoder.py @@ -28,6 +28,9 @@ def __init__(self, h5dataset): self.decoder = lambda key: self.dataset[key] self.dataset = h5dataset self.shape = self.dataset.shape + self.ndim = self.dataset.ndim + self.dtype = np.dtype('c8') + self.dtype_storage = self.dataset.dtype group = h5dataset.parent if "BFPQLUT" in group: assert group["BFPQLUT"].dtype == np.float32 diff --git a/python/packages/nisar/products/readers/Raw/Raw.py b/python/packages/nisar/products/readers/Raw/Raw.py index 486c11131..9858e6812 100644 --- a/python/packages/nisar/products/readers/Raw/Raw.py +++ b/python/packages/nisar/products/readers/Raw/Raw.py @@ -130,6 +130,26 @@ def getChirpParameters(self, frequency: str = 'A', tx: str = 'H'): fc = self.getCenterFrequency(frequency, tx) return fc, fs, K, T + def getRangeBandwidth(self, frequency: str = 'A', tx: str = 'H'): + """Get RF bandwidth of a desired TX frequency band and pol. + + Parameters + ---------- + frequency : {'A', 'B'}, optional + Sub-band + tx : {'H', 'V', 'L', 'R'}, optional + Transmit polarization + + Returns + ------- + float + Bandwidth in Hz. + + """ + tx_path = self._pulseMetaPath(frequency=frequency, tx=tx) + with h5py.File(self.filename, 'r', libver='latest', swmr=True) as f: + return f[tx_path]["rangeBandwidth"][()] + @property def TelemetryPath(self): return f"{self.ProductPath}/lowRateTelemetry" @@ -189,6 +209,51 @@ def getPulseTimes(self, frequency='A', tx='H'): epoch = isce3.io.get_ref_epoch(f[txpath], name) return epoch, t + def getNominalPRF(self, frequency='A', tx='H'): + """Nominal PRF defined as mean PRF for dithered case. + + Parameters + ---------- + frequency : {'A', 'B'}, optional + Sub-band. Typically main science band is 'A'. + + tx : {'H', 'V', 'L', 'R'} + Transmit polarization. Abbreviations correspond to horizontal + (linear), vertical (linear), left circular, right circular + + Returns + ------- + float + PRF in Hz. + + """ + _, az_time = self.getPulseTimes(frequency, tx) + return (az_time.size - 1) / (az_time[-1] - az_time[0]) + + def isDithered(self, frequency='A', tx='H'): + """Whether or not PRF is dithering. + + That is more than one PRF value within entire azimuth duration. + + Parameters + ---------- + frequency : {'A', 'B'}, optional + Sub-band. Typically main science band is 'A'. + + tx : {'H', 'V', 'L', 'R'} + Transmit polarization. Abbreviations correspond to horizontal + (linear), vertical (linear), left circular, right circular + + Returns + ------- + bool + True if multiple PRF values and False if PRF is fixed. + + """ + _, az_time = self.getPulseTimes(frequency, tx) + tm_diff = np.diff(az_time) + return not np.isclose(tm_diff.min(), tm_diff.max()) + def getCenterFrequency(self, frequency: str = 'A', tx: str = None): if tx is None: diff --git a/python/packages/nisar/workflows/__init__.py b/python/packages/nisar/workflows/__init__.py index e69de29bb..3d60589b8 100644 --- a/python/packages/nisar/workflows/__init__.py +++ b/python/packages/nisar/workflows/__init__.py @@ -0,0 +1 @@ +from .doppler_lut_from_raw import doppler_lut_from_raw diff --git a/python/packages/nisar/workflows/doppler_lut_from_raw.py b/python/packages/nisar/workflows/doppler_lut_from_raw.py new file mode 100644 index 000000000..31c3be1ad --- /dev/null +++ b/python/packages/nisar/workflows/doppler_lut_from_raw.py @@ -0,0 +1,628 @@ +""" +Function to generate Doppler LUT2d from Raw L0B data. +""" +import logging +import os +import numpy as np +from scipy import fft +try: + from matplotlib import pyplot as plt +except ImportError: + plt = None + +from isce3.signal import (cheby_equi_ripple_filter, corr_doppler_est, + sign_doppler_est, unwrap_doppler) +from isce3.core import LUT2d + + +def doppler_lut_from_raw(raw_obj, freq_band='A', txrx_pol=None, + num_rgb_avg=16, az_block_dur=4.0, time_interval=2.0, + dop_method='CDE', subband=False, + polyfit_deg=3, polyfit=False, out_path='.', + plot=False, logger=None): + """Generates 2-D Doppler LUT as a function of slant range and azimuth time. + + It generates Doppler map in isce3.core.LUT2d format. + It optionally generates Doppler plots as a function of + slant ranges at various azimuth times stored in PNG files. + For algorithms + See references [GHAEMI2018]_, [MADSEN1989]_, [BAMLER1991]_. + + The subbanding is a joint time-frequency approach where three frequency + bands lower, mid, and upper part of the echo are individually used in + time-domain doppler correlator estimator and then a linear regression is + applied to the three doppler values as a function of frequency. Finally, + the doppler is evaluated at the center frequency of the band from the + first-degree polyfit coefficients. + + To generate three sub-bands, a FIR Chebyshev Equi-rippler low-pass filter + is designed. This filter is up/down converted to perform band-pass + filtering of lower/upper part of the band. + + In case of polyfit, the Doppler at invalid range bins will be replaced by + poly evaluated ones and thus the final respective valid mask will all set + to True. That is no invalid range bins will be reported! + + Parameters + ---------- + raw_obj : nisar.products.readers.Raw.RawBase + Raw L0B product parser base object + freq_band : {'A', 'B'} + Frequency band in multi-band TX chirp. + txrx_pol : str, optional + TxRx polarization such as {'HH', 'HV',...}. If not provided the first + product under `freq_band` will be used. + num_rgb_avg : int, default=16 + Number of range bins to be averaged in final Doppler values. + az_block_dur : float, default=4.0 + Azimuth block duration in seconds defining time-domain correlator + length used in Doppler estimator. + time_interval : float, default=2.0 + Time stamp interval between azimuth blocks in seconds. + It should not be larger than "az_block_dur". + dop_method : {'CDE', 'SDE'} + Correlator-based time-domain Doppler estimator method, either of + Correlation Doppler Estimator ('CDE') or Sign-Doppler estimator ('SDE') + See [MADSEN1989]_. These methods used as a base method + in subbanding time-frequency approach if requested via `subband`. + See [BAMLER1991]_ and [GHAEMI2018]_. + subband : bool, default=False + Whether or not use sub-banding frequency approach on top of correlator + one in Doppler estimation. + polyfit_deg : int, default=3 + Polyfit degree used in Doppler plots for polyfitting of doppler as a + function of slant ranges and its statistical mean/std variation over + the swath. If "polyfit" flag set to True, the polyfitted version of + the estimated dopplers in slant range will be used as the final 2-D + LUT product! + The polyfitting will be done over valid range bins per azimuth block! + polyfit : bool, default=False + If is True, then polyfitted Doppler product with degree "polyfit_deg" + will be used in place of estimated one as a function of slant range per + azimuth block. + out_path : str, default='.' + Ouput directory for dumping PNG files, if `plot` is True. + plot : bool, default=False + If True, it will generate bunch of .png plots of both True and + poly-fitted Doppler centroid as a function of slant ranges per azimuth + block. The polyfit degree used in plotting is 3 if not set by the + `polyfit_deg`! + logger : logging.Logger, optional + If not provided a longger with StreamHandler will be set. + + Notes + ----- + PRF must be constant. Dithered PRF is not supported. + The LUT2d product requires at least two blocks in each directions. + In case of polyfit, the number of valid range bins must be larger than + (polyfit_deg * num_rgb_avg). + NISAR Non-science multi-channel aka diagnostic mode # 2 (DM2) L0B product + is not supported. Simply single-channel SAR (non-NISAR) or composite DBFed + SAR data (NISAR science mode) are supported. + + Returns + ------- + isce3.core.LUT2d + Doppler values (Hz) as a function of `x=`slant range (m) and + `y=`azimuth/pulse time (sec) + isce3.core.DateTime + Reference epoch UTC time for azimuth/pulse times + np.ndarray(bool) + Mask array for valid averaged range bins + np.ndarray(float32) + Correlation coefficients within [0,1] + str + TxRx polarization of the product + float + Center frequency of the `freq_band` in (Hz) + np.ndarray or None + Prototype filter coeffs centered at chirp center frequency (LPF) + if `subband=True`, otherwise None. + + Raises + ------ + ValueError + For bad input parameters or non-existent polarization and/or + frequency band. + RuntimeError + For dithered PRF. + Less than 2 azimuth blocks. + Too many invalid range bins w.r.t polyfit degree in case of polyfit. + NotImplementedError + For non-science multi-channel aka diagnostic mode # 2 (DM2). + + References + ---------- + .. [GHAEMI2018] H. Ghaemi and S. Durden, 'Pointing Estimation Algorithms + and Simulation Results', JPL Report, February 2018. + .. [MADSEN1989] S. Madsen, 'Estimating The Doppler Centroid of SAR Data', + IEEE Transaction On Aerospace and Elect Sys, March 1989. + .. [BAMLER1991] R. Bamler and H. Runge, 'PRF-Ambiguity Resolving by + Wavelength Diversity', IEEE Transaction on GeoSci and Remote Sensing, + November 1991. + + """ + # List of Constants + num_subband = 3 + ripple_flt = 0.2 # passband ripple of subband filter (dB) + rolloff_flt = 1.25 # roll-off of subband filter + stopatt_flt = 27.0 # stop-band attenuation of subband filter (dB) + # check inputs + if polyfit_deg < 1: + raise ValueError('polyfit_deg must be greater than 0') + if az_block_dur <= 0.0: + raise ValueError('az_block_dur must be a positive value') + if (time_interval <= 0.0 or time_interval > az_block_dur): + raise ValueError( + 'time_interval must be a positive value less than az_block_dur') + if num_rgb_avg < 1: + raise ValueError('Number of range bins must be a positive value') + # set logger + if logger is None: + logger = set_logger("DopplerLUT") + + # check if there is matplotlib package needed for plotting if requested + if plot: + if plt is None: + logger.warning('No plots due to missing package "matplotlib"!') + plot = False + + # Check frequency band + if freq_band not in raw_obj.polarizations: + raise ValueError( + 'Wrong frequency band! The available bands -> ' + f'{list(raw_obj.polarizations)}' + ) + logger.info(f"Frequency band -> '{freq_band}'") + # check for txrx_pol + list_txrx_pols = raw_obj.polarizations[freq_band] + if txrx_pol is None: + txrx_pol = list_txrx_pols[0] + elif txrx_pol not in list_txrx_pols: + raise ValueError( + f'Wrong TxRx polarization! The available ones -> {list_txrx_pols}') + logger.info(f"TxRx Pol -> '{txrx_pol}'") + + # Get chirp parameters + centerfreq, samprate, _, pulsewidth = \ + raw_obj.getChirpParameters(freq_band, txrx_pol[0]) + + bandwidth = raw_obj.getRangeBandwidth(freq_band, txrx_pol[0]) + + # Get Pulse/azimuth time and ref epoch + epoch_utc, az_time = raw_obj.getPulseTimes(freq_band, + txrx_pol[0]) + epoch_utc_str = epoch_utc.isoformat() + # get PRF and check for dithering + prf = raw_obj.getNominalPRF(freq_band, txrx_pol[0]) + dithered = raw_obj.isDithered(freq_band, txrx_pol[0]) + pri = 1. / prf + + if dithered: + raise RuntimeError("Dithered PRF is not supported!") + logger.info(f'Fast-time sampling rate -> {samprate * 1e-6:.2f} (MHz)') + logger.info(f'Chirp bandwidth -> {bandwidth * 1e-6:.2f} (MHz)') + logger.info(f'Chirp pulsewidth -> {pulsewidth * 1e6:.2f} (us)') + logger.info(f'Chirp center frequency -> {centerfreq * 1e-6:.2f} (MHz)') + logger.info(f'PRF -> {prf:.3f} (Hz)') + + # Get raw dataset + raw_dset = raw_obj.getRawDataset(freq_band, txrx_pol) + if raw_dset.ndim > 2: + raise NotImplementedError( + 'Multi-channel Raw echo aka Diagnostic Mode 2 ' + '(DM2) has not supported yet!' + ) + tot_pulses, tot_rgbs = raw_dset.shape + logger.info( + f'Shape of the echo data (pulses, ranges) -> {tot_pulses, tot_rgbs}') + + # blocksize in range + if num_rgb_avg > (tot_rgbs // 2): + raise ValueError( + 'Number of range bins to be averaged must be equal or less than ' + f'{tot_rgbs // 2} to result in at least 2 range blocks!' + ) + logger.info(f'Number of range bins per range block -> {num_rgb_avg}') + num_blk_rg = tot_rgbs // num_rgb_avg + logger.info(f'Number of range blocks -> {num_blk_rg}') + + # Get prototype LPF coeffs if subbanding is requested + coeff_lpf = None + if subband: + logger.info("Perform sub-banding on echo data!") + logger.info(f'Number of subbands -> {num_subband}') + bw_flt = bandwidth / num_subband + coeff_lpf = cheby_equi_ripple_filter(samprate, bw_flt, rolloff_flt, + ripple_flt, stopatt_flt, + force_odd_len=True) + len_flt = len(coeff_lpf) + logger.info( + 'Subbanding filter passband bandiwdth -> ' + f'{bw_flt * 1e-6:.2f} (MHz)' + ) + logger.info( + f'Subbanding filter passband ripple -> {ripple_flt:.2f} (dB)') + logger.info( + f'Subbanding filter stopband attenuaton -> {stopatt_flt:.2f} (dB)') + logger.info(f'Subbanding filter rolloff factor -> {rolloff_flt}') + logger.info(f'Length of subband filter -> {len_flt}') + + # convolution length and group delay + len_conv = tot_rgbs + len_flt - 1 + grp_del = len_flt // 2 + + # Get number of FFT and total group delay caused by rgcomp + subband + nfft = fft.next_fast_len(len_conv) + logger.info( + f'Number of FFT points in rangecomp and/or subbanding -> {nfft}') + + # Get FFT of the prototype LPF + coeff_lpf_fft = fft.fft(coeff_lpf, nfft) + slice_grp_del = slice(grp_del, grp_del + tot_rgbs) + + # Calculate center frequencies for only three bands:first, mid,and last + fcnt_first = (1 - num_subband) / (2. * num_subband) * bandwidth + fcnt_last = -fcnt_first + fcnt_subbands = [fcnt_first, 0.0, fcnt_last] + fcnt_rf_subbands = np.asarray(fcnt_subbands) + centerfreq + logger.info( + 'The RF center freq of subbands -> ' + '({:.2f}, {:.2f}, {:.2f}) (MHz)'.format(*(fcnt_rf_subbands * 1e-6)) + ) + + # Mixer func for up/down conversion of LPF -> BPF + def mixer_fun(fc): + return np.exp( + 1j * 2.0 * np.pi * fc / samprate * np.arange(len_flt)) + + # Get freq-domain BPFs Coeffs for two edge bands from LPF prototype + coef_bpf_fft_first = fft.fft(coeff_lpf * mixer_fun(fcnt_first), nfft) + coef_bpf_fft_last = fft.fft(coeff_lpf * mixer_fun(fcnt_last), nfft) + + # plot three suband BPF in frequency domain + if plot: + plt_name = f'Subband_Filter_Plot_Freq{freq_band}_Pol{txrx_pol}.png' + name_plot = os.path.join(out_path, plt_name) + _plot_subband_filters(samprate, centerfreq, coef_bpf_fft_first, + coeff_lpf_fft, coef_bpf_fft_last, name_plot) + + # get complex echo for all range bins but limited pulses + dop_method = dop_method.upper() + logger.info( + f'Doppler estimator method per block and per band -> {dop_method}') + # form a generic doppler estimator function covering both methods + if dop_method == 'SDE': + def time_dop_est(echo, prf, lag=1, axis=None): + return sign_doppler_est(echo, prf, lag, axis), 1 + elif dop_method == 'CDE': + time_dop_est = corr_doppler_est + else: + raise ValueError( + f'Unexpected time-domain Doppler method "{dop_method}"') + + # Get slant ranges per range block centered at each block + sr_lsp = raw_obj.getRanges(freq_band, txrx_pol[0]) + sr_spacing = sr_lsp.spacing * num_rgb_avg + sr_start = sr_lsp.first + 0.5 * sr_spacing + sr_stop = sr_start + (num_blk_rg - 1) * sr_spacing + slrg_per_blk = np.linspace(sr_start, sr_stop, num=num_blk_rg) + + # form the blocks of range lines / azimuth bins + len_az_blk_dur, len_tm_int, num_blk_az = _get_az_block_interval_len( + tot_pulses, az_block_dur, prf, time_interval) + + logger.info( + f'Final full azimuth block duration -> {len_az_blk_dur/prf:.3f} (sec)') + logger.info( + f'Number of range lines of a full azimuth block -> {len_az_blk_dur}') + logger.info( + f'Time interval between azimuth blocks -> {len_tm_int/prf:.3f} (sec)') + logger.info( + 'Number of range line seperation between azimuth blocks -> ' + f'{len_tm_int}' + ) + logger.info(f'Total number of azimuth blocks -> {num_blk_az}') + + slice_lines = _azblk_slice_gen( + tot_pulses, len_az_blk_dur, len_tm_int, num_blk_az) + + # parse valid subswath index for all range lines used later + valid_sbsw_all = raw_obj.getSubSwaths(freq_band, txrx_pol[0]) + # initialized output mask array for averaged range bins for + # all azimuth blocks + mask_rgb_avg_all = np.zeros((num_blk_az, num_blk_rg), dtype='bool') + + # initialize correlator coeff for all range bins and azimuth blocks + corr_coef = np.ones((num_blk_az, num_blk_rg), dtype='float32') + + # initialize the azimuth time block and set an intermediate var + half_az_blk_dur = (len_az_blk_dur - 1) / 2 + az_time_blk = np.full(num_blk_az, az_time[0] + half_az_blk_dur * pri, + dtype=float) + tm_int_pri_prod = len_tm_int * pri + + # doppler centroid map is azimuth block by slant-range block + dop_cnt_map = np.zeros((num_blk_az, num_blk_rg), dtype='float32') + + # loop over azimuth blocks /range line blocks + for n_azblk, slice_line in enumerate(slice_lines): + num_lines = slice_line.stop - slice_line.start + logger.info( + f'(start, stop) of AZ block # {n_azblk + 1} -> ' + f'{slice_line.start, slice_line.stop}' + ) + logger.info( + 'Block size (lines, ranges) for Doppler estimation -> ' + f'({num_lines, num_rgb_avg})' + ) + + # get decoded raw echoes of one azimuth block and for all range bins + echo = raw_dset[slice_line] + + # create a mask for invalid/bad range bins for any reason + # invalid values are either nan or zero but this does not include + # TX gaps that may be filled with TX chirp! + mask_bad = (np.isnan(echo) | (echo == 0.0)).sum(axis=0) > 0 + + # build a mask array of range bins assuming fixed PRF within + # each azimuth block. This is needed in case the TX gaps are filled + # with TX chirp rather than invalid/bad value! + mask_valid_rgb = _form_mask_valid_range( + tot_rgbs, valid_sbsw_all[:, slice_line.start, :]) + mask_valid_rgb &= _form_mask_valid_range( + tot_rgbs, valid_sbsw_all[:, slice_line.stop - 1, :]) + + # Update valid mask with invalid range bins over all range lines + mask_valid_rgb[mask_bad] = False + + # decimate the range bins mask to fill in mask for averaged range bins + # per azimuth block. Make sure a valid averaged block contains all + # valid range bins otherwise set to invalid. + mask_rgb_avg_all[n_azblk, :] = mask_valid_rgb[ + :num_blk_rg * num_rgb_avg].reshape((num_blk_rg, num_rgb_avg)).sum( + axis=1) == num_rgb_avg + + # azimuth time at mid part of the azimuth block + az_time_blk[n_azblk] += n_azblk * tm_int_pri_prod + + # form mask for NaN values in echo and replace it with 0 + echo[np.isnan(echo)] = 0.0 + + if subband: + echo_sub_first = np.zeros(echo.shape, dtype=echo.dtype) + echo_sub_last = np.copy(echo_sub_first) + # Loop over range lines for one azimuth block + for line in range(num_lines): + # apply subband BPF in freq domain + rgc_line_fft = fft.fft(echo[line, :], nfft) + # first band + rgc_line_fft_edge = rgc_line_fft * coef_bpf_fft_first + echo_sub_first[line, :] = fft.ifft( + rgc_line_fft_edge)[slice_grp_del] + # last band + rgc_line_fft_edge = rgc_line_fft * coef_bpf_fft_last + echo_sub_last[line, :] = fft.ifft( + rgc_line_fft_edge)[slice_grp_del] + # mid band + rgc_line_fft *= coeff_lpf_fft + # go back to time and get rid of all group delays + echo[line, :] = fft.ifft(rgc_line_fft)[slice_grp_del] + + # estimate doppler per band, per azimuth block over all range blocks + dop_cnt = np.zeros(num_blk_rg, dtype="float32") + for n_blk in range(num_blk_rg): + slice_rgb = slice(n_blk * num_rgb_avg, (n_blk + 1) * num_rgb_avg) + # CDE or SDE + dop_cnt[n_blk], corr_coef[n_azblk, n_blk] = time_dop_est( + echo[:, slice_rgb], prf) + + if subband: + dop_cnt_bands = np.zeros((3, num_blk_rg), dtype="float32") + dop_cnt_bands[1, :] = dop_cnt + for n_blk in range(num_blk_rg): + slice_rgb = slice(n_blk * num_rgb_avg, + (n_blk + 1) * num_rgb_avg) + # CDE or SDE for each subband + dop_cnt_bands[0, n_blk], corr_coef_low = time_dop_est( + echo_sub_first[:, slice_rgb], prf) + dop_cnt_bands[-1, n_blk], corr_coef_high = time_dop_est( + echo_sub_last[:, slice_rgb], prf) + # sum correlation coeff among all three bands + corr_coef[n_azblk, n_blk] += (corr_coef_low + corr_coef_high) + # average correlation coeff among all three bands + corr_coef[n_azblk, :] /= 3.0 + + # perform doppler unwrapping over three bands + dop_cnt_bands = unwrap_doppler(dop_cnt_bands, prf) + + # perform linear (1st degree) polyfit over 3 bands for + # all range blocks + # IF version: np.polyfit(fcnt_subbands, dop_cnt_bands, 1) + pf_coef_subbands = np.polyfit(fcnt_rf_subbands, dop_cnt_bands, 1) + + # eval doppler centroid at the center freq of the chirp + # IF version: pf_coef_subbands[1, :] + dop_cnt = np.polyval(pf_coef_subbands, centerfreq) + + # collect doppler centroid vectors + if polyfit: # replace actual value by polyfitted ones + sr_valid = slrg_per_blk[mask_rgb_avg_all[n_azblk, :]] + # check if the number of valid range blocks > polyfit_deg + if sr_valid.size <= polyfit_deg: + raise RuntimeError( + 'Too many bad range bins! Polyfit requires at least ' + f'{polyfit_deg + 1} valid range blocks or ' + f'{(polyfit_deg + 1) * num_rgb_avg} valid range bins!' + ) + dop_cnt_valid = dop_cnt[mask_rgb_avg_all[n_azblk, :]] + pf_coef_dop_cnt = np.polyfit(sr_valid, dop_cnt_valid, polyfit_deg) + dop_cnt_map[n_azblk, :] = np.polyval(pf_coef_dop_cnt, slrg_per_blk) + # given estimation of invalid range bins from polyfit, + # set the mask to be all True after polyeval! + mask_rgb_avg_all[n_azblk, :] = True + else: # keep the actual values + dop_cnt_map[n_azblk, :] = dop_cnt + + # plot Doppler centroid per azimuth block + if plot: + _plot_save_dop(n_azblk, slrg_per_blk, dop_cnt, az_time_blk, + epoch_utc_str, out_path, freq_band, txrx_pol, + polyfit_deg, mask_rgb_avg_all[n_azblk, :]) + + # form Doppler LUT2d object + dop_lut = LUT2d(slrg_per_blk, az_time_blk, dop_cnt_map) + + return dop_lut, epoch_utc, mask_rgb_avg_all, corr_coef, txrx_pol, \ + centerfreq, coeff_lpf + + +def set_logger(name: str) -> logging.Logger: + """set logger""" + logger = logging.getLogger(name) + + if not logger.handlers: + logger.setLevel(logging.DEBUG) + log_hdl = logging.StreamHandler() + log_hdl.setLevel(logging.DEBUG) + log_fmt = logging.Formatter( + fmt="%(asctime)s : %(name)s : %(levelname)s : %(message)s", + datefmt="%Y-%m-%dT%H:%M:%S") + log_hdl.setFormatter(log_fmt) + logger.addHandler(log_hdl) + + return logger + +# list of private helper functions: + + +def _plot_subband_filters(samprate: float, centerfreq: float, + coef_bpf_fft_first: np.ndarray, + coeff_lpf_fft: np.ndarray, + coef_bpf_fft_last: np.ndarray, + name_plot: str): + """Plot spectrum of three subbands filters""" + # form RF frequency vector + nfft = coef_bpf_fft_first.size + min_rf_freq = centerfreq - 0.5 * samprate + freq = min_rf_freq + (samprate / nfft) * np.arange(nfft) + freq *= 1e-6 # (MHz) + def amp2db_fft(amp): return 20 * np.log10(abs(fft.fftshift(amp))) + plt.figure() + plt.plot(freq, amp2db_fft(coef_bpf_fft_first), 'b', + freq, amp2db_fft(coeff_lpf_fft), 'g', + freq, amp2db_fft(coef_bpf_fft_last), 'r', + linewidth=2) + plt.legend(['First', 'Mid', 'Last'], loc='best') + plt.xlabel('RF Frequency (MHz)') + plt.ylabel('Magnitude (dB)') + plt.title('Spectrum of the three subband filters') + plt.ylim([-50.0, 1.0]) + plt.grid(True) + plt.savefig(name_plot) + plt.close() + + +def _plot_save_dop(n_azblk: int, slrg_per_blk: np.ndarray, dop_cnt: np.ndarray, + az_time_blk: np.ndarray, epoch_utc_str: str, out_path: str, + freq_band: str, txrx_pol: str, polyfit_deg: int, + mask_valid_rgb: np.ndarray): + """Plot Doppler as a function Slant range and save it as PNG file""" + fig = plt.figure(n_azblk) + ax = fig.add_subplot(111) + # only polyfit over the valid range blocks! + pf_coeff_dop_rg = np.polyfit(slrg_per_blk[mask_valid_rgb], + dop_cnt[mask_valid_rgb], polyfit_deg) + pv_dop_rg = np.polyval(pf_coeff_dop_rg, slrg_per_blk) + slrg_km = slrg_per_blk * 1e-3 + ax.plot(slrg_km, dop_cnt, 'r*--', slrg_km, pv_dop_rg, 'b--') + ax.legend(["Echo", f"PF(order={polyfit_deg})"], loc='best') + diff_dop_pf = dop_cnt - pv_dop_rg + plt_textstr = '\n'.join(( + 'Deviation from PF:', + r'$\mathrm{MEAN}$=%.1f Hz' % diff_dop_pf.mean(), + r'$\mathrm{STD}$=%.1f Hz' % diff_dop_pf.std())) + plt_props = dict(boxstyle='round', facecolor='green', alpha=0.5) + ax.text(0.5, 0.2, plt_textstr, transform=ax.transAxes, fontsize=10, + horizontalalignment='center', verticalalignment='center', + bbox=plt_props) + ax.grid(True) + ax.set_title( + 'Doppler Centroids\n@ azimuth-time = ' + f'{az_time_blk[n_azblk]:.3f} sec\nsince {epoch_utc_str}' + ) + ax.set_ylabel("Doppler (Hz)") + ax.set_xlabel("Slant Range (Km)") + fig.savefig(os.path.join( + out_path, 'Doppler_SlantRange_Plot_Freq' + f'{freq_band}_Pol{txrx_pol}_AzBlock{n_azblk + 1}.png' + )) + + +def _get_az_block_interval_len(num_pls: int, az_block_dur: float, prf: float, + time_interval: float): + """Get block size and interval lengths for azimuth blocks + + Returns + ------- + int + length of a full azimuth block + int + length of time interval + int + total number of blocks , full + partial (last block). + + """ + # time interval shall be equal ot less than block duration + time_interval = min(time_interval, az_block_dur) + # make sure the min time interval is one PRI! + len_tm_int = max(int(time_interval * prf), 1) + len_az_blk_dur = int(az_block_dur * prf) + # if the block_dur + interval is too large then raise an exception! + if (len_tm_int + len_az_blk_dur) > num_pls: + raise ValueError( + 'Sum of azimuth block duration and time interval is large than ' + f'echo duration {(num_pls - 1) / prf} (sec)!' + ) + # get number of blocks which must be at least 2! + num_blk_az = int(np.ceil((num_pls - len_az_blk_dur) / len_tm_int)) + 1 + if num_blk_az < 2: + raise RuntimeError( + 'At least two azimuth blocks are required to form LUT2d! Try to ' + 'reduce time interval!' + ) + return len_az_blk_dur, len_tm_int, num_blk_az + + +def _azblk_slice_gen(num_pls: int, len_az_blk_dur: int, len_tm_int: int, + num_blk_az: int): + """Slice index generator for azimuth blocks/range lines""" + # generate slice for azimuth block indexing + i_str = 0 + i_stp = len_az_blk_dur + for bb in range(num_blk_az): + yield slice(i_str, i_stp) + i_str += len_tm_int + i_stp = min(i_str + len_az_blk_dur, num_pls) + + +def _form_mask_valid_range(tot_rgbs, rgb_valid_sbsw): + """Form valid mask for range bins for a specific range line. + + Parameters + ---------- + tot_rgbs : int + Total number of range bins + rgb_valid_sbsw : np.ndarray(np.ndarray(int)) + 2-D array-like integers for valid range bins of a specific range line + + Returns + ------- + np.ndarray(bool) + Mask array for valid range bins + + """ + msk_valid_rg = np.zeros(tot_rgbs, dtype=bool) + for start_stop in rgb_valid_sbsw: + msk_valid_rg[slice(*start_stop)] = True + return msk_valid_rg diff --git a/python/packages/nisar/workflows/gen_doppler_range_product.py b/python/packages/nisar/workflows/gen_doppler_range_product.py new file mode 100755 index 000000000..f50a217ff --- /dev/null +++ b/python/packages/nisar/workflows/gen_doppler_range_product.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +""" +Generate Doppler Centroid product from L0B data +""" +import os +import time +import argparse as argp +import numpy as np +from datetime import datetime + +from nisar.workflows import doppler_lut_from_raw +from nisar.workflows.doppler_lut_from_raw import set_logger +from nisar.products.readers.Raw import open_rrsd +from isce3.core import TimeDelta, Linspace +from nisar.products.readers.antenna import AntennaParser + + +def cmd_line_parser(): + """Parse command line input arguments. + + Notes + ----- + It also allows parsing arguments via an ASCII file + by using prefix char "@". + + Returns + ------- + argparse.Namespace + + """ + prs = argp.ArgumentParser( + description='Estimate Doppler centroid from L0B raw echo and creates ' + 'a 2-D Doppler LUT dumped into a CSV file', + fromfile_prefix_chars="@", + formatter_class=argp.ArgumentDefaultsHelpFormatter + ) + prs.add_argument('filename_l0b', type=str, + help='Filename of HDF5 L0B product') + prs.add_argument('-antenna_file', type=str, dest='antenna_file', + help='Filename of HDF5 Antenna product used to extract ' + 'averaged azimuth angle for EL cuts of TX + RX pol of ' + 'first beam. If not provided, the azimuth angle is ' + 'assumed to be zero!') + prs.add_argument('-f', '--freq', type=str, choices=['A', 'B'], default='A', + dest='freq_band', help='Frequency band such as "A".') + prs.add_argument('-p', '--pol', type=str, dest='txrx_pol', + choices=["HH", "VV", "HV", "VH"], + help='TxRx Polarization such as "HH". Default is the ' + 'first pol in the specified frequency band') + prs.add_argument('-r', '--rgb', type=int, dest='num_rgb_avg', default=16, + help='Number of range bins to be averaged in Doppler ' + 'Estimator block. Shall be equal or larger than 1.') + prs.add_argument('-a', '--az_block_dur', type=float, dest='az_block_dur', + default=4.0, + help='Azimuth block duration in seconds defining time-' + 'domain correlator length used in Doppler estimator.') + prs.add_argument('-t', '--time_interval', type=float, dest='time_interval', + default=2.0, + help='Time stamp interval between azimuth blocks in ' + 'seconds. Must not be larger than "az_block_dur".') + prs.add_argument('-m', '--method', type=str, dest='dop_method', + default='CDE', choices=['SDE', 'CDE'], + help='Time-domain Doppler estimator methods "CDE"/"SDE"' + ' which are Correlator/Sign Doppler Estimator.') + prs.add_argument('--subband', action='store_true', dest='subband', + help='Perform fast-time frequency subbanding on top of ' + 'time-domain correlator in Doppler estimator') + prs.add_argument('-d', '--deg', type=int, dest='polyfit_deg', + default=3, help='Degree of the polyfit.') + prs.add_argument('--polyfit', action='store_true', dest='polyfit', + help='If set, it will replace actual estimated doppler ' + 'by its polyfitted ones in slant range.') + prs.add_argument('--plot', action='store_true', dest='plot', + help='Plot Doppler centroids and save them in ' + '*.png files at the specified output path') + prs.add_argument('-o', '--out', type=str, dest='out_path', default='.', + help='Output directory to dump Doppler product as well as' + 'PNG plots.') + + return prs.parse_args() + + +def gen_doppler_range_product(args): + """Generate Doppler-Range LUT Product. + + It generates Doppler centroid LUT as a function of slant range + at various azimuth/pulse times and dump them into a CSV file. + + The format of the file and output filename convention is defined + in reference [1]_. + + Parameters + ---------- + args : argparse.Namespace + All input arguments parsed from a command line or an ASCII file. + + References + ---------- + .. [1] D. Kannapan, "D&C Radar Data Product SIS," JPL D-104976, + December 3, 2020. + + """ + # Const + PREFIX_NAME_CSV = 'NISAR_ANC' + + tic = time.time() + # set logger + logger = set_logger("DopplerRangeProduct") + + # get keyword args for function "doppler_lut_from_raw" + kwargs = {key: val for key, val in args.__dict__.items() if + 'file' not in key} + + # get Raw object + raw_obj = open_rrsd(args.filename_l0b) + # get the SAR band char + sar_band_char = raw_obj.sarBand + logger.info(f'SAR band char -> {sar_band_char}') + + # operation mode, whether DBF (single or a composite channel) or + # 'DM2' (multi-channel) + # currently, the nunderlying module simply support DBF or single channel + op_mode = 'DBF' + + # generate Doppler LUT2d from Raw L0B + dop_lut, ref_utc, mask_rgb, corr_coef, txrx_pol, centerfreq, _ = \ + doppler_lut_from_raw(raw_obj, logger=logger, **kwargs) + + # check out antenna file to extract azimuth angle for EL cuts used for + # Doppler CSV product + if args.antenna_file is None: + az_ang_deg = 0.0 + logger.warning( + 'No antenna file! Azimuth angle for Doppler product is ' + 'assumed to be zero!' + ) + else: + logger.info( + 'Extracting the azimuth angle of EL cuts from antenna file.') + ant_obj = AntennaParser(args.antenna_file) + + ant_tx = ant_obj.el_cut(pol=txrx_pol[0]) + az_ang = ant_tx.cut_angle + # if RX pol is different from TX pol then take average of both + if txrx_pol[0] != txrx_pol[1]: + ant_rx = ant_obj.el_cut(pol=txrx_pol[1]) + az_ang += ant_rx.cut_angle + az_ang *= 0.5 + az_ang_deg = np.rad2deg(az_ang) + logger.info( + 'Azimuth angle extracted from antenna file -> ' + f'{az_ang_deg:.3f} (deg)' + ) + + # form Linspace object for uniformly-spaced azimuth time and slant range + azt_lsp = Linspace(dop_lut.y_start, dop_lut.y_spacing, dop_lut.length) + sr_lsp = Linspace(dop_lut.x_start, dop_lut.x_spacing, dop_lut.width) + + # get the first and last utc azimuth time w/o fractional seconds + # in "%Y%m%dT%H%M%S" format to be used as part of CSV product filename. + dt_utc_start = sec2str(ref_utc, azt_lsp.first) + dt_utc_stop = sec2str(ref_utc, azt_lsp.last) + # get current time w/o fractional seconds in "%Y%m%dT%H%M%S" format + # used as part of CSV product filename + dt_utc_cur = datetime.now().strftime('%Y%m%dT%H%M%S') + + # naming convention of CSV file and product spec is defined in Doc: + # See reference [1] + name_csv = (f'{PREFIX_NAME_CSV}_{sar_band_char}_{op_mode}_DOPP_' + f'{dt_utc_cur}_{dt_utc_start}_{dt_utc_stop}.csv') + file_csv = os.path.join(args.out_path, name_csv) + logger.info(f'Dump Doppler product in "CSV" format to file -> {file_csv}') + + with open(file_csv, 'wt') as fid_csv: + fid_csv.write( + 'UTC Time,Frequency (Hz),Doppler (Hz),Range (m),Azimuth (deg),' + 'Correlation\n' + ) + # loop over azimuth time and slant ranges + for i_row, azt in enumerate(azt_lsp): + tm_utc_str = sec2isofmt(ref_utc, azt) + + for i_col, sr in enumerate(sr_lsp): + fid_csv.write( + '{:s},{:.1f},{:.3f},{:.3f},{:.3f},{:.3f}\n'.format( + tm_utc_str, centerfreq, dop_lut.data[i_row, i_col], + sr, az_ang_deg, + mask_rgb[i_row, i_col] * corr_coef[i_row, i_col]) + ) + + # total elapsed time + logger.info(f'Elapsed time -> {time.time() - tic:.1f} (sec)') + + +def sec2isofmt(ref_utc: 'isce3.core.DateTime', seconds: float) -> str: + """seconds to isoformat string""" + return (ref_utc + TimeDelta(seconds)).isoformat() + + +def sec2str(ref_utc: 'isce3.core.DateTime', seconds: float) -> str: + """seconds to string format '%Y%m%dT%H%M%S'""" + fmt = '%Y%m%dT%H%M%S' + dt_iso = sec2isofmt(ref_utc, seconds) + return datetime.fromisoformat(dt_iso.split('.')[0]).strftime(fmt) + + +if __name__ == "__main__": + """Main driver""" + gen_doppler_range_product(cmd_line_parser()) diff --git a/python/packages/nisar/workflows/h5_prep.py b/python/packages/nisar/workflows/h5_prep.py index 93f1663f8..866b5de5f 100644 --- a/python/packages/nisar/workflows/h5_prep.py +++ b/python/packages/nisar/workflows/h5_prep.py @@ -250,11 +250,34 @@ def cp_geocode_meta(cfg, output_hdf5, dst): copy_insar_meta(cfg, dst, src_h5, dst_h5, src_meta_path) else: copy_gslc_gcov_meta(ref_slc.SwathPath, dst, src_h5, dst_h5) + if ref_slc.productType in ref_slc.SwathPath: + # Regular case + dst_path = ref_slc.SwathPath.replace(ref_slc.productType, dst) + else: + # ProductType is RSLC and SwathPath contains /SLC/ + dst_path = ref_slc.SwathPath.replace('SLC', dst) + # Copy zeroDopplerTimeSpacing scalar (GCOV and GSLC) + for freq in freq_pols.keys(): + frequency = f'frequency{freq}' + dst_freq_path = dst_path.replace('swaths', 'grids')+f'/{frequency}' + copy_zero_doppler_time_spacing(src_h5, ref_slc.SwathPath, + dst_h5, dst_freq_path) src_h5.close() dst_h5.close() +def copy_zero_doppler_time_spacing(src_h5, swath_path, dst_h5, dst_path): + az_spacing = src_h5[f'{swath_path}/zeroDopplerTimeSpacing'][()] + descr = "Time interval in the along track direction for raster layers. " \ + "This is the same as the spacing between consecutive entries in " \ + "zeroDopplerTime array" + _create_datasets(dst_h5[dst_path], [0], np.float32, + 'zeroDopplerTimeSpacing', + descr=descr, units="seconds", data=az_spacing, + long_name="zero doppler time spacing") + + def copy_gslc_gcov_meta(src_swath_path, dst, src_h5, dst_h5): ''' Copy metadata info for GSLC GCOV workflows @@ -365,7 +388,6 @@ def prep_ds(cfg, output_hdf5, dst): else: prep_ds_insar(cfg, dst, dst_h5) - def prep_ds_gslc_gcov(cfg, dst, dst_h5): ''' Prepare datasets for GSLC and GCOV diff --git a/tests/cxx/isce3/Sources.cmake b/tests/cxx/isce3/Sources.cmake index 09390b869..47b64e03f 100644 --- a/tests/cxx/isce3/Sources.cmake +++ b/tests/cxx/isce3/Sources.cmake @@ -79,6 +79,7 @@ signal/signal.cpp signal/signal_utils.cpp unwrap/icu/icu.cpp unwrap/phass/phass.cpp +unwrap/snaphu/mcf.cpp ) #This is a temporary fix - since GDAL does not support diff --git a/tests/cxx/isce3/cuda/geometry/geo2rdr/gpuGeo2rdr.cpp b/tests/cxx/isce3/cuda/geometry/geo2rdr/gpuGeo2rdr.cpp index 0c3a2b1f3..c63dfa010 100644 --- a/tests/cxx/isce3/cuda/geometry/geo2rdr/gpuGeo2rdr.cpp +++ b/tests/cxx/isce3/cuda/geometry/geo2rdr/gpuGeo2rdr.cpp @@ -20,7 +20,7 @@ #include "isce3/io/Raster.h" // isce3::product -#include "isce3/product/Product.h" +#include "isce3/product/RadarGridProduct.h" // isce3::cuda::geometry #include "isce3/cuda/geometry/Geo2rdr.h" @@ -32,7 +32,7 @@ TEST(Geo2rdrTest, RunGeo2rdr) { isce3::io::IH5File file(h5file); // Load the product - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); // Create geo2rdr instance isce3::cuda::geometry::Geo2rdr geo(product, 'A', true); diff --git a/tests/cxx/isce3/cuda/geometry/geometry/gpuGeometry.cpp b/tests/cxx/isce3/cuda/geometry/geometry/gpuGeometry.cpp index b9a12550c..2b5c561d2 100644 --- a/tests/cxx/isce3/cuda/geometry/geometry/gpuGeometry.cpp +++ b/tests/cxx/isce3/cuda/geometry/geometry/gpuGeometry.cpp @@ -27,7 +27,7 @@ #include "isce3/core/Serialization.h" // isce3::product -#include "isce3/product/Product.h" +#include "isce3/product/RadarGridProduct.h" #include "isce3/product/RadarGridParameters.h" // isce3::geometry @@ -61,8 +61,8 @@ struct GpuGeometryTest : public ::testing::Test { std::string h5file(TESTDATA_DIR "envisat.h5"); isce3::io::IH5File file(h5file); - // Instantiate a Product - isce3::product::Product product(file); + // Instantiate a RadarGridProduct + isce3::product::RadarGridProduct product(file); // Extract core and product objects orbit = product.metadata().orbit(); diff --git a/tests/cxx/isce3/cuda/geometry/rtc/gpuRTC.cpp b/tests/cxx/isce3/cuda/geometry/rtc/gpuRTC.cpp index 24601eb46..59f4b759d 100644 --- a/tests/cxx/isce3/cuda/geometry/rtc/gpuRTC.cpp +++ b/tests/cxx/isce3/cuda/geometry/rtc/gpuRTC.cpp @@ -3,13 +3,13 @@ #include "isce3/core/Serialization.h" #include "isce3/io/IH5.h" #include "isce3/io/Raster.h" -#include "isce3/product/Product.h" +#include "isce3/product/RadarGridProduct.h" #include "isce3/cuda/geometry/gpuRTC.h" TEST(TestRTC, RunRTC) { // Open HDF5 file and load products isce3::io::IH5File file(TESTDATA_DIR "envisat.h5"); - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); // Open DEM raster isce3::io::Raster dem(TESTDATA_DIR "srtm_cropped.tif"); diff --git a/tests/cxx/isce3/cuda/geometry/topo/gpuTopo.cpp b/tests/cxx/isce3/cuda/geometry/topo/gpuTopo.cpp index 3a6d83978..06aa209fe 100644 --- a/tests/cxx/isce3/cuda/geometry/topo/gpuTopo.cpp +++ b/tests/cxx/isce3/cuda/geometry/topo/gpuTopo.cpp @@ -21,7 +21,7 @@ #include "isce3/io/Raster.h" // isce3::product -#include "isce3/product/Product.h" +#include "isce3/product/RadarGridProduct.h" // isce3::cuda::geometry #include "isce3/cuda/geometry/Topo.h" @@ -36,7 +36,7 @@ TEST(GPUTopoTest, RunTopo) { isce3::io::IH5File file(h5file); // Load the product - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); // Create topo instance isce3::cuda::geometry::Topo topo(product, 'A', true); diff --git a/tests/cxx/isce3/cuda/image/resampslc/gpuResampSlc.cpp b/tests/cxx/isce3/cuda/image/resampslc/gpuResampSlc.cpp index 7ab507693..5935c7ffd 100644 --- a/tests/cxx/isce3/cuda/image/resampslc/gpuResampSlc.cpp +++ b/tests/cxx/isce3/cuda/image/resampslc/gpuResampSlc.cpp @@ -21,7 +21,7 @@ #include "isce3/io/Raster.h" // isce3::product -#include "isce3/product/Product.h" +#include "isce3/product/RadarGridProduct.h" // isce3::cuda::image #include "isce3/cuda/image/ResampSlc.h" @@ -36,7 +36,7 @@ TEST(ResampSlcTest, Resamp) { isce3::io::IH5File file(h5file); // Create product - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); // Instantiate a ResampSLC object isce3::cuda::image::ResampSlc gpu_resamp(product); diff --git a/tests/cxx/isce3/cuda/signal/gpuCrossMul.cpp b/tests/cxx/isce3/cuda/signal/gpuCrossMul.cpp index ccdcc87d6..a2afdc2f9 100644 --- a/tests/cxx/isce3/cuda/signal/gpuCrossMul.cpp +++ b/tests/cxx/isce3/cuda/signal/gpuCrossMul.cpp @@ -9,7 +9,7 @@ #include "isce3/io/Raster.h" #include -#include +#include #include #include "isce3/cuda/signal/gpuCrossMul.h" @@ -40,7 +40,7 @@ TEST(gpuCrossmul, Crossmul) isce3::io::IH5File file(h5file); // Create a product and swath - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); const isce3::product::Swath & swath = product.swath('A'); // get the Doppler polynomial for refernce SLC @@ -125,7 +125,7 @@ TEST(gpuCrossmul, MultilookCrossmul) isce3::io::IH5File file(h5file); // Create a product and swath - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); const isce3::product::Swath & swath = product.swath('A'); // get the Doppler polynomial for refernce SLC @@ -211,7 +211,7 @@ TEST(gpuCrossmul, CrossmulAzimuthFilter) isce3::io::IH5File file(h5file); // Create a product and swath - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); const isce3::product::Swath & swath = product.swath('A'); // get the Doppler polynomial for refernce SLC diff --git a/tests/cxx/isce3/cuda/signal/gpuFilter.cpp b/tests/cxx/isce3/cuda/signal/gpuFilter.cpp index 368216a0c..3ab6c2150 100644 --- a/tests/cxx/isce3/cuda/signal/gpuFilter.cpp +++ b/tests/cxx/isce3/cuda/signal/gpuFilter.cpp @@ -10,7 +10,7 @@ #include "isce3/io/Raster.h" #include #include -#include +#include #include "isce3/cuda/signal/gpuSignal.h" #include "isce3/cuda/signal/gpuFilter.h" @@ -33,7 +33,7 @@ TEST(Filter, constructAzimuthCommonbandFilter) isce3::io::IH5File file(h5file); // Create a product and swath - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); const isce3::product::Swath & swath = product.swath('A'); // Get the Doppler polynomial and use it for both refernce and secondary SLCs @@ -79,7 +79,7 @@ TEST(Filter, constructBoxcarRangeBandpassFilter) isce3::io::IH5File file(h5file); // Create a product and swath - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); const isce3::product::Swath & swath = product.swath('A'); // get the range bandwidth diff --git a/tests/cxx/isce3/geocode/geocode.cpp b/tests/cxx/isce3/geocode/geocode.cpp index cc73dab01..ef01deff9 100644 --- a/tests/cxx/isce3/geocode/geocode.cpp +++ b/tests/cxx/isce3/geocode/geocode.cpp @@ -19,7 +19,7 @@ #include #include #include -#include +#include #include std::set geocode_mode_set = {"interp", "area_proj"}; @@ -59,7 +59,7 @@ TEST(GeocodeTest, TestGeocodeCov) { isce3::io::IH5File file(h5file); // Load the product - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); const isce3::product::Swath & swath = product.swath('A'); isce3::core::Orbit orbit = product.metadata().orbit(); @@ -545,7 +545,7 @@ TEST(GeocodeTest, TestGeocodeSlc) // Load the product std::cout << "create the product" << std::endl; - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); // std::cout << "get the swath" << std::endl; // const isce3::product::Swath & swath = product.swath('A'); @@ -816,7 +816,7 @@ void createTestData() isce3::io::IH5File file(h5file); // Load the product - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); // Create topo instance with native Doppler isce3::geometry::Topo topo(product, 'A', true); diff --git a/tests/cxx/isce3/geometry/geo2rdr/geo2rdr.cpp b/tests/cxx/isce3/geometry/geo2rdr/geo2rdr.cpp index 8f7584a8d..c70b48de8 100644 --- a/tests/cxx/isce3/geometry/geo2rdr/geo2rdr.cpp +++ b/tests/cxx/isce3/geometry/geo2rdr/geo2rdr.cpp @@ -20,7 +20,7 @@ #include "isce3/io/Raster.h" // isce3::product -#include "isce3/product/Product.h" +#include "isce3/product/RadarGridProduct.h" // isce3::geometry #include "isce3/geometry/Geo2rdr.h" @@ -32,7 +32,7 @@ TEST(Geo2rdrTest, RunGeo2rdr) { isce3::io::IH5File file(h5file); // Load the product - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); // Create geo2rdr instance isce3::geometry::Geo2rdr geo(product, 'A', true); diff --git a/tests/cxx/isce3/geometry/geometry/geometry.cpp b/tests/cxx/isce3/geometry/geometry/geometry.cpp index 428eeef27..020b9a27e 100644 --- a/tests/cxx/isce3/geometry/geometry/geometry.cpp +++ b/tests/cxx/isce3/geometry/geometry/geometry.cpp @@ -25,7 +25,7 @@ #include // isce3::product -#include +#include // isce3::geometry #include @@ -57,8 +57,8 @@ struct GeometryTest : public ::testing::Test { std::string h5file(TESTDATA_DIR "envisat.h5"); isce3::io::IH5File file(h5file); - // Instantiate a Product - isce3::product::Product product(file); + // Instantiate a RadarGridProduct + isce3::product::RadarGridProduct product(file); // Extract core and product objects orbit = product.metadata().orbit(); diff --git a/tests/cxx/isce3/geometry/metadata_cubes/metadata_cubes.cpp b/tests/cxx/isce3/geometry/metadata_cubes/metadata_cubes.cpp index 27f4e723d..00e311708 100644 --- a/tests/cxx/isce3/geometry/metadata_cubes/metadata_cubes.cpp +++ b/tests/cxx/isce3/geometry/metadata_cubes/metadata_cubes.cpp @@ -13,7 +13,7 @@ #include #include #include -#include +#include #include template @@ -384,7 +384,7 @@ TEST(radarGridCubeTest, testRadarGridCube) isce3::io::IH5File file(h5file); // Load the product - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); // Create radar grid parameter char frequency = 'A'; @@ -691,7 +691,7 @@ TEST(metadataCubesTest, testMetadataCubes) { isce3::io::IH5File file(h5file); // Load the product - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); // Create radar grid parameter char frequency = 'A'; diff --git a/tests/cxx/isce3/geometry/rtc/rtc.cpp b/tests/cxx/isce3/geometry/rtc/rtc.cpp index b6daf6cac..22a5b9f3d 100644 --- a/tests/cxx/isce3/geometry/rtc/rtc.cpp +++ b/tests/cxx/isce3/geometry/rtc/rtc.cpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include @@ -19,7 +19,7 @@ std::set rtc_algorithm_set = { TEST(TestRTC, RunRTC) { // Open HDF5 file and load products isce3::io::IH5File file(TESTDATA_DIR "envisat.h5"); - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); char frequency = 'A'; // Open DEM raster diff --git a/tests/cxx/isce3/geometry/topo/topo.cpp b/tests/cxx/isce3/geometry/topo/topo.cpp index 9e39addd7..9f3514940 100644 --- a/tests/cxx/isce3/geometry/topo/topo.cpp +++ b/tests/cxx/isce3/geometry/topo/topo.cpp @@ -20,7 +20,7 @@ #include "isce3/io/Raster.h" // isce3::product -#include "isce3/product/Product.h" +#include "isce3/product/RadarGridProduct.h" // isce3::geometry #include "isce3/geometry/Topo.h" @@ -35,7 +35,7 @@ TEST(TopoTest, RunTopo) { isce3::io::IH5File file(h5file); // Load the product - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); // Create topo instance isce3::geometry::Topo topo(product, 'A', true); diff --git a/tests/cxx/isce3/image/resampslc/resampslc.cpp b/tests/cxx/isce3/image/resampslc/resampslc.cpp index e4c6fa0fc..d2531e4c9 100644 --- a/tests/cxx/isce3/image/resampslc/resampslc.cpp +++ b/tests/cxx/isce3/image/resampslc/resampslc.cpp @@ -21,7 +21,7 @@ #include "isce3/io/Raster.h" // isce3::product -#include "isce3/product/Product.h" +#include "isce3/product/RadarGridProduct.h" // isce3::image #include "isce3/image/ResampSlc.h" @@ -36,7 +36,7 @@ TEST(ResampSlcTest, Resamp) { isce3::io::IH5File file(h5file); // Create product - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); // Instantiate a ResampSLC object isce3::image::ResampSlc resamp(product); diff --git a/tests/cxx/isce3/product/radargrid/radargrid.cpp b/tests/cxx/isce3/product/radargrid/radargrid.cpp index 13f262885..9b1c252be 100644 --- a/tests/cxx/isce3/product/radargrid/radargrid.cpp +++ b/tests/cxx/isce3/product/radargrid/radargrid.cpp @@ -12,7 +12,7 @@ // isce3::product #include #include -#include +#include using isce3::core::LookSide; @@ -23,7 +23,7 @@ TEST(RadarGridTest, fromProduct) { isce3::io::IH5File file(h5file); // Instantiate and load a product - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); //Create radar grid from product isce3::product::RadarGridParameters grid(product); @@ -44,7 +44,7 @@ TEST(RadarGridTest, fromSwath) { isce3::io::IH5File file(h5file); // Instantiate and load a product - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); // Get the swath const isce3::product::Swath &swath = product.swath('A'); diff --git a/tests/cxx/isce3/product/serialization/serializeProduct.cpp b/tests/cxx/isce3/product/serialization/serializeProduct.cpp index 7dedaa5fb..a53eb8fbe 100644 --- a/tests/cxx/isce3/product/serialization/serializeProduct.cpp +++ b/tests/cxx/isce3/product/serialization/serializeProduct.cpp @@ -13,7 +13,7 @@ #include // isce3::product -#include +#include #include TEST(ProductTest, FromHDF5) { @@ -23,7 +23,7 @@ TEST(ProductTest, FromHDF5) { isce3::io::IH5File file(h5file); // Instantiate and load a product - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); // Get the swath const isce3::product::Swath & swath = product.swath('A'); diff --git a/tests/cxx/isce3/signal/crossmul.cpp b/tests/cxx/isce3/signal/crossmul.cpp index 6a65d9997..d27388b73 100644 --- a/tests/cxx/isce3/signal/crossmul.cpp +++ b/tests/cxx/isce3/signal/crossmul.cpp @@ -11,7 +11,7 @@ #include "isce3/io/Raster.h" #include "isce3/signal/Crossmul.h" #include -#include +#include #include using isce3::core::avgLUT2dToLUT1d; @@ -40,7 +40,7 @@ TEST(Crossmul, RunCrossmul) isce3::io::IH5File file(h5file); // Create a product and swath - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); const isce3::product::Swath & swath = product.swath('A'); // get the Doppler polynomial for refernce SLC @@ -125,7 +125,7 @@ TEST(Crossmul, RunCrossmulWithAzimuthCommonBandFilter) isce3::io::IH5File file(h5file); // Create a product and swath - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); const isce3::product::Swath & swath = product.swath('A'); // get the Doppler polynomial for refernce SLC diff --git a/tests/cxx/isce3/signal/filter.cpp b/tests/cxx/isce3/signal/filter.cpp index 5b4336a74..430a39518 100644 --- a/tests/cxx/isce3/signal/filter.cpp +++ b/tests/cxx/isce3/signal/filter.cpp @@ -9,7 +9,7 @@ #include #include #include -#include +#include #include #include #include @@ -32,7 +32,7 @@ TEST(Filter, constructAzimuthCommonbandFilter) isce3::io::IH5File file(h5file); // Create a product and swath - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); const isce3::product::Swath & swath = product.swath('A'); // Get the Doppler polynomial and use it for both refernce and secondary SLCs @@ -75,7 +75,7 @@ TEST(Filter, constructBoxcarRangeBandpassFilter) isce3::io::IH5File file(h5file); // Create a product and swath - isce3::product::Product product(file); + isce3::product::RadarGridProduct product(file); const isce3::product::Swath & swath = product.swath('A'); // get the range bandwidth diff --git a/tests/cxx/isce3/unwrap/snaphu/mcf.cpp b/tests/cxx/isce3/unwrap/snaphu/mcf.cpp new file mode 100644 index 000000000..ea25893ec --- /dev/null +++ b/tests/cxx/isce3/unwrap/snaphu/mcf.cpp @@ -0,0 +1,193 @@ +#include + +#include + +#include + +/** Check if two arrays are element-wise equal. */ +template +bool allEqual(const ArrayLike2D& a, const ArrayLike2D& b) +{ + auto m = a.rows(); + auto n = a.cols(); + + // Check that `a` and `b` have the same shape. + if (b.rows() != m or b.cols() != n) { + return false; + } + + // Check that each element of `a` matches the corresponding element from + // `b`. + for (decltype(m) i = 0; i < m; ++i) { + for (decltype(n) j = 0; j < n; ++j) { + if (a(i, j) != b(i, j)) { + return false; + } + } + } + + return true; +} + +TEST(MCFTest, MCF1) +{ + // Dimensions of the 2D wrapped phase field. + const long nrow = 3; + const long ncol = 5; + + // Construct a wrapped phase field with a positive residue in the top-left + // corner and a negative residue in the bottom-left corner. + auto wrapped_phase = isce3::unwrap::Array2D(nrow, ncol); + wrapped_phase.row(0) = M_PI; + wrapped_phase.row(1) = 0.0f; + wrapped_phase.row(2) = M_PI; + wrapped_phase.col(0) = -M_PI; + + std::cout << "wrapped_phase = " << std::endl << wrapped_phase << std::endl; + + // Calculate residues (just for debugging purposes -- not required for + // testing). + auto residue = isce3::unwrap::Array2D(nrow - 1, ncol - 1); + isce3::unwrap::CycleResidue(wrapped_phase, residue, nrow, ncol); + + std::cout << "residue = " << std::endl << residue << std::endl; + + // Initialize all arc costs to an arbitrary large value. + auto costs = isce3::unwrap::MakeRowColArray2D(nrow, ncol); + costs = 99; + + // Carve out a zero-cost path from the positive residue to the negative + // residue. + auto rowcosts = costs.topLeftCorner(nrow - 1, ncol); + auto colcosts = costs.bottomLeftCorner(nrow, ncol - 1); + rowcosts(0, 1) = 0; + rowcosts(0, 2) = 0; + rowcosts(0, 3) = 0; + colcosts(1, 3) = 0; + rowcosts(1, 3) = 0; + rowcosts(1, 2) = 0; + rowcosts(1, 1) = 0; + + std::cout << "rowcosts = " << std::endl << rowcosts << std::endl; + std::cout << "colcosts = " << std::endl << colcosts << std::endl; + + // Calculate arc flows using the MCF initializer. + isce3::unwrap::Array2D flows; + isce3::unwrap::MCFInitFlows(wrapped_phase, &flows, costs, nrow, ncol); + + auto rowflows = flows.topLeftCorner(nrow - 1, ncol); + auto colflows = flows.bottomLeftCorner(nrow, ncol - 1); + + std::cout << "rowflows = " << std::endl << rowflows << std::endl; + std::cout << "colflows = " << std::endl << colflows << std::endl; + + // Get the expected resulting flows. + auto true_flows = isce3::unwrap::MakeRowColArray2D(nrow, ncol); + true_flows = 0; + + auto true_rowflows = true_flows.topLeftCorner(nrow - 1, ncol); + auto true_colflows = true_flows.bottomLeftCorner(nrow, ncol - 1); + true_rowflows(0, 1) = 1; + true_rowflows(0, 2) = 1; + true_rowflows(0, 3) = 1; + true_colflows(1, 3) = 1; + true_rowflows(1, 3) = -1; + true_rowflows(1, 2) = -1; + true_rowflows(1, 1) = -1; + + std::cout << "true_rowflows = " << std::endl << true_rowflows << std::endl; + std::cout << "true_colflows = " << std::endl << true_colflows << std::endl; + + // Make sure the computed flows match the expected flows. + EXPECT_TRUE(allEqual(rowflows, true_rowflows)); + EXPECT_TRUE(allEqual(colflows, true_colflows)); +} + +TEST(MCFTest, MCF2) +{ + // Dimensions of the 2D wrapped phase field. + const long nrow = 8; + const long ncol = 3; + + // Construct a wrapped phase field with a single positive residue and two + // negative residues. The residue array looks like this: + // + // [ 0 0 ] + // [-1 0 ] + // [ 0 0 ] + // [ 1 0 ] + // [ 0 0 ] + // [-1 0 ] + // [ 0 0 ] + // + auto wrapped_phase = isce3::unwrap::Array2D(nrow, ncol); + wrapped_phase.col(0) = 0.0f; + wrapped_phase.col(1) = M_PI; + wrapped_phase.col(2) = M_PI; + wrapped_phase.row(2) = -M_PI; + wrapped_phase.row(3) = -M_PI; + wrapped_phase.row(6) = -M_PI; + wrapped_phase.row(7) = -M_PI; + + std::cout << "wrapped_phase = " << std::endl << wrapped_phase << std::endl; + + // Calculate residues (just for debugging purposes -- not required for + // testing). + auto residue = isce3::unwrap::Array2D(nrow - 1, ncol - 1); + isce3::unwrap::CycleResidue(wrapped_phase, residue, nrow, ncol); + + std::cout << "residue = " << std::endl << residue << std::endl; + + // Initialize all arc costs to an arbitrary large value. + auto costs = isce3::unwrap::MakeRowColArray2D(nrow, ncol); + costs = 99; + + // Carve out a zero-cost path between each residue and the "ground" node. + auto rowcosts = costs.topLeftCorner(nrow - 1, ncol); + auto colcosts = costs.bottomLeftCorner(nrow, ncol - 1); + colcosts(0, 0) = 0; + colcosts(1, 0) = 0; + rowcosts(3, 1) = 0; + rowcosts(3, 2) = 0; + colcosts(6, 0) = 0; + colcosts(7, 0) = 0; + + std::cout << "rowcosts = " << std::endl << rowcosts << std::endl; + std::cout << "colcosts = " << std::endl << colcosts << std::endl; + + // Calculate arc flows using the MCF initializer. + isce3::unwrap::Array2D flows; + isce3::unwrap::MCFInitFlows(wrapped_phase, &flows, costs, nrow, ncol); + + auto rowflows = flows.topLeftCorner(nrow - 1, ncol); + auto colflows = flows.bottomLeftCorner(nrow, ncol - 1); + + std::cout << "rowflows = " << std::endl << rowflows << std::endl; + std::cout << "colflows = " << std::endl << colflows << std::endl; + + // Get the expected resulting flows. + auto true_flows = isce3::unwrap::MakeRowColArray2D(nrow, ncol); + true_flows = 0; + + auto true_rowflows = true_flows.topLeftCorner(nrow - 1, ncol); + auto true_colflows = true_flows.bottomLeftCorner(nrow, ncol - 1); + true_colflows(0, 0) = 1; + true_colflows(1, 0) = 1; + true_rowflows(3, 1) = 1; + true_rowflows(3, 2) = 1; + true_colflows(6, 0) = -1; + true_colflows(7, 0) = -1; + + std::cout << "true_rowflows = " << std::endl << true_rowflows << std::endl; + std::cout << "true_colflows = " << std::endl << true_colflows << std::endl; + + // Make sure the computed flows match the expected flows. + EXPECT_TRUE(allEqual(rowflows, true_rowflows)); + EXPECT_TRUE(allEqual(colflows, true_colflows)); +} + +int main(int argc, char* argv[]) +{ + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tests/data/README.md b/tests/data/README.md index 2cb78edc9..fcaa85169 100644 --- a/tests/data/README.md +++ b/tests/data/README.md @@ -41,6 +41,19 @@ f415fc38e1feff0bb1453782be3d2b5f *smoothFinal.xml.gz This file was uncompressed and trimmed to the first ten state vectors in order to reduce the size of the `orbit.xml` file stored here. +# nisar_129_gcov_crop.h5 + +Minimal GCOV sample product obtained from processing a small subset of the +UAVSAR dataset NISARA_13905_19070_007_190930_L090_CX_129_02 from the AM/PM +campaing. The sample product was generated with parameters: +``` +top_left: + y_abs: 35.18 + x_abs: -83.46 +bottom_right: + y_abs: 35.13 + x_abs: -83.41 +``` ## REE Multi-channel L0B Raw data and HDF5 antenna pattern cuts (NISAR antenna format *v2*) diff --git a/tests/data/nisar_129_gcov_crop.h5 b/tests/data/nisar_129_gcov_crop.h5 new file mode 100644 index 000000000..3610e29cd Binary files /dev/null and b/tests/data/nisar_129_gcov_crop.h5 differ diff --git a/tests/python/extensions/pybind/CMakeLists.txt b/tests/python/extensions/pybind/CMakeLists.txt index 298d0cbcc..20267f327 100644 --- a/tests/python/extensions/pybind/CMakeLists.txt +++ b/tests/python/extensions/pybind/CMakeLists.txt @@ -44,6 +44,7 @@ signal/filter2D.py product/generic_product.py product/radargridparameters.py product/swath.py +product/grid.py unwrap/icu.py unwrap/phass.py geometry/ltpcoordinates.py diff --git a/tests/python/extensions/pybind/product/grid.py b/tests/python/extensions/pybind/product/grid.py new file mode 100644 index 000000000..648bc3c19 --- /dev/null +++ b/tests/python/extensions/pybind/product/grid.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 + +import numpy.testing as npt + +import pybind_isce3 as isce +import iscetest + +def test_radargridparameters(): + + # Create Grid object + gcov_crop_path = iscetest.data + "nisar_129_gcov_crop.h5" + grid = isce.product.Grid(gcov_crop_path, 'A') + + # Check its values + npt.assert_equal(grid.wavelength, 0.24118460016090104) + npt.assert_equal(grid.range_bandwidth, 20000000.0) + npt.assert_equal(grid.azimuth_bandwidth, 19.100906906009193) + npt.assert_equal(grid.center_frequency, 1243000000.0) + npt.assert_equal(grid.slant_range_spacing, 6.245676208) + npt.assert_equal(grid.zero_doppler_time_spacing, 0.022481251507997513) + npt.assert_equal(grid.start_x, -83.45989999999999) + npt.assert_equal(grid.start_y, 35.179899999999996) + npt.assert_equal(grid.spacing_x, 0.0002) + npt.assert_equal(grid.spacing_y, -0.0002) + npt.assert_equal(grid.width, 250) + npt.assert_equal(grid.length, 250) + npt.assert_equal(grid.epsg, 4326) + +# end of file diff --git a/tests/python/packages/CMakeLists.txt b/tests/python/packages/CMakeLists.txt index 39ba3caf5..3c09348b3 100644 --- a/tests/python/packages/CMakeLists.txt +++ b/tests/python/packages/CMakeLists.txt @@ -1,12 +1,16 @@ set(TESTFILES isce3/core/gpu_check.py +isce3/signal/doppler_est_func.py +isce3/signal/fir_filter_func.py isce3/signal/point_target_info.py isce3/unwrap/snaphu.py nisar/products/readers/attitude.py nisar/products/readers/orbit.py nisar/products/readers/raw.py nisar/workflows/crossmul.py +nisar/workflows/doppler_lut_from_raw.py nisar/workflows/focus.py +nisar/workflows/gen_doppler_range_product.py nisar/workflows/stage_dem.py nisar/workflows/gcov.py nisar/workflows/geo2rdr.py diff --git a/tests/python/packages/isce3/signal/doppler_est_func.py b/tests/python/packages/isce3/signal/doppler_est_func.py new file mode 100644 index 000000000..c40b105fe --- /dev/null +++ b/tests/python/packages/isce3/signal/doppler_est_func.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +import iscetest +from isce3.focus import form_linear_chirp +from isce3.signal import corr_doppler_est, sign_doppler_est + +import numpy as np +import numpy.testing as npt + + +# List of functions used in generating simulated data +def rcoswin(size, ped=1.0): + """Raised-cosine symmetric window function. + + Parameters + ---------- + size : int + size of the window + ped : float, default=1.0 + pedestal, a value within [0. 1.0] + + Returns + ------- + np.ndarray(float) + + Raises + ------ + AssertionError + For bad inputs + + """ + assert 0 <= ped and ped <= 1, 'Pedestal value shall be wihtin [0, 1]!' + assert size > 0, 'Size must be a positive number!' + return (1 + ped) / 2. - (1 - ped) / 2. * np.cos(2.0 * np.pi / (size - 1) * + np.arange(0, size)) + + +def form_weighted_chirp(bandwidth, duration, prf, pedestal_win): + """Form a weighted complex baseband chirp. + + Parameters + ---------- + bandwidth : float + chirp bandwidth in (Hz) + duration : float + chirp duration in (sec) + prf : float + PRF in (Hz) + pedestal_win : float + Pedestal of a rasied cosine window + + Returns + ------- + np.ndarray(complex) + Complex windowed baseband chirp + + """ + chirp_rate = bandwidth / duration # (Hz/sec) + chirp = np.asarray(form_linear_chirp(chirp_rate, duration, prf, 0.0)) + # apply weighting + chirp *= rcoswin(len(chirp), ped=pedestal_win) + return chirp + + +class TestDopplerEstFunc: + # List of parameters for generating noisy azimuth chirp signal + + # azimuth sampling rate (Hz) + prf = 2000. + # bandwidth of azimuth chirp(Hz) + bandwidth = 1500.0 + # duration of azimuth chirp (sec) + duration = 2.5 + # signal to noise ratio (dB) + snr = 8.0 + # number of range bins + num_rgb = 8 + # pedestal of raised cosine window function + pedestal_win = 0.4 + # seed number for random generator + seed_rnd = 10 + + # list of desired doppler centroids to be tested + doppler_list = [-550., -215, -50, 0, 50, 215, 550] # (Hz) + + # absolute doppler tolerance in (Hz) per requirement for validating + # Doppler estimators output against list of Dopplers + atol_dop = 15.0 + + # generating noise-free baseband chirp as well as noise signal + # used in testing all methods + + # form a noise-free weighted complex baseband chirp + chirp = form_weighted_chirp(bandwidth, duration, prf, pedestal_win) + + # generate complex Gaussian zero-mean random noise, + # one indepedent set per range bin used for all dopplers + std_iq_noise = 1./np.sqrt(2) * 10 ** (-snr / 20) + rnd_gen = np.random.RandomState(seed=seed_rnd) + noise = std_iq_noise * (rnd_gen.randn(chirp.size, num_rgb) + + 1j * rnd_gen.randn(chirp.size, num_rgb)) + + def _validate_doppler_est(self, method: str): + """Validate estimated doppler for a method""" + # form a generic doppler estimator function covering both methods + if method == 'CDE': + dop_func = corr_doppler_est + else: # 'SDE' + def dop_func(echo, prf, lag=1, axis=None): + return sign_doppler_est(echo, prf, lag, axis), 1 + + # loop over list of doppler centroid to generate noisy pass-band signal + for doppler in self.doppler_list: + # create a pass-band chirp from baseband one per doppler + chirp_dop = self.chirp * np.exp(1j * 2 * np.pi * doppler / self.prf + * np.arange(self.chirp.size)) + # create a noisy complex signal , one set per range bin + sig = np.repeat(chirp_dop.reshape((chirp_dop.size, 1)), + self.num_rgb, axis=1) + self.noise + # estimate doppler and check its value + dop_est, corr_coef = dop_func(sig, self.prf) + npt.assert_allclose(dop_est, doppler, atol=self.atol_dop, + err_msg='Large error for doppler ' + f'{doppler:.2f} (Hz) in method "{method}"') + npt.assert_equal((corr_coef >= 0 and corr_coef <= 1), True, + err_msg = 'Correlation coeff is out of range' + f' for method {method}') + print(f'Correlation coef for method {method} & Doppler ' + f'{doppler:.1f} (Hz) -> {corr_coef:.3f}') + + def test_corr_doppler_est(self): + self._validate_doppler_est('CDE') + + def test_sign_doppler_est(self): + self._validate_doppler_est('SDE') diff --git a/tests/python/packages/isce3/signal/fir_filter_func.py b/tests/python/packages/isce3/signal/fir_filter_func.py new file mode 100644 index 000000000..001c1c4f9 --- /dev/null +++ b/tests/python/packages/isce3/signal/fir_filter_func.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +import iscetest +from isce3.signal import cheby_equi_ripple_filter + +import numpy as np +import numpy.testing as npt +import bisect + + +def est_centroid_spectrum(pow_spec, freq): + """Estimate the centroid of the power spectrum""" + return (pow_spec * freq).sum() / pow_spec.sum() + + +def est_equivalent_bandwidth(pow_spec, df): + """Estimate rectangle equivalent bandwidth of spectrum""" + return pow_spec.sum() * df / pow_spec.max() + + +def test_cheby_equi_ripple_filter(): + # number of fft for power spectrum calculation + nfft = 2048 + + # relative tolerance for pass-band ripple (dB), stop-band attenuation (dB) + # and bandwidth (MHz) + rtol_pass = 0.5 + rtol_stop = 0.1 + rtol_bw = 0.02 + + # desired filter spec + samprate = 1.6 # (MHz) + bandwidth = 1.0 # (MHz) + rolloff = 1.2 # roll off or oversampling factor (-) + ripple = 0.2 # pass band ripple (dB) + stopatt = 35.0 # min stop-band attenutation (dB) + centerfreq = 0.1 # (MHz) + + # generate filter coeffs + coefs = cheby_equi_ripple_filter(samprate, bandwidth, rolloff, ripple, + stopatt, centerfreq) + + # get the power spectrum of the filter in linear scale + pow_spec = abs(np.fft.fftshift(np.fft.fft(coefs, nfft)))**2 + # frequency vector within [-samprate/2., samprate/2.[ + # freq resolution defines absolute tolerance in ceter freq, + # and roll-off factor estimation! + df = samprate / nfft + freq = -0.5 * samprate + df * np.arange(nfft) + + # check the center freq + freq_cent_est = est_centroid_spectrum(pow_spec, freq) + print(f'Estimated frequency centroid -> {freq_cent_est:.3f} (MHz)') + npt.assert_allclose(freq_cent_est, centerfreq, atol=df, + err_msg='Wrong center frequency!') + + # check bandwidth + rect_bw_est = est_equivalent_bandwidth(pow_spec, df) + print(f'Estimated bandwidth -> {rect_bw_est:.3f} (MHz)') + npt.assert_allclose(rect_bw_est, bandwidth, rtol=rtol_bw, + err_msg='Wrong bandwidth!') + + # get the expected [low, high[ index wihtin pass-band region + frq_pass_low = -0.5 * bandwidth + centerfreq + frq_pass_high = 0.5 * bandwidth + centerfreq + idx_pass_low = bisect.bisect_left(freq, frq_pass_low) + idx_pass_high = bisect.bisect_right(freq, frq_pass_high) + slice_pass = slice(idx_pass_low, idx_pass_high) + + # make sure the peak occurs wihtin expected [low,high[ of passband + idx_max = pow_spec.argmax() + npt.assert_equal(idx_pass_low <= idx_max and idx_max <= idx_pass_high, + True, err_msg='The peak gain occurs outside expected \ +pass-band region') + # get peak-to-peak ripple within pass band to check pass-band ripple + max_val_pass = pow_spec[idx_max] + min_val_pass = pow_spec[slice_pass].min() + est_ripple = 5.0 * np.log10(max_val_pass / min_val_pass) + print(f'Estimated ripple within passband -> {est_ripple:.2f} (dB)') + npt.assert_allclose(est_ripple, ripple, rtol=rtol_pass, + err_msg='Wrong pass-band ripple') + + # get expected start {left, right} index for edge of stop band region + frq_stop_left = -0.5 * rolloff * bandwidth + centerfreq + frq_stop_right = 0.5 * rolloff * bandwidth + centerfreq + idx_stop_left = bisect.bisect_left(freq, frq_stop_left) + idx_stop_right = bisect.bisect_right(freq, frq_stop_right) + slice_stop_left = slice(0, idx_stop_left) + slice_stop_right = slice(idx_stop_right, nfft) + # check min stop-band attenuation within stop band on each side + min_att_left = abs(10 * np.log10(pow_spec[slice_stop_left].max())) + min_att_right = abs(10 * np.log10(pow_spec[slice_stop_right].max())) + print(f'Estimated min stop-band attenuation on the left side -> \ +{min_att_left:.2f} (dB)') + print(f'Estimated min stop-band attenuation on the right side -> \ +{min_att_right:.2f} (dB)') + npt.assert_allclose(min_att_left, stopatt, rtol=rtol_stop, + err_msg='Wrong stop-band attenuation on the left') + npt.assert_allclose(min_att_right, stopatt, rtol=rtol_stop, + err_msg='Wrong stop-band attenuation on the right') diff --git a/tests/python/packages/isce3/unwrap/snaphu.py b/tests/python/packages/isce3/unwrap/snaphu.py index 3b5a851c4..86755459b 100644 --- a/tests/python/packages/isce3/unwrap/snaphu.py +++ b/tests/python/packages/isce3/unwrap/snaphu.py @@ -2,6 +2,7 @@ import isce3 import numpy as np +import pytest from isce3.unwrap import snaphu @@ -131,7 +132,7 @@ def next_power_of_two(n): return z[:length, :width] -def jaccard_similarity(a, b, /): +def jaccard_similarity(a, b): """Compute the Jaccard similarity coefficient (intersect-over-union) of two boolean arrays. @@ -195,7 +196,8 @@ def simulate_phase_noise(corr, nlooks: float, *, seed: Optional[int] = None): class TestSnaphu: - def test_smooth_cost(self): + @pytest.mark.parametrize("init_method", ["mcf", "mst"]) + def test_smooth_cost(self, init_method): """Test SNAPHU unwrapping using "smooth" cost mode.""" # Interferogram dimensions l, w = 1100, 256 @@ -249,6 +251,7 @@ def test_smooth_cost(self): corr_raster, nlooks=20.0, cost="smooth", + init_method=init_method, conncomp_params=conncomp_params, ) @@ -272,7 +275,8 @@ def test_smooth_cost(self): cc = ccl_raster.data == label assert jaccard_similarity(cc, mask) > 0.9 - def test_topo_cost(self): + @pytest.mark.parametrize("init_method", ["mcf", "mst"]) + def test_topo_cost(self, init_method): """Test SNAPHU unwrapping using "topo" cost mode.""" # Simulate a topographic interferometric phase signal using notionally # NISAR-like 20 MHz L-band parameters. @@ -344,6 +348,7 @@ def test_topo_cost(self): nlooks=nlooks, cost="topo", cost_params=cost_params, + init_method=init_method, ) # Check the connected component labels. There should be a single diff --git a/tests/python/packages/nisar/products/readers/raw.py b/tests/python/packages/nisar/products/readers/raw.py index d720a7b0e..6ca432c5a 100644 --- a/tests/python/packages/nisar/products/readers/raw.py +++ b/tests/python/packages/nisar/products/readers/raw.py @@ -28,11 +28,18 @@ def test_raw(): t, grid = raw.getRadarGrid(frequency=freq, tx=tx) ds = raw.getRawDataset(freq, pol) swaths = raw.getSubSwaths(frequency=freq, tx=tx) + prf = raw.getNominalPRF(freq, tx) + bandwidth = raw.getRangeBandwidth(freq, tx) # Verify assumptions. npt.assert_equal(orbit.reference_epoch, attitude.reference_epoch) npt.assert_equal(side, "right") - + npt.assert_equal(raw.isDithered(freq, tx), False) + npt.assert_equal(raw.sarBand, 'L') + npt.assert_equal(ds.ndim, 2) + npt.assert_equal(ds.dtype, np.complex64) + print(f'Datatype of raw dataset before decoding -> {ds.dtype_storage}') + # Check quaternion convention. # RCS frame has Y-axis nearly parallel to velocity (for small rotations). # In this case they should be exactly aligned since roll == pitch == 0. diff --git a/tests/python/packages/nisar/workflows/doppler_lut_from_raw.py b/tests/python/packages/nisar/workflows/doppler_lut_from_raw.py new file mode 100644 index 000000000..02ffb66b0 --- /dev/null +++ b/tests/python/packages/nisar/workflows/doppler_lut_from_raw.py @@ -0,0 +1,251 @@ +import iscetest +from nisar.workflows import doppler_lut_from_raw +from nisar.products.readers.Raw import open_rrsd +from isce3.core import speed_of_light + +import os +import numpy as np +import numpy.testing as npt + + +def get_doppler_from_attitude_eb(raw_obj, freq_band, txrx_pol, tm_mid, + eb_angle_deg): + """ + Estimate doppler at mid pulse time from attitude due to e.g., imperfect + zero-doppler steering as well as contribution from electrical boresight. + + Parameters + ---------- + raw_obj : isce3.nisar_products.readers.Raw.RawBase + freq_band : str + txrx_pol : str + tm_mid : float + Mid azimuth time of echo in (sec) + eb_angle_deg : float + Electrtcial boresight angle in (deg) + + Returns + ------- + float + doppler centroid in (Hz) + float + total squint angle, residual yaw plus EB, in (rad) + + """ + # get orbit and attitude objects + orb = raw_obj.getOrbit() + att = raw_obj.getAttitude() + + # get state vectors and quaternions at mid echo time + pos, vel = orb.interpolate(tm_mid) + vel_mag = np.linalg.norm(vel) + quat = att.interpolate(tm_mid) + + # get sign of Y-axis from radar/antenna looking direction + sgn = {'R': -1}.get(raw_obj.identification.lookDirection[0], 1) + + # get Y-axis with proper sign in ECEF + y_ecef = quat.rotate([0, sgn, 0]) + + # get the yaw angle (rad) due to imperfect zero-doppler steering + yaw_ang = np.arccos(np.dot(vel / vel_mag, y_ecef)) + + # get the wavelength + wl = speed_of_light / raw_obj.getCenterFrequency(freq_band) + + # total squint angle + squnit_ang = np.deg2rad(eb_angle_deg) + yaw_ang + + # calculate doppler centroid + dop_cnt = 2. * vel_mag / wl * np.sin(squnit_ang) + + return dop_cnt, squnit_ang + + +class TestDopplerLutFromRaw: + # List of inputs + + # filename of ALOS1 PALSAR data over homogenenous scene like + # Amazon rainforest + filename = 'ALPSRP081257070-H1.0__A_HH_2500_LINES.h5' + + # TxRx Polarization of the echo product + txrx_pol = 'HH' + + # frequency band ''A or 'B' + freq_band = 'A' + + # abosulte MSE tolerance in Doppler centroid estimation in (Hz) + dop_cnt_err = 12.0 + + # electrical boresight (EB) angle in (deg). EB along with residual + # Yaw angle defines final squint angle (deviation from zero doppler plane) + # and thus Doppler centroid. + eb_angle_deg = 0.0 + + # azimuth block duration and time interval in (sec) + # values are chosen to result in at least 2 azimuth blocks! + az_block_dur = 0.875 + time_interval = 0.29 + + # expected prototype Chebyshev Equi-ripple filter length when subbanding + # requested in joint time-freq doppler estimation + filter_length = 33 + + # The object/values obtained from the inputs and shared by all methods + + # get ISCE3 Raw object from L0B file + raw_obj = open_rrsd(os.path.join(iscetest.data, filename)) + + # get number of bad values for thwe first range line + dset = raw_obj.getRawDataset(freq_band, txrx_pol)[0] + num_bad_vals = np.where(np.isnan(dset))[0].size + + # get ref epoch and mid azimuth pulse time of the echo + ref_epoch, az_tm = raw_obj.getPulseTimes(freq_band, txrx_pol[0]) + tm_mid = az_tm.mean() + pri = az_tm[1] - az_tm[0] + + # calculate number of azimuth blocks + _len_tm_int = int(time_interval / pri) + _len_az_blk_dur = int(az_block_dur / pri) + num_az_blocks = int(np.ceil((len(az_tm) - _len_az_blk_dur) / + _len_tm_int)) + 1 + + # get slant range + sr = raw_obj.getRanges(freq_band, txrx_pol[0]) + + # get the expected mean doppler centroid of the echo in (Hz) + dop_cnt_mean, squint_ang = get_doppler_from_attitude_eb( + raw_obj, freq_band, txrx_pol, tm_mid, eb_angle_deg) + + def _validate_doppler_lut(self, dop_lut, num_rgb_avg=1, err_msg=''): + """ + Compare mean, std, ref of Doppler LUT2d values with expected mean + within "dop_cnt_err". Check the shape and axes of the LUT2d. + + Parameters + ---------- + dop_lut : isce3.core.LUT2d + Estimated Doppler LUT + num_rgb_avg : int, default=1 + Number of range bins to be averaged. + err_msg : str, default='' + + """ + # check the shape of LUT + num_sr = self.sr.size // num_rgb_avg + lut_shape = (self.num_az_blocks, num_sr) + npt.assert_equal( + dop_lut.data.shape, lut_shape, + err_msg='Wrong shape of LUT2d {err_msg}' + ) + # check statistics of the LUT + npt.assert_allclose( + abs(dop_lut.data.mean() - self.dop_cnt_mean), 0.0, + atol=self.dop_cnt_err, + err_msg='Mean Doppler centroids from Raw exceeds error ' + f'{self.dop_cnt_err} (Hz) {err_msg}' + ) + + npt.assert_allclose( + dop_lut.data.std(), 0.0, atol=self.dop_cnt_err, + err_msg='STD of Doppler centroids from Raw exceeds error ' + f'{self.dop_cnt_err} (Hz) {err_msg}' + ) + # check the start and spacing for x-axis (slant range) in (m) + spacing_sr = num_rgb_avg * self.sr.spacing + npt.assert_allclose( + dop_lut.x_spacing, spacing_sr, + err_msg=f'Wrong slant range/X spacing of LUT {err_msg}' + ) + + start_sr = self.sr[num_rgb_avg // 2] + npt.assert_allclose( + dop_lut.x_start, start_sr, + err_msg=f'Wrong start slant range/X of LUT {err_msg}' + ) + # check azimuth time/y-axis start and spcaing by using PRI + # as absolute tol + start_az = self.az_tm[0] + self.az_block_dur / 2.0 + npt.assert_allclose( + dop_lut.y_spacing, self.time_interval, atol=self.pri, + err_msg=f'Wrong spcaing az time/Y of LUT {err_msg}' + ) + npt.assert_allclose( + dop_lut.y_start, start_az, atol=self.pri, + err_msg=f'Wrong start az time/Y of LUT {err_msg}' + ) + + def test_doppler_est_time(self): + # print expected doppler centroid and squint angle values + print( + 'Expected mean squint angle from attitude plus EB -> ' + f'{np.rad2deg(self.squint_ang) * 1e3:.1f} (mdeg)' + ) + print( + 'Expected mean Doppler centroid from attitude plus EB -> ' + f'{self.dop_cnt_mean:.2f} (Hz)' + ) + num_rgb_avg = 4 + # estimate Doppler LUT + dop_lut, dop_epoch, dop_mask, corr_coef, dop_pol, centerfreq, \ + dop_flt_coef = doppler_lut_from_raw( + self.raw_obj, num_rgb_avg=num_rgb_avg, + az_block_dur=self.az_block_dur, + time_interval=self.time_interval) + + # validate center freq + npt.assert_allclose(centerfreq, self.raw_obj.getCenterFrequency( + self.freq_band, self.txrx_pol[0]), + err_msg='Wrong center frequency') + # validate pol + npt.assert_equal(dop_pol, self.txrx_pol, err_msg='Wrong TxRx Pol') + # validate epoch + npt.assert_equal(dop_epoch, self.ref_epoch, err_msg='Wrong Ref epoch') + # validate Doppler LUT axes, shape, statistics + self._validate_doppler_lut(dop_lut, num_rgb_avg=num_rgb_avg, + err_msg=' in time approach') + # validate mask array shape and values + npt.assert_equal(dop_mask.shape, dop_lut.data.shape, + err_msg='Wrong shape of mask array') + # get total number of bad items for all azimuth blocks + tot_bad_vals = (self.num_bad_vals // num_rgb_avg) * dop_mask.shape[0] + # check the mask array False values against the expected one + npt.assert_equal(np.where(~dop_mask)[0].size, tot_bad_vals, + err_msg='Wrong number of False values in mask array') + # validate correlation coeffs shape and values + npt.assert_equal(corr_coef.shape, dop_lut.data.shape, + err_msg='Wrong shape of correlation coeffs') + + npt.assert_equal(np.all(corr_coef >= 0) and np.all(corr_coef <= 1), + True, err_msg='Correlation coeffs are out of range' + ' [0,1]') + # check filter coeffs if any + npt.assert_equal(dop_flt_coef, None, + err_msg='Existence of filter coeffs in time approach') + + def test_doppler_est_time_subband(self): + # estimate Doppler LUT + dop_lut, _, _, _, _, _, dop_flt_coef = doppler_lut_from_raw( + self.raw_obj, az_block_dur=self.az_block_dur, + time_interval=self.time_interval, subband=True) + # validate Doppler LUT axes, shape, statistics + self._validate_doppler_lut(dop_lut, num_rgb_avg=16, + err_msg=' in joint time-frequency approach') + # check filter coeffs + npt.assert_equal( + dop_flt_coef.size, self.filter_length, + err_msg='Wrong filter length in time-frequency approach' + ) + + def test_doppler_est_time_polyfit(self): + # estimate Doppler LUTf + dop_lut, _, _, _, _, _, _ = doppler_lut_from_raw( + self.raw_obj, az_block_dur=self.az_block_dur, + time_interval=self.time_interval, polyfit=True) + # validate Doppler LUT axes, shape, statistics + self._validate_doppler_lut( + dop_lut, num_rgb_avg=16, + err_msg=' in time approach with polyfitted output' + ) diff --git a/tests/python/packages/nisar/workflows/gen_doppler_range_product.py b/tests/python/packages/nisar/workflows/gen_doppler_range_product.py new file mode 100644 index 000000000..9cfa5d797 --- /dev/null +++ b/tests/python/packages/nisar/workflows/gen_doppler_range_product.py @@ -0,0 +1,30 @@ +import iscetest +from nisar.workflows.gen_doppler_range_product import gen_doppler_range_product + +import numpy.testing as npt +import argparse +import os + + +class TestGenDopplerRangeProduct: + # L0B filename + l0b_file = 'ALPSRP081257070-H1.0__A_HH_2500_LINES.h5' + ant_file = 'ALOS1_PALSAR_ANTPAT_FIVE_BEAMS.h5' + + # set input arguments + args = argparse.Namespace( + filename_l0b=os.path.join(iscetest.data, l0b_file), + freq_band='A', txrx_pol='HH', num_rgb_avg=32, + dop_method='CDE', az_block_dur=0.8753, time_interval=0.2918, + subband=False, polyfit=False, polyfit_deg=3, + plot=True, out_path='.', + antenna_file=os.path.join(iscetest.data, ant_file)) + + def test_correct_args(self): + gen_doppler_range_product(self.args) + + def test_incorrect_args(self): + # change the frequency band to a non-existing one + self.args.freq_band = 'B' + with npt.assert_raises(ValueError): + gen_doppler_range_product(self.args)