Skip to content

Commit

Permalink
- Rename processor configs from ProcessorNameConfig to Config.
Browse files Browse the repository at this point in the history
- Add explicit defaults for 2D EM data to flow_config.
- Remove defaults from the respective Configs' data

PiperOrigin-RevId: 691539391
  • Loading branch information
timblakely authored and copybara-github committed Nov 6, 2024
1 parent 20426e0 commit 1e39e30
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 57 deletions.
133 changes: 121 additions & 12 deletions pipeline/flow_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# coding=utf-8
# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -19,29 +19,138 @@
"""

import dataclasses
import enum
from typing import Any

from connectomics.common import utils
import dataclasses_json
from sofima.processor import flow


EstimateFlowConfig = flow.EstimateFlow.EstimateFlowConfig
ReconcileFlowsConfig = flow.ReconcileAndFilterFlows.ReconcileFlowsConfig
EstimateMissingFlowConfig = flow.EstimateMissingFlow.EstimateMissingFlowConfig
# TODO(blakely): Combine with mesh_config.DefaultPipeline
class DefaultPipeline(enum.Enum):
EM_2D = 'em_2d'


@dataclasses.dataclass(frozen=True)
class FlowPipelineConfig(dataclasses_json.DataClassJsonMixin):
"""Configuration for end-to-end SOFIMA flow pipelines."""

estimate_flow: EstimateFlowConfig = dataclasses.field(
default_factory=EstimateFlowConfig,
estimate_flow: flow.EstimateFlow.Config
reconcile_flows: flow.ReconcileAndFilterFlows.Config
estimate_missing_flow: flow.EstimateMissingFlow.Config
reconcile_missing_flows: flow.ReconcileAndFilterFlows.Config


def default_em_2d_estimate_flow_config(
overrides: dict[str, Any] | None = None,
) -> flow.EstimateFlow.Config:
"""Default configuration for estimating flow fields in EM 2D data."""
config = flow.EstimateFlow.Config(
patch_size=160,
stride=40,
z_stride=1,
fixed_current=False,
mask_configs=None,
mask_only_for_patch_selection=True,
selection_mask_configs=None,
batch_size=1024,
)
reconcile_flows: ReconcileFlowsConfig = dataclasses.field(
default_factory=ReconcileFlowsConfig,
if overrides is not None:
config = utils.update_dataclass(config, overrides)
return config


def default_em_2d_reconcile_flows_config(
overrides: dict[str, Any] | None = None,
) -> flow.ReconcileAndFilterFlows.Config:
"""Default configuration for reconciling flow fields in EM 2D data."""
config = flow.ReconcileAndFilterFlows.Config(
flow_volinfos=None,
mask_configs=None,
min_peak_ratio=1.6,
min_peak_sharpness=1.6,
max_magnitude=40,
max_deviation=10,
max_gradient=40,
min_patch_size=400,
multi_section=False,
base_delta_z=1,
)
estimate_missing_flow: EstimateMissingFlowConfig = dataclasses.field(
default_factory=EstimateMissingFlowConfig,
if overrides is not None:
config = utils.update_dataclass(config, overrides)
return config


def default_em_2d_estimate_missing_flow_config(
overrides: dict[str, Any] | None = None,
) -> flow.EstimateMissingFlow.Config:
"""Default configuration for estimating missing flow fields in EM 2D data."""
config = flow.EstimateMissingFlow.Config(
patch_size=160,
stride=40,
delta_z=1,
max_delta_z=4,
max_attempts=2,
mask_configs=None,
mask_only_for_patch_selection=True,
selection_mask_configs=None,
min_peak_ratio=1.6,
min_peak_sharpness=1.6,
max_magnitude=40,
batch_size=1024,
image_volinfo=None,
image_cache_bytes=int(1e9),
mask_cache_bytes=int(1e9),
search_radius=0,
)
reconcile_missing_flows: ReconcileFlowsConfig = dataclasses.field(
default_factory=ReconcileFlowsConfig,
if overrides is not None:
config = utils.update_dataclass(config, overrides)
return config


def default_em_2d_reconcile_missing_flows_config(
overrides: dict[str, Any] | None = None,
) -> flow.ReconcileAndFilterFlows.Config:
"""Default configuration for reconciling missing flow fields in EM 2D data."""
config = utils.update_dataclass(
default_em_2d_reconcile_flows_config(),
{
'multi_section': True,
'max_magnitude': 0,
'max_deviation': 10,
'max_gradient': 10,
'min_patch_size': 400,
'base_delta_z': 1,
},
)
if overrides is not None:
config = utils.update_dataclass(config, overrides)
return config


def default_em_2d(
overrides: dict[str, Any] | None = None,
) -> FlowPipelineConfig:
"""Default flow pipeline configuration for EM 2D data."""
config = FlowPipelineConfig(
estimate_flow=default_em_2d_estimate_flow_config(),
reconcile_flows=default_em_2d_reconcile_flows_config(),
estimate_missing_flow=default_em_2d_estimate_missing_flow_config(),
reconcile_missing_flows=default_em_2d_reconcile_missing_flows_config(),
)
if overrides is not None:
config = utils.update_dataclass(config, overrides)
return config


_DEFAULT_CONFIG_TYPE_DISPATCH = {
DefaultPipeline.EM_2D: default_em_2d,
}


def default(
default_type: DefaultPipeline, overrides: dict[str, Any] | None = None
) -> FlowPipelineConfig:
"""Default flow pipeline configuration for a given data type."""
return _DEFAULT_CONFIG_TYPE_DISPATCH[default_type](overrides)
90 changes: 45 additions & 45 deletions processor/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class EstimateFlow(subvolume_processor.SubvolumeProcessor):
"""

@dataclasses.dataclass(eq=True)
class EstimateFlowConfig(utils.NPDataClassJsonMixin):
class Config(utils.NPDataClassJsonMixin):
"""Configuration for EstimateFlow.
Attributes:
Expand All @@ -84,18 +84,18 @@ class EstimateFlowConfig(utils.NPDataClassJsonMixin):
batch_size: Max number of patches to process in parallel.
"""

patch_size: int = 160
stride: int = 40
z_stride: int = 1
fixed_current: bool = False
mask_configs: str | mask_lib.MaskConfigs | None = None
mask_only_for_patch_selection: bool = False
selection_mask_configs: mask_lib.MaskConfigs | None = None
batch_size: int = 1024
patch_size: int
stride: int
z_stride: int
fixed_current: bool
mask_configs: str | mask_lib.MaskConfigs | None
mask_only_for_patch_selection: bool
selection_mask_configs: mask_lib.MaskConfigs | None
batch_size: int

_config: EstimateFlowConfig
_config: Config

def __init__(self, config: EstimateFlowConfig, input_volinfo_or_ts_spec=None):
def __init__(self, config: Config, input_volinfo_or_ts_spec=None):
"""Constructor.
Args:
Expand Down Expand Up @@ -289,7 +289,7 @@ class ReconcileAndFilterFlows(subvolume_processor.SubvolumeProcessor):
crop_at_borders = False

@dataclasses.dataclass(eq=True)
class ReconcileFlowsConfig(utils.NPDataClassJsonMixin):
class Config(utils.NPDataClassJsonMixin):
"""Configuration for ReconcileAndFilterFlows.
Attributes:
Expand All @@ -313,23 +313,23 @@ class ReconcileFlowsConfig(utils.NPDataClassJsonMixin):
channel to initialize the output flow with
"""

flow_volinfos: Sequence[str] | str | None = None
mask_configs: str | mask_lib.MaskConfigs | None = None
min_peak_ratio: float = 1.6
min_peak_sharpness: float = 1.6
max_magnitude: float = 40
max_deviation: float = 10
max_gradient: float = 40
min_patch_size: int = 400
multi_section: bool = False
base_delta_z: int = 0

_config: ReconcileFlowsConfig
flow_volinfos: Sequence[str] | str | None
mask_configs: str | mask_lib.MaskConfigs | None
min_peak_ratio: float
min_peak_sharpness: float
max_magnitude: float
max_deviation: float
max_gradient: float
min_patch_size: int
multi_section: bool
base_delta_z: int

_config: Config
_metadata: list[metadata.VolumeMetadata] = []

def __init__(
self,
config: ReconcileFlowsConfig,
config: Config,
input_path_or_metadata: (
file.PathLike | metadata.VolumeMetadata | None
) = None,
Expand Down Expand Up @@ -495,7 +495,7 @@ class EstimateMissingFlow(subvolume_processor.SubvolumeProcessor):
"""

@dataclasses.dataclass(frozen=True)
class EstimateMissingFlowConfig(dataclasses_json.DataClassJsonMixin):
class Config(dataclasses_json.DataClassJsonMixin):
"""Configuration for EstimateMissingFlow.
Attributes:
Expand Down Expand Up @@ -530,28 +530,28 @@ class EstimateMissingFlowConfig(dataclasses_json.DataClassJsonMixin):
direction when extracting data for the 'previous' section
"""

patch_size: int = 160
stride: int = 40
delta_z: int = 1
max_delta_z: int = 4
max_attempts: int = 2
mask_configs: str | mask_lib.MaskConfigs | None = None
mask_only_for_patch_selection: bool = True
selection_mask_configs: str | mask_lib.MaskConfigs | None = None
min_peak_ratio: float = 1.6
min_peak_sharpness: float = 1.6
max_magnitude: int = 40
batch_size: int = 1024
image_volinfo: str | None = None
image_cache_bytes: int = int(1e9)
mask_cache_bytes: int = int(1e9)
search_radius: int = 0

_config: EstimateMissingFlowConfig
patch_size: int
stride: int
delta_z: int
max_delta_z: int
max_attempts: int
mask_configs: str | mask_lib.MaskConfigs | None
mask_only_for_patch_selection: bool
selection_mask_configs: str | mask_lib.MaskConfigs | None
min_peak_ratio: float
min_peak_sharpness: float
max_magnitude: int
batch_size: int
image_volinfo: str | None
image_cache_bytes: int
mask_cache_bytes: int
search_radius: int

_config: Config

def __init__(
self,
config: EstimateMissingFlowConfig,
config: Config,
input_volinfo_or_ts_spec=None,
):
"""Constructor.
Expand Down

0 comments on commit 1e39e30

Please sign in to comment.