Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend DataFrame parameter support to other libraries #975

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 38 additions & 7 deletions param/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2277,7 +2277,11 @@ def deserialize(cls, value):

class DataFrame(ClassSelector):
"""
Parameter whose value is a pandas DataFrame.
Parameter whose value is a DataFrame of one of the enabled libraries.

The supported libraries can be controlled with the libraries argument.
Currently pandas is supported by default and both pandas and polars
can be enabled.

The structure of the DataFrame can be constrained by the rows and
columns arguments:
Expand All @@ -2294,31 +2298,54 @@ class DataFrame(ClassSelector):
same columns and in the same order and no other columns.
"""

__slots__ = ['rows', 'columns', 'ordered']
__slots__ = ['rows', 'columns', 'ordered', 'libraries']

_slot_defaults = _dict_update(
ClassSelector._slot_defaults, rows=None, columns=None, ordered=None
ClassSelector._slot_defaults, rows=None, columns=None, ordered=None, libraries=None
)

_supported_libraries = ('pandas', 'polars')

@typing.overload
def __init__(
self,
default=None, *, rows=None, columns=None, ordered=None, is_instance=True,
default=None, *, rows=None, columns=None, ordered=None, libraries=None, is_instance=True,
allow_None=False, doc=None, label=None, precedence=None, instantiate=True,
constant=False, readonly=False, pickle_default_value=True, per_instance=True,
allow_refs=False, nested_refs=False
):
...

@_deprecate_positional_args
def __init__(self, default=Undefined, *, rows=Undefined, columns=Undefined, ordered=Undefined, **params):
from pandas import DataFrame as pdDFrame
def __init__(self, default=Undefined, *, rows=Undefined, columns=Undefined, ordered=Undefined, libraries=Undefined, **params):
if libraries in (None, Undefined):
libraries = ('pandas',)
elif any(l not in self._supported_libraries for l in libraries):
raise ValueError(f'DataFrame parameter libraries must be one of {self._supported_libraries}')
self.rows = rows
self.columns = columns
self.ordered = ordered
super().__init__(default=default, class_=pdDFrame, **params)
self.libraries = libraries
super().__init__(default=default, class_=None, **params)
self._validate(self.default)

@property
def class_(self):
types = ()
if 'pandas' in self.libraries and 'pandas' in sys.modules:
import pandas as pd
types += (pd.DataFrame,)
if 'polars' in self.libraries and 'polars' in sys.modules:
import polars as pl
types += (pl.DataFrame, pl.LazyFrame)
if not types:
return type(None)
return types if len(types) > 1 else types[0]

@class_.setter
def class_(self, value):
pass # This is automatically determined from the libraries

def _length_bounds_check(self, bounds, length, name):
message = f'{name} length {length} does not match declared bounds of {bounds}'
if not isinstance(bounds, tuple):
Expand Down Expand Up @@ -2374,6 +2401,10 @@ def _validate(self, val):
def serialize(cls, value):
if value is None:
return None
if hasattr(value, 'collect'):
value = value.collect() # Polars LazyFrame
if hasattr(value, 'to_dicts'):
philippjfr marked this conversation as resolved.
Show resolved Hide resolved
return value.to_dicts()
return value.to_dict('records')

@classmethod
Expand Down
Loading