Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Request for Multi-Input Support in BaseModel #1584

Open
shenshanf opened this issue Oct 15, 2024 · 1 comment
Open

[Feature] Request for Multi-Input Support in BaseModel #1584

shenshanf opened this issue Oct 15, 2024 · 1 comment

Comments

@shenshanf
Copy link

What is the feature?

Description

The current implementation of BaseModel in mmengine assumes a single inputs parameter of type torch.Tensor in the forward method:

def forward(self, inputs: torch.Tensor, data_samples: Optional[list] = None, mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]:

While this works well for many scenarios, it poses challenges for models that require multiple, disparate inputs that cannot be easily concatenated into a single tensor. For example, in multi-modal learning tasks, we might need to process point cloud data [b,n,c] and image data [b,c1,h,w] simultaneously.

Current Workarounds and Their Limitations

  1. Concatenating inputs: Not suitable for inputs with different dimensions or semantic meanings.
  2. Using a dictionary input: Breaks compatibility with the current BaseModel interface.
  3. Creating a custom data structure: Also breaks compatibility and requires significant changes to existing codebases.
  4. Setting inputs before calling forward: Requires modifying training loops and doesn't align with PyTorch's typical usage patterns.

Feature Request

We propose extending the BaseModel to support multiple input tensors in a way that maintains backward compatibility. This could potentially be achieved by:

  1. Allowing inputs to be a tuple or list of tensors.
  2. Adding an optional parameter for additional inputs.
  3. Creating a new base class specifically for multi-input models.

Benefits

  1. Improved flexibility for complex model architectures.
  2. Better support for multi-modal learning tasks.
  3. Easier integration of models with multiple input types.
  4. Maintains consistency with PyTorch's typical usage patterns.

Questions for Discussion

  1. What is the best way to implement this feature while maintaining backward compatibility?
  2. Are there any potential drawbacks or performance implications to consider?
  3. How might this change affect other parts of the mmengine ecosystem?

We appreciate your consideration of this feature request and look forward to any feedback or discussion on this topic.

Any other context?

No response

@shenshanf
Copy link
Author

Description

There appears to be an inconsistency between the type annotation of the inputs parameter in the BaseModel.forward() method and how it's actually used in other parts of the code.

In BaseModel.forward(), the inputs parameter is annotated as torch.Tensor:

def forward(self,
            inputs: torch.Tensor,
            data_samples: Optional[list] = None,
            mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]:

However, in the _run_forward() method, inputs is treated as either a dict, tuple, or list:

def _run_forward(self, data: Union[dict, tuple, list],
                 mode: str) -> Union[Dict[str, torch.Tensor], list]:
    if isinstance(data, dict):
        results = self(**data, mode=mode)
    elif isinstance(data, (list, tuple)):
        results = self(*data, mode=mode)
    else:
        raise TypeError('Output of `data_preprocessor` should be '
                        f'list, tuple or dict, but got {type(data)}')

Additionally, the val_step() method suggests that self.data_preprocessor outputs a sequence of tensors.

Proposed Fix

To resolve this inconsistency, we suggest modifying the type annotation of the inputs parameter in BaseModel.forward() to:

from typing import Union, Sequence

def forward(self,
            inputs: Union[torch.Tensor, Sequence[torch.Tensor]],
            data_samples: Optional[list] = None,
            mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]:

This change would make the type annotation consistent with how inputs is actually used in the codebase.

Additional Context

This issue was discovered while reviewing the BaseModel class implementation in mmengine/model/base_model.py. The inconsistency could potentially lead to type checking errors or unexpected behavior when using static type checkers or IDEs with type inference.

Thank you for your attention to this matter. We appreciate your work on mmengine and are happy to provide any additional information or clarification if needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant