Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle multiple reduction runs #500

Merged
merged 3 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/snapred/backend/dao/request/MatchRunsRequest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import List

from pydantic import BaseModel


class MatchRunsRequest(BaseModel):
runNumbers: List[str]
useLiteMode: bool
23 changes: 23 additions & 0 deletions src/snapred/backend/service/CalibrationService.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
FocusSpectraRequest,
HasStateRequest,
InitializeStateRequest,
MatchRunsRequest,
SimpleDiffCalRequest,
)
from snapred.backend.dao.response.CalibrationAssessmentResponse import CalibrationAssessmentResponse
Expand Down Expand Up @@ -99,6 +100,7 @@ def __init__(self):
self.registerPath("group", self.groupCalibration)
self.registerPath("validateWritePermissions", self.validateWritePermissions)
self.registerPath("residual", self.calculateResidual)
self.registerPath("fetchMatches", self.fetchMatchingCalibrations)
return

@staticmethod
Expand Down Expand Up @@ -365,6 +367,27 @@ def load(self, run: RunConfig, version: Optional[int] = None):
"""
return self.dataFactoryService.getCalibrationRecord(run.runNumber, run.useLiteMode, version)

def matchRunsToCalibrationVersions(self, request: MatchRunsRequest) -> Dict[str, Any]:
"""
For each run in the list, find the calibration version that applies to it
"""
response = {}
for runNumber in request.runNumbers:
response[runNumber] = self.dataFactoryService.getThisOrLatestCalibrationVersion(
runNumber, request.useLiteMode
)
return response

@FromString
def fetchMatchingCalibrations(self, request: MatchRunsRequest):
calibrations = self.matchRunsToCalibrationVersions(request)
for runNumber in request.runNumbers:
if runNumber in calibrations:
self.groceryClerk.diffcal_table(runNumber, calibrations[runNumber]).useLiteMode(
request.useLiteMode
).add()
return set(self.groceryService.fetchGroceryList(self.groceryClerk.buildList())), calibrations

@FromString
def saveCalibrationToIndex(self, entry: IndexEntry):
"""
Expand Down
24 changes: 24 additions & 0 deletions src/snapred/backend/service/NormalizationService.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
from typing import Any, Dict

from snapred.backend.dao import Limit
from snapred.backend.dao.indexing.IndexEntry import IndexEntry
Expand All @@ -14,6 +15,7 @@
CreateNormalizationRecordRequest,
FarmFreshIngredients,
FocusSpectraRequest,
MatchRunsRequest,
NormalizationExportRequest,
NormalizationRequest,
SmoothDataExcludingPeaksRequest,
Expand Down Expand Up @@ -378,3 +380,25 @@ def smoothDataExcludingPeaks(self, request: SmoothDataExcludingPeaksRequest):
smoothedVanadium=request.outputWorkspace,
detectorPeaks=peaks,
).dict()

def matchRunsToNormalizationVersions(self, request: MatchRunsRequest) -> Dict[str, Any]:
"""
For each run in the list, find the calibration version that applies to it
"""
response = {}
for runNumber in request.runNumbers:
response[runNumber] = self.dataFactoryService.getThisOrLatestNormalizationVersion(
runNumber, request.useLiteMode
)
return response

@FromString
@Register("fetchMatches")
def fetchMatchingNormalizations(self, request: MatchRunsRequest):
normalizations = self.matchRunsToNormalizationVersions(request)
for runNumber in request.runNumbers:
if normalizations.get(runNumber) is not None:
self.groceryClerk.normalization(runNumber, normalizations[runNumber]).useLiteMode(
request.useLiteMode
).add()
return set(self.groceryService.fetchGroceryList(self.groceryClerk.buildList())), normalizations
50 changes: 36 additions & 14 deletions src/snapred/ui/workflow/ReductionWorkflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from snapred.backend.dao.ingredients import ArtificialNormalizationIngredients
from snapred.backend.dao.request import (
CreateArtificialNormalizationRequest,
MatchRunsRequest,
ReductionExportRequest,
ReductionRequest,
)
Expand Down Expand Up @@ -181,27 +182,48 @@ def _triggerReduction(self, workflowPresenter):
response = self.request(path="reduction/groupings", payload=request_)
self._keeps = set(response.data["groupingWorkspaces"])

for runNumber in self.runNumbers:
self._artificialNormalizationView.updateRunNumber(runNumber)
request_ = self._createReductionRequest(runNumber)
# get the calibration and normalization versions for all runs to be processed
matchRequest = MatchRunsRequest(runNumbers=self.runNumbers, useLiteMode=self.useLiteMode)
loadedCalibrations, calVersions = self.request(path="calibration/fetchMatches", payload=matchRequest).data
loadedNormalizations, normVersions = self.request(path="normalization/fetchMatches", payload=matchRequest).data
self._keeps.update(loadedCalibrations)
self._keeps.update(loadedNormalizations)

distinctNormVersions = set(normVersions.values())
if len(distinctNormVersions) > 1 and None in distinctNormVersions:
raise RuntimeError(
"Some of your workspaces require Artificial Normalization. "
"SNAPRed can currently only handle the situation where all, or none "
"of the runs require Artificial Normalization. Please clear the list "
"and try again."
)

# Validate reduction; if artificial normalization is needed, handle it
response = self.request(path="reduction/validate", payload=request_)
if ContinueWarning.Type.MISSING_NORMALIZATION in self.continueAnywayFlags:
# Validate reduction; if artificial normalization is needed, handle it
# NOTE: this logic ONLY works because we are forbidding mixed cases of artnorm or loaded norm
response = self.request(path="reduction/validate", payload=request_)
if ContinueWarning.Type.MISSING_NORMALIZATION in self.continueAnywayFlags:
if len(self.runNumbers) > 1:
raise RuntimeError(
"Currently, Artificial Normalization can only be performed on a "
"single run at a time. Please clear your run list and try again."
)
for runNumber in self.runNumbers:
self._artificialNormalizationView.updateRunNumber(runNumber)
self._artificialNormalizationView.showAdjustView()
request_ = self._createReductionRequest(runNumber)
response = self.request(path="reduction/grabWorkspaceforArtificialNorm", payload=request_)
self._artificialNormalization(workflowPresenter, response.data, runNumber)
else:
# Proceed with reduction if artificial normalization is not needed
else:
for runNumber in self.runNumbers:
self._artificialNormalizationView.showSkippedView()
request_ = self._createReductionRequest(runNumber)
response = self.request(path="reduction/", payload=request_)
if response.code == ResponseCode.OK:
record, unfocusedData = response.data.record, response.data.unfocusedData
self._finalizeReduction(record, unfocusedData)
# after each run, clean workspaces except groupings, calibrations, normalizations, and outputs
self._keeps.update(self.outputs)
self._clearWorkspaces(exclude=self._keeps, clearCachedWorkspaces=True)
workflowPresenter.advanceWorkflow()
self._finalizeReduction(response.data.record, response.data.unfocusedData)
# after each run, clean workspaces except groupings, calibrations, normalizations, and outputs
self._keeps.update(self.outputs)
self._clearWorkspaces(exclude=self._keeps, clearCachedWorkspaces=True)
workflowPresenter.advanceWorkflow()
# SPECIAL FOR THE REDUCTION WORKFLOW: clear everything _except_ the output workspaces
# _before_ transitioning to the "save" panel.
# TODO: make '_clearWorkspaces' a public method (i.e make this combination a special `cleanup` method).
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/backend/service/test_CalibrationService.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,33 @@ def test_fitPeaks(self, FitMultiplePeaksRecipe):
res = self.instance.fitPeaks(request)
assert res == FitMultiplePeaksRecipe.return_value.executeRecipe.return_value

def test_matchRuns(self):
self.instance.dataFactoryService.getThisOrLatestCalibrationVersion = mock.Mock(
side_effect=[mock.sentinel.version1, mock.sentinel.version2, mock.sentinel.version3],
)
request = mock.Mock(runNumbers=[mock.sentinel.run1, mock.sentinel.run2], useLiteMode=True)
response = self.instance.matchRunsToCalibrationVersions(request)
assert response == {
mock.sentinel.run1: mock.sentinel.version1,
mock.sentinel.run2: mock.sentinel.version2,
}

def test_fetchRuns(self):
mockCalibrations = {
mock.sentinel.run1: mock.sentinel.version1,
mock.sentinel.run2: mock.sentinel.version2,
mock.sentinel.run3: mock.sentinel.version2,
}
mockGroceries = [mock.sentinel.grocery1, mock.sentinel.grocery2, mock.sentinel.grocery2]
self.instance.matchRunsToCalibrationVersions = mock.Mock(return_value=mockCalibrations)
self.instance.groceryService.fetchGroceryList = mock.Mock(return_value=mockGroceries)
self.instance.groceryClerk = mock.Mock()

request = mock.Mock(runNumbers=[mock.sentinel.run1, mock.sentinel.run2], useLiteMode=True)
groceries, cal = self.instance.fetchMatchingCalibrations(request)
assert groceries == {mock.sentinel.grocery1, mock.sentinel.grocery2}
assert cal == mockCalibrations

def test_initializeState(self):
testCalibration = DAOFactory.calibrationParameters()
mockInitializeState = mock.Mock(return_value=testCalibration.instrumentState)
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/backend/service/test_NormalizationService.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,33 @@ def test_smoothDataExcludingPeaks(
SmoothingParameter=mockRequest.smoothingParameter,
)

def test_matchRuns(self):
self.instance.dataFactoryService.getThisOrLatestNormalizationVersion = mock.Mock(
side_effect=[mock.sentinel.version1, mock.sentinel.version2],
)
request = mock.Mock(runNumbers=[mock.sentinel.run1, mock.sentinel.run2], useLiteMode=True)
response = self.instance.matchRunsToNormalizationVersions(request)
assert response == {
mock.sentinel.run1: mock.sentinel.version1,
mock.sentinel.run2: mock.sentinel.version2,
}

def test_fetchRuns(self):
mockCalibrations = {
mock.sentinel.run1: mock.sentinel.version1,
mock.sentinel.run2: mock.sentinel.version2,
mock.sentinel.run3: mock.sentinel.version2,
}
mockGroceries = [mock.sentinel.grocery1, mock.sentinel.grocery2, mock.sentinel.grocery2]
self.instance.matchRunsToNormalizationVersions = mock.Mock(return_value=mockCalibrations)
self.instance.groceryService.fetchGroceryList = mock.Mock(return_value=mockGroceries)
self.instance.groceryClerk = mock.Mock()

request = mock.Mock(runNumbers=[mock.sentinel.run1, mock.sentinel.run2], useLiteMode=True)
groceries, cal = self.instance.fetchMatchingNormalizations(request)
assert groceries == {mock.sentinel.grocery1, mock.sentinel.grocery2}
assert cal == mockCalibrations

def test_normalizationAssessment(self):
self.instance = NormalizationService()
self.instance.sousChef = SculleryBoy()
Expand Down