Skip to content

Commit

Permalink
save progress
Browse files Browse the repository at this point in the history
  • Loading branch information
rboston628 committed Nov 21, 2024
1 parent bb31f0d commit cf21ac0
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 1 deletion.
8 changes: 8 additions & 0 deletions src/snapred/backend/dao/request/MatchRunsRequest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import List

from pydantic import BaseModel


class MatchRunsRequest(BaseModel):
runNumbers: List[str]
useLiteMode: bool
23 changes: 23 additions & 0 deletions src/snapred/backend/service/CalibrationService.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
FocusSpectraRequest,
HasStateRequest,
InitializeStateRequest,
MatchRunsRequest,
SimpleDiffCalRequest,
)
from snapred.backend.dao.response.CalibrationAssessmentResponse import CalibrationAssessmentResponse
Expand Down Expand Up @@ -99,6 +100,7 @@ def __init__(self):
self.registerPath("group", self.groupCalibration)
self.registerPath("validateWritePermissions", self.validateWritePermissions)
self.registerPath("residual", self.calculateResidual)
self.registerPath("fetchMatches", self.fetchMatchingCalibrations)
return

@staticmethod
Expand Down Expand Up @@ -365,6 +367,27 @@ def load(self, run: RunConfig, version: Optional[int] = None):
"""
return self.dataFactoryService.getCalibrationRecord(run.runNumber, run.useLiteMode, version)

def matchRunsToCalibrationVersions(self, request: MatchRunsRequest) -> Dict[str, Any]:
"""
For each run in the list, find the calibration version that applies to it
"""
response = {}
for runNumber in request.runNumbers:
response[runNumber] = self.dataFactoryService.getThisOrLatestCalibrationVersion(
runNumber, request.useLiteMode
)
return response

@FromString
def fetchMatchingCalibrations(self, request: MatchRunsRequest):
calibrations = self.matchRunsToCalibrationVersions(request)
for runNumber in request.runNumbers:
if runNumber in calibrations:
self.groceryClerk.diffcal_table(runNumber, calibrations[runNumber]).useLiteMode(
request.useLiteMode
).add()
return set(self.groceryService.fetchGroceryList(self.groceryClerk.buildList())), calibrations

@FromString
def saveCalibrationToIndex(self, entry: IndexEntry):
"""
Expand Down
24 changes: 24 additions & 0 deletions src/snapred/backend/service/NormalizationService.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
from typing import Any, Dict

from snapred.backend.dao import Limit
from snapred.backend.dao.indexing.IndexEntry import IndexEntry
Expand All @@ -14,6 +15,7 @@
CreateNormalizationRecordRequest,
FarmFreshIngredients,
FocusSpectraRequest,
MatchRunsRequest,
NormalizationExportRequest,
NormalizationRequest,
SmoothDataExcludingPeaksRequest,
Expand Down Expand Up @@ -378,3 +380,25 @@ def smoothDataExcludingPeaks(self, request: SmoothDataExcludingPeaksRequest):
smoothedVanadium=request.outputWorkspace,
detectorPeaks=peaks,
).dict()

def matchRunsToNormalizationVersions(self, request: MatchRunsRequest) -> Dict[str, Any]:
"""
For each run in the list, find the calibration version that applies to it
"""
response = {}
for runNumber in request.runNumbers:
response[runNumber] = self.dataFactoryService.getThisOrLatestNormalizationVersion(
runNumber, request.useLiteMode
)
return response

@FromString
@Register("fetchMatches")
def fetchMatchingNormalizations(self, request: MatchRunsRequest):
normalizations = self.matchRunsToNormalizationVersions(request)
for runNumber in request.runNumbers:
if normalizations.get(runNumber) is not None:
self.groceryClerk.normalization(runNumber, normalizations[runNumber]).useLiteMode(
request.useLiteMode
).add()
return set(self.groceryService.fetchGroceryList(self.groceryClerk.buildList())), normalizations
14 changes: 13 additions & 1 deletion src/snapred/ui/view/reduction/ReductionRequestView.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from snapred.meta.decorators.Resettable import Resettable
from snapred.ui.view.BackendRequestView import BackendRequestView
from snapred.ui.widget.Toggle import Toggle
from snapred.ui.widget.StoplightIcon import StoplightColor
from snapred.ui.widget.StagedReductionRun import StagedReductionRun

logger = snapredLogger.getLogger(__name__)

Expand Down Expand Up @@ -97,6 +99,15 @@ def addRunNumber(self):
except ValueError as e:
QMessageBox.warning(self, "Warning", str(e), buttons=QMessageBox.Ok, defaultButton=QMessageBox.Ok)
self.runNumberInput.clear()
# NOTE added to temporarily fix defect in EWM 8287. Remove when a more complete solution is implemented
if len(self.runNumbers) > 1:
QMessageBox.warning(
self,
"Multirun Reduction",
"SNAPRed can currently only reduce one run at a time. Only the top run will be processed.",
buttons=QMessageBox.Ok,
defaultButton=QMessageBox.Ok,
)

def parseInputRunNumbers(self) -> List[str]:
# WARNING: run numbers are strings.
Expand Down Expand Up @@ -125,7 +136,8 @@ def parseInputRunNumbers(self) -> List[str]:

def updateRunNumberList(self):
self.runNumberDisplay.clear()
self.runNumberDisplay.addItems(self.runNumbers)
for runNumber in self.runNumbers:
self.runNumberDisplay.insertItem(0, StagedReductionRun(runNumber, StoplightColor.green))

def clearRunNumbers(self):
self.runNumbers.clear()
Expand Down
12 changes: 12 additions & 0 deletions src/snapred/ui/widget/StagedReductionRun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from qtpy.QtWidgets import QListWidgetItem
from qtpy.QtCore import QSize

from snapred.ui.widget.StoplightIcon import StoplightIcon, StoplightColor


class StagedReductionRun(QListWidgetItem):
def __init__(self, runNumber: str, color: StoplightColor):
super().__init__()
self.setText(runNumber)
self.setIcon(StoplightIcon(color=color))
self.setSizeHint(QSize(100, 24))
35 changes: 35 additions & 0 deletions src/snapred/ui/widget/StoplightIcon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from enum import Enum
from qtpy.QtGui import QPixmap, QPainter, QColor, QIcon


class StoplightColor(Enum):
green=2
yellow=1
red=0

class StoplightIcon(QIcon):
def __init__(self, color: StoplightColor = StoplightColor.red, size: int=10):
super().__init__()
pixmap = QPixmap(size, size)
pixmap.fill(QColor("transparent"))

painter = QPainter(pixmap)
painter.setRenderHint(QPainter.Antialiasing)

color = QColor(self._matchColor(color))
painter.setBrush(color)
painter.setPen(QColor("transparent"))

painter.drawEllipse(0, 0, size - 1, size - 1)
painter.end()

self.addPixmap(pixmap)

def _matchColor(self, color: StoplightColor):
match(color):
case StoplightColor.green:
return "green"
case StoplightColor.yellow:
return "yellow"
case StoplightColor.red:
return "red"
49 changes: 49 additions & 0 deletions src/snapred/ui/workflow/ReductionWorkflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
CreateArtificialNormalizationRequest,
ReductionExportRequest,
ReductionRequest,
MatchRunsRequest,
)
from snapred.backend.dao.SNAPResponse import ResponseCode, SNAPResponse
from snapred.backend.error.ContinueWarning import ContinueWarning
Expand Down Expand Up @@ -181,6 +182,19 @@ def _triggerReduction(self, workflowPresenter):
response = self.request(path="reduction/groupings", payload=request_)
self._keeps = set(response.data["groupingWorkspaces"])

# get the calibration and normalization versions for all runs to be processed
matchRequest = MatchRunsRequest(runNumbers=self.runNumbers, useLiteMode=self.useLiteMode)
loadedCalibrations, calVersions = self.request(path="calibration/fetchMatches", payload=matchRequest).data
loadedNormalizations, normVersions = self.request(path="normalization/fetchMatches", payload=matchRequest).data
self._keeps.update(loadedCalibrations)
self._keeps.update(loadedNormalizations)

if len(loadedNormalizations) > 1 and None in normVersions:
raise RuntimeError("Some of your workspaces require Artificial Normalization. "
"SNAPRed can currently only handle the situation where all, or none "
"of the runs require Artificial Normalization. Please clear the list "
"and try again.")

for runNumber in self.runNumbers:
self._artificialNormalizationView.updateRunNumber(runNumber)
request_ = self._createReductionRequest(runNumber)
Expand Down Expand Up @@ -250,6 +264,41 @@ def onArtificialNormalizationValueChange(self, smoothingValue, lss, decreasePara
self._artificialNormalizationView.updateWorkspaces(diffractionWorkspace, response.data)
self._artificialNormalizationView.enableRecalculateButton()

# def _continueWithNormalization(self, workflowPresenter): # noqa: ARG002
# """
# For each workspace run in the list, optionally create an artificial normalization,
# then reduce the data.
# """

# for runNumber in self.runNumbers[0]:
# self._artificialNormalizationView.updateRunNumber(runNumber)

# request_ = self._createReductionRequest(runNumber)

# # check if this run has a valid calibration and normalization
# # NOTE this must be performed on each individual run
# self.request(path="reduction/validate", payload=request_)
# missingNormalization = ContinueWarning.Type.MISSING_NORMALIZATION in self.continueAnywayFlags

# # if normalization is missing, need to perform artificial normalization
# if missingNormalization:
# self._artificialNormalizationView.showAdjustView()
# # response = self.request(path="reduction/grabWorkspaceforArtificialNorm", payload=request_)
# # self._artificialNormalization(workflowPresenter, response.data, runNumber)
# else:
# self._artificialNormalizationView.showSkippedView()

# response = self.request(path="reduction/", payload=request_)
# if response.code == ResponseCode.OK:
# record, unfocusedData = response.data.record, response.data.unfocusedData
# self._finalizeReduction(record, unfocusedData)
# # after each run, clean out temporary workspaces, but keep the needed groupings, calibrations, normalizations, and outputs
# self._keeps.update(self.outputs)
# self._clearWorkspaces(exclude=self._keeps, clearCachedWorkspaces=True)
# # workflowPresenter.advanceWorkflow()
# self._clearWorkspaces(exclude=self.outputs, clearCachedWorkspaces=True)
# return self.responses[-1]

def _continueWithNormalization(self, workflowPresenter): # noqa: ARG002
"""Continues the workflow using the artificial normalization workspace."""
artificialNormIngredients = ArtificialNormalizationIngredients(
Expand Down

0 comments on commit cf21ac0

Please sign in to comment.