Skip to content

Commit

Permalink
feat: Improve error message when comparing Series with list literal, …
Browse files Browse the repository at this point in the history
…or when using multi-output expressions in unsupported context (#1382)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
MarcoGorelli and pre-commit-ci[bot] authored Nov 15, 2024
1 parent 4ce30d6 commit aba8584
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 10 deletions.
11 changes: 9 additions & 2 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,15 @@ def validate_column_comparand(other: Any) -> Any:

if isinstance(other, list):
if len(other) > 1:
# e.g. `plx.all() + plx.all()`
msg = "Multi-output expressions are not supported in this context"
if hasattr(other[0], "__narwhals_expr__") or hasattr(
other[0], "__narwhals_series__"
):
# e.g. `plx.all() + plx.all()`
msg = "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) are not supported in this context"
raise ValueError(msg)
msg = (
f"Expected scalar value, Series, or Expr, got list of : {type(other[0])}"
)
raise ValueError(msg)
other = other[0]
if isinstance(other, ArrowDataFrame):
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def maybe_evaluate(df: DaskLazyFrame, obj: Any) -> Any:
if isinstance(obj, DaskExpr):
results = obj._call(df)
if len(results) != 1: # pragma: no cover
msg = "Multi-output expressions not supported in this context"
msg = "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) not supported in this context"
raise NotImplementedError(msg)
result = results[0]
validate_comparand(df._native_frame, result)
Expand Down
11 changes: 9 additions & 2 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,15 @@ def validate_column_comparand(index: Any, other: Any) -> Any:

if isinstance(other, list):
if len(other) > 1:
# e.g. `plx.all() + plx.all()`
msg = "Multi-output expressions are not supported in this context"
if hasattr(other[0], "__narwhals_expr__") or hasattr(
other[0], "__narwhals_series__"
):
# e.g. `plx.all() + plx.all()`
msg = "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) are not supported in this context"
raise ValueError(msg)
msg = (
f"Expected scalar value, Series, or Expr, got list of : {type(other[0])}"
)
raise ValueError(msg)
other = other[0]
if isinstance(other, PandasLikeDataFrame):
Expand Down
3 changes: 1 addition & 2 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2692,8 +2692,7 @@ def mode(self: Self) -> Self:
return self.__class__(lambda plx: self._call(plx).mode())

def cum_count(self: Self, *, reverse: bool = False) -> Self:
r"""
Return the cumulative count of the non-null values in the column.
r"""Return the cumulative count of the non-null values in the column.
Arguments:
reverse: reverse the operation
Expand Down
3 changes: 1 addition & 2 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2630,8 +2630,7 @@ def mode(self: Self) -> Self:
return self._from_compliant_series(self._compliant_series.mode())

def cum_count(self: Self, *, reverse: bool = False) -> Self:
r"""
Return the cumulative count of the non-null values in the series.
r"""Return the cumulative count of the non-null values in the series.
Arguments:
reverse: reverse the operation
Expand Down
2 changes: 1 addition & 1 deletion tests/frame/reindex_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ def test_reindex(df_raw: Any) -> None:
result = df.with_columns(s.sort())
expected = {"a": [1, 2, 3], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]} # type: ignore[list-item]
assert_equal_data(result, expected)
with pytest.raises(ValueError, match="Multi-output expressions are not supported"):
with pytest.raises(ValueError, match="Multi-output expressions"):
nw.to_native(df.with_columns(nw.all() + nw.all()))
9 changes: 9 additions & 0 deletions tests/frame/select_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import pandas as pd
import pyarrow as pa
import pytest

import narwhals.stable.v1 as nw
Expand Down Expand Up @@ -44,3 +45,11 @@ def test_select_boolean_cols(request: pytest.FixtureRequest) -> None:
assert_equal_data(result.to_dict(as_series=False), {True: [1, 2]}) # type: ignore[dict-item]
result = df.select(nw.col([False, True])) # type: ignore[list-item]
assert_equal_data(result.to_dict(as_series=False), {True: [1, 2], False: [3, 4]}) # type: ignore[dict-item]


def test_comparison_with_list_error_message() -> None:
msg = "Expected scalar value, Series, or Expr, got list of : <class 'int'>"
with pytest.raises(ValueError, match=msg):
nw.from_native(pa.chunked_array([[1, 2, 3]]), series_only=True) == [1, 2, 3] # noqa: B015
with pytest.raises(ValueError, match=msg):
nw.from_native(pd.Series([[1, 2, 3]]), series_only=True) == [1, 2, 3] # noqa: B015

0 comments on commit aba8584

Please sign in to comment.