diff --git a/tjpcov/__init__.py b/tjpcov/__init__.py index dbfc2a06..0957d84c 100644 --- a/tjpcov/__init__.py +++ b/tjpcov/__init__.py @@ -1,5 +1,4 @@ -# flake8: noqa -from .covariance_builder import CovarianceBuilder +#!/usr/bin/python3 def covariance_from_name(name): @@ -11,17 +10,18 @@ def covariance_from_name(name): Returns: :class:`~tjpcov.covariance_builder.CovarianceBuilder` child class """ - - def all_subclasses(cls): - # Recursively find all subclasses (and their subclasses) - # From https://stackoverflow.com/questions/3862310 - return set(cls.__subclasses__()).union( - [s for c in cls.__subclasses__() for s in all_subclasses(c)] - ) - - subcs = all_subclasses(CovarianceBuilder) - mappers = {m.__name__: m for m in subcs} - if name in mappers: - return mappers[name] + # TODO: Make this automatic + if name == "FourierGaussianNmt": + from .covariance_fourier_gaussian_nmt import FourierGaussianNmt as Cov + elif name == "FourierSSCHaloModel": + from .covariance_fourier_ssc import FourierSSCHaloModel as Cov + elif name == "ClusterCounts": + from .covariance_cluster_counts import ClusterCounts as Cov + elif name == "FourierGaussianFsky": + from .covariance_gaussian_fsky import FourierGaussianFsky as Cov + elif name == "RealGaussianFsky": + from .covariance_gaussian_fsky import RealGaussianFsky as Cov else: raise ValueError(f"Unknown covariance {name}") + + return Cov