Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Proposal: Refactor image loading transforms for better modularity and extensibility #3198

Open
shenshanf opened this issue Nov 3, 2024 · 2 comments

Comments

@shenshanf
Copy link

What is the feature?

Current Issues:

  1. The LoadImageFromFile class has hardcoded key mappings, making it difficult to extend for different use cases like stereo or multi-view scenarios.
  2. The core image loading logic is duplicated when creating new loading transforms.
  3. The current inheritance-based approach for creating new loading transforms (like stereo) is not flexible enough for dynamic scenarios (e.g., varying number of views in MVS).

Proposed Solution:
Extract the core image loading logic into a separate ImageLoader class and refactor the transforms to use this common functionality. Here's a basic implementation proposal:

class ImageLoader:
    """Core functionality for loading images."""
    def __init__(self,
                 to_float32: bool = False,
                 color_type: str = 'color',
                 imdecode_backend: str = 'cv2',
                 ignore_empty: bool = False,
                 backend_args: Optional[dict] = None):
        self.to_float32 = to_float32
        self.color_type = color_type
        self.imdecode_backend = imdecode_backend
        self.ignore_empty = ignore_empty
        self.backend_args = backend_args

    def load(self, filepath: str) -> Optional[np.ndarray]:
        """Load single image from file."""
        try:
            img_bytes = fileio.get(filepath, backend_args=self.backend_args)
            img = mmcv.imfrombytes(
                img_bytes, 
                flag=self.color_type,
                backend=self.imdecode_backend
            )
            if self.to_float32:
                img = img.astype(np.float32)
            return img
        except Exception as e:
            if self.ignore_empty:
                return None
            raise e

### Any other context?

_No response_
@shenshanf
Copy link
Author

@TRANSFORMS.register_module()
class LoadImageFromFile(BaseTransform):
    """Load single image from file."""

    def __init__(self,
                 to_float32: bool = False,
                 color_type: str = 'color',
                 imdecode_backend: str = 'cv2',
                 ignore_empty: bool = False,
                 backend_args: Optional[dict] = None):
        super().__init__()
        self.loader = ImageLoader(
            to_float32=to_float32,
            color_type=color_type,
            imdecode_backend=imdecode_backend,
            ignore_empty=ignore_empty,
            backend_args=backend_args
        )

    def transform(self, results: dict) -> Optional[dict]:
        img = self.loader.load(results['img_path'])
        if img is None:
            return None

        results['img'] = img
        results['img_shape'] = img.shape[:2]
        results['ori_shape'] = img.shape[:2]
        return results


@TRANSFORMS.register_module()
class LoadMultiViewImage(BaseTransform):
    """Load multiple view images."""

    def __init__(self,
                 to_float32: bool = False,
                 color_type: str = 'color',
                 imdecode_backend: str = 'cv2',
                 ignore_empty: bool = False,
                 backend_args: Optional[dict] = None):
        super().__init__()
        self.loader = ImageLoader(
            to_float32=to_float32,
            color_type=color_type,
            imdecode_backend=imdecode_backend,
            ignore_empty=ignore_empty,
            backend_args=backend_args
        )

    def transform(self, results: dict) -> Optional[dict]:
        # 加载多视角图像
        imgs = []
        shapes = []

        for filepath in results['img_paths']:
            img = self.loader.load(filepath)
            if img is None:
                return None

            imgs.append(img)
            shapes.append(img.shape[:2])

        results['imgs'] = imgs
        results['img_shapes'] = shapes
        results['ori_shapes'] = shapes.copy()
        return results

@shenshanf
Copy link
Author

@TRANSFORMS.register_module()
class LoadStereoImage(BaseTransform):
    """Load stereo images."""

    def __init__(self,
                 to_float32: bool = False,
                 color_type: str = 'color',
                 imdecode_backend: str = 'cv2',
                 ignore_empty: bool = False,
                 backend_args: Optional[dict] = None):
        super().__init__()
        self.loader = ImageLoader(
            to_float32=to_float32,
            color_type=color_type,
            imdecode_backend=imdecode_backend,
            ignore_empty=ignore_empty,
            backend_args=backend_args
        )

    def transform(self, results: dict) -> Optional[dict]:
        # 加载左图
        left_img = self.loader.load(results['left_img_path'])
        if left_img is None:
            return None

        # 加载右图
        right_img = self.loader.load(results['right_img_path'])
        if right_img is None:
            return None

        # 更新结果
        results.update({
            'left_img': left_img,
            'right_img': right_img,
            'left_img_shape': left_img.shape[:2],
            'right_img_shape': right_img.shape[:2],
            'ori_left_shape': left_img.shape[:2],
            'ori_right_shape': right_img.shape[:2]
        })
        return results

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant