Skip to content

Commit

Permalink
Merge pull request #37 from lsst/tickets/DM-44287
Browse files Browse the repository at this point in the history
DM-44287: Refactor error handling to work better with measurement framework and quiet excessive logging
  • Loading branch information
erykoff authored May 14, 2024
2 parents af2e528 + f4033f0 commit d724642
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 60 deletions.
97 changes: 53 additions & 44 deletions python/lsst/meas/extensions/gaap/_gaap.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,33 +44,12 @@


class GaapConvolutionError(measBase.MeasurementError):
"""Collection of any unexpected errors in GAaP during PSF Gaussianization.
The PSF Gaussianization procedure using `modelPsfMatchTask` may throw
exceptions for certain target PSFs. Such errors are caught until all
measurements are at least attempted. The complete traceback information
is lost, but unique error messages are preserved.
Parameters
----------
errors : `dict` [`str`, `Exception`]
The values are exceptions raised, while the keys are the loop variables
(in `str` format) where the exceptions were raised.
"""Raised when there is an error in GAaP convolution.
"""
def __init__(self, errors: dict[str, Exception]):
self.errorDict = errors
message = "Problematic scaling factors = "
message += ", ".join(errors)
message += " Errors: "
message += " | ".join(set(msg.__repr__() for msg in errors.values())) # msg.cpp.what() misses type
super().__init__(message, 1) # the second argument does not matter.


class NoPixelError(Exception):
class NoPixelError(measBase.MeasurementError):
"""Raised when the footprint has no pixels.
This is caught by the measurement framework, which then calls the
`fail` method of the plugin without passing in a value for `error`.
"""


Expand Down Expand Up @@ -175,7 +154,6 @@ def _sigmas(self) -> list:

def setDefaults(self) -> None:
# Docstring inherited
# TODO: DM-27482 might change these values.
self._modelPsfMatch.kernel.active.alardNGauss = 1
self._modelPsfMatch.kernel.active.alardDegGaussDeconv = 1
self._modelPsfMatch.kernel.active.alardDegGauss = [4]
Expand Down Expand Up @@ -557,21 +535,28 @@ def _gaussianizeAndMeasure(self, measRecord: lsst.afw.table.SourceRecord,
This method is the entry point to the mixin from the concrete derived
classes.
"""
# First make sure we have a PSF.
if (psf := exposure.getPsf()) is None:
raise measBase.FatalAlgorithmError("No PSF in exposure")

# Raise errors if the plugin would fail for this record for all
# scaling factors and sigmas.
if measRecord.getFootprint().getArea() == 0:
self._setFlag(measRecord, self.name, "no_pixel")
raise NoPixelError

if (psf := exposure.getPsf()) is None:
raise measBase.FatalAlgorithmError("No PSF in exposure")
self._setScalingAndSigmaFlags(measRecord, self.config.scalingFactors)
raise NoPixelError("No good pixels in footprint", 1)

psfSigma = psf.computeShape(center).getTraceRadius()
if not (psfSigma > 0): # This captures NaN and negative values.
errorCollection = {str(scalingFactor): measBase.MeasurementError("PSF size could not be measured")
for scalingFactor in self.config.scalingFactor}
raise GaapConvolutionError(errorCollection)
center = measRecord.getCentroid()
self.log.debug("Invalid PSF sigma; cannot solve for PSF matching kernel in GAaP for (%f, %f): %s",
center.getX(), center.getY(), "GAaP Convolution Error")
self._setScalingAndSigmaFlags(
measRecord,
self.config.scalingFactors,
specificFlag="flag_gaussianization",
)
raise GaapConvolutionError("Failed to solve for PSF matching kernel", 1)
else:
errorCollection = dict()

Expand Down Expand Up @@ -630,7 +615,19 @@ def _gaussianizeAndMeasure(self, measRecord: lsst.afw.table.SourceRecord,
# Raise GaapConvolutionError before exiting the plugin
# if the collection of errors is not empty
if errorCollection:
raise GaapConvolutionError(errorCollection)
message = "Problematic scaling factors = "
message += ", ".join(errorCollection)
message += " Errors: "
message += " | ".join(set(msg.__repr__() for msg in errorCollection.values()))
center = measRecord.getCentroid()
self.log.debug("Failed to solve for PSF matching kernel in GAaP for (%f, %f): %s",
center.getX(), center.getY(), message)
self._setScalingAndSigmaFlags(
measRecord,
errorCollection.keys(),
specificFlag="flag_gaussianization",
)
raise GaapConvolutionError("Failed to solve for PSF matching kernel", 1)

@staticmethod
def _setFlag(measRecord, baseName, flagName=None):
Expand Down Expand Up @@ -658,6 +655,27 @@ def _setFlag(measRecord, baseName, flagName=None):
genericFlagKey = measRecord.schema.join(baseName, "flag")
measRecord.set(genericFlagKey, True)

def _setScalingAndSigmaFlags(self, measRecord, scalingFactors, specificFlag=None):
"""Set a full suite of flags for scalingFactors/sigmas.
Parameters
----------
measRecord : `~lsst.afw.table.SourceRecord`
Record describing the source being measured.
scalingFactors : `list` [`float`]
List of scaling factors.
specificFlag : `str`, optional
Specific type of flag to set if needed.
"""
for scalingFactor in scalingFactors:
if specificFlag is not None:
flagName = self.ConfigClass._getGaapResultName(scalingFactor, specificFlag,
self.name)
measRecord.set(flagName, True)
for sigma in self.config._sigmas:
baseName = self.ConfigClass._getGaapResultName(scalingFactor, sigma, self.name)
self._setFlag(measRecord, baseName)

def _isAllFailure(self, measRecord, scalingFactor, targetSigma) -> bool:
"""Check if all measurements would result in failure.
Expand Down Expand Up @@ -722,18 +740,9 @@ def fail(self, measRecord, error=None):
error : `Exception`
Error causing failure, or `None`.
"""
if error is not None:
center = measRecord.getCentroid()
self.log.error("Failed to solve for PSF matching kernel in GAaP for (%f, %f): %s",
center.getX(), center.getY(), error)
for scalingFactor in error.errorDict:
flagName = self.ConfigClass._getGaapResultName(scalingFactor, "flag_gaussianization",
self.name)
measRecord.set(flagName, True)
for sigma in self.config._sigmas:
baseName = self.ConfigClass._getGaapResultName(scalingFactor, sigma, self.name)
self._setFlag(measRecord, baseName)
else:
# We only need to set the failKey if no error was specified which
# signifies that the flagging was already handled.
if error is None:
measRecord.set(self._failKey, True)


Expand Down
3 changes: 2 additions & 1 deletion python/lsst/meas/extensions/gaap/_gaussianizePsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,8 @@ def _solve(self, kernelCellSet, basisList):
spatialKernel, spatialBackground = spatialkv.getSolutionPair()
spatialSolution = spatialkv.getKernelSolution()
except Exception as e:
self.log.error("ERROR: Unable to calculate psf matching kernel")
# This is just a debug log because it is caught by the GAaP plugin.
self.log.debug("Unable to calculate psf matching kernel")
getTraceLogger(self.log.getChild("_solve"), 1).debug("%s", e)
raise e

Expand Down
92 changes: 77 additions & 15 deletions tests/test_gaap.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ def check(self, psfSigma=0.5, flux=1000., scalingFactors=[1.15], forced=False):
measConfig = TaskClass.ConfigClass()
algName = "ext_gaap_GaapFlux"

# Remove sky coordinate plugin because we don't have the columns
# in the tests.
if "base_SkyCoord" in measConfig.plugins.names:
measConfig.plugins.names.remove("base_SkyCoord")

measConfig.plugins.names.add(algName)

if forced:
Expand Down Expand Up @@ -259,43 +264,100 @@ def testFail(self, scalingFactors=[100.], sigmas=[500.]):
exposure, catalog = self.dataset.realize(0.0, sfmTask.schema)
self.recordPsfShape(catalog)

# Expected error messages in the logs when running `sfmTask`.
# Expected debug messages in the logs when running `sfmTask`.
errorMessage = [("Failed to solve for PSF matching kernel in GAaP for (100.000000, 670.000000): "
"Problematic scaling factors = 100.0 "
"Errors: RuntimeError('Unable to determine kernel sum; 0 candidates')"),
("MeasurementError in ext_gaap_GaapFlux.measure on record 1: "
"Failed to solve for PSF matching kernel"),
("Failed to solve for PSF matching kernel in GAaP for (100.000000, 870.000000): "
"Problematic scaling factors = 100.0 "
"Errors: RuntimeError('Unable to determine kernel sum; 0 candidates')"),
("MeasurementError in ext_gaap_GaapFlux.measure on record 2: "
"Failed to solve for PSF matching kernel"),
("Failed to solve for PSF matching kernel in GAaP for (-10.000000, -20.000000): "
"Problematic scaling factors = 100.0 "
"Errors: RuntimeError('Unable to determine kernel sum; 0 candidates')")]
"Errors: RuntimeError('Unable to determine kernel sum; 0 candidates')"),
("MeasurementError in ext_gaap_GaapFlux.measure on record 3: "
"Failed to solve for PSF matching kernel")]

testCatalog = catalog.copy(deep=True)
plugin_logger_name = sfmTask.log.getChild(algName).name
self.assertEqual(plugin_logger_name, "lsst.measurement.ext_gaap_GaapFlux")
with self.assertLogs(plugin_logger_name, "ERROR") as cm:
sfmTask.run(catalog, exposure)
with self.assertLogs(plugin_logger_name, "DEBUG") as cm:
sfmTask.run(testCatalog, exposure)
self.assertEqual([record.message for record in cm.records], errorMessage)

self._checkAllFlags(
testCatalog,
algName,
scalingFactors,
sigmas,
gaapConfig,
specificFlag="flag_gaussianization",
)

# Trigger a "not (psfSigma > 0) error":
exposureJunkPsf = exposure.clone()
testCatalog = catalog.copy(deep=True)
junkPsf = afwDetection.GaussianPsf(1, 1, 0)
exposureJunkPsf.setPsf(junkPsf)
sfmTask.run(testCatalog, exposureJunkPsf)

self._checkAllFlags(
testCatalog,
algName,
scalingFactors,
sigmas,
gaapConfig,
specificFlag="flag_gaussianization",
)

# Trigger a NoPixelError.
testCatalog = catalog.copy(deep=True)
testCatalog[0].setFootprint(afwDetection.Footprint())
with self.assertLogs(plugin_logger_name, "DEBUG") as cm:
sfmTask.run(testCatalog, exposure)

self.assertEqual(
cm.records[0].message,
"MeasurementError in ext_gaap_GaapFlux.measure on record 1: No good pixels in footprint",
)
self.assertEqual(testCatalog[f"{algName}_flag_no_pixel"][0], True)
self.assertEqual(testCatalog[f"{algName}_flag"][0], True)

self._checkAllFlags(testCatalog[0: 1], algName, scalingFactors, sigmas, gaapConfig, allFailFlag=True)

# Try and "fail" with no PSF.
# Since fatal exceptions are not caught by the measurement framework,
# use a context manager and catch it here.
exposure.setPsf(None)
with self.assertRaises(lsst.meas.base.FatalAlgorithmError):
sfmTask.run(catalog, exposure)

def _checkAllFlags(
self,
catalog,
algName,
scalingFactors,
sigmas,
gaapConfig,
specificFlag=None,
allFailFlag=False
):
for record in catalog:
self.assertFalse(record[algName + "_flag"])
self.assertEqual(record[algName + "_flag"], allFailFlag)
for scalingFactor in scalingFactors:
flagName = gaapConfig._getGaapResultName(scalingFactor, "flag_gaussianization", algName)
self.assertTrue(record[flagName])
if specificFlag is not None:
flagName = gaapConfig._getGaapResultName(scalingFactor, specificFlag, algName)
self.assertTrue(record[flagName])
for sigma in sigmas + ["Optimal"]:
baseName = gaapConfig._getGaapResultName(scalingFactor, sigma, algName)
self.assertTrue(record[baseName + "_flag"])
self.assertFalse(record[baseName + "_flag_bigPsf"])

baseName = gaapConfig._getGaapResultName(scalingFactor, "PsfFlux", algName)
self.assertTrue(record[baseName + "_flag"])

# Try and "fail" with no PSF.
# Since fatal exceptions are not caught by the measurement framework,
# use a context manager and catch it here.
exposure.setPsf(None)
with self.assertRaises(lsst.meas.base.FatalAlgorithmError):
sfmTask.run(catalog, exposure)

def testFlags(self, sigmas=[0.4, 0.5, 0.7], scalingFactors=[1.15, 1.25, 1.4, 100.]):
"""Test that GAaP flags are set properly.
Expand Down

0 comments on commit d724642

Please sign in to comment.