Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ChannelDirectionGroup to the Prediction summaries #286

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/pdstools/adm/Aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,8 @@ def name_normalizer(x):
ChannelDirectionGroup=pl.when(
pl.col("ChannelGroup").is_not_null()
& pl.col("DirectionGroup").is_not_null()
& pl.col("ChannelGroup").is_in(["Other", "Unknown", ""]).not_()
& pl.col("DirectionGroup").is_in(["Other", "Unknown", ""]).not_()
)
.then(pl.concat_str(["ChannelGroup", "DirectionGroup"], separator="/"))
.otherwise(pl.lit("Other")),
Expand Down
12 changes: 11 additions & 1 deletion python/pdstools/prediction/Prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ def summary_by_channel(
period_expr = []

return (
self.predictions.join(
self.predictions.with_columns(pl.col("ModelName").str.to_uppercase())
.join(
self.cdh_guidelines.get_predictions_channel_mapping(
custom_predictions
).lazy(),
Expand Down Expand Up @@ -249,6 +250,15 @@ def summary_by_channel(
+ pl.col("Negatives_NBA")
),
CTR=(pl.col("Positives")) / (pl.col("ResponseCount")),
ChannelDirectionGroup=pl.when(
pl.col("Channel").is_not_null()
& pl.col("Direction").is_not_null()
& pl.col("Channel").is_in(["Other", "Unknown", ""]).not_()
& pl.col("Direction").is_in(["Other", "Unknown", ""]).not_()
& pl.col("isMultiChannelPrediction").not_()
)
.then(pl.concat_str(["Channel", "Direction"], separator="/"))
.otherwise(pl.lit("Other")),
isValid=self.prediction_validity_expr,
)
.sort(["Prediction"] + (["Period"] if by_period is not None else []))
Expand Down
78 changes: 40 additions & 38 deletions python/tests/test_Prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,17 @@
"pySnapShotTime": cdh_utils.to_prpc_date_time(datetime.datetime.now())[
0:15
], # Polars doesn't like time zones like GMT+0200
"pyModelId": ["DATA-DECISION-REQUEST-CUSTOMER!PREDICTWEBPROPENSITY"] * 4
+ ["DATA-DECISION-REQUEST-CUSTOMER!PREDICTMOBILEPROPENSITY"] * 4,
# "Channel": ["Web"] * 4 + ["Mobile"] * 4,
# "Direction": ["Inbound"] * 4 + ["Outbound"] * 4,
"pyModelId": ["DATA-DECISION-REQUEST-CUSTOMER!MYCUSTOMPREDICTION"] * 4
+ ["DATA-DECISION-REQUEST-CUSTOMER!PredictActionPropensity"] * 4
+ ["DATA-DECISION-REQUEST-CUSTOMER!PREDICTMOBILEPROPENSITY"] * 4
+ ["DATA-DECISION-REQUEST-CUSTOMER!PREDICTWEBPROPENSITY"] * 4,
"pyModelType": "PREDICTION",
"pySnapshotType": (["Daily"] * 3 + [None]) * 2,
"pyDataUsage": [
"Control",
"Test",
"NBA",
"",
"Control",
"Test",
"NBA",
"",
],
"pyPositives": [100, 400, 500, 1000, 200, 800, 1000, 2000],
"pyNegatives": [1000, 2000, 3000, 6000, 3000, 6000, 9000, 18000],
"pyCount": [1100, 2400, 3500, 7000, 3200, 6800, 10000, 20000],
"pyValue": [0.65] * 4 + [0.70] * 4,
"pySnapshotType": (["Daily"] * 3 + [None]) * 4,
"pyDataUsage": ["Control", "Test", "NBA", ""] * 4,
"pyPositives": [100, 400, 500, 1000, 200, 800, 1000, 2000] * 2,
"pyNegatives": [1000, 2000, 3000, 6000, 3000, 6000, 9000, 18000] * 2,
"pyCount": [1100, 2400, 3500, 7000, 3200, 6800, 10000, 20000] * 2,
"pyValue": ([0.65] * 4 + [0.70] * 4) * 2,
}
).lazy()

Expand Down Expand Up @@ -94,24 +85,24 @@ def test_summary_by_channel_cols(test):
"ControlPercentage",
"TestPercentage",
"CTR",
"ChannelDirectionGroup",
"isValid",
]
assert len(summary) == 2


def test_summary_by_channel_channels(test):
summary = test.summary_by_channel().collect()
assert summary.select(pl.len()).item() == 2
assert summary.select(pl.len()).item() == 4


def test_summary_by_channel_validity(test):
summary = test.summary_by_channel().collect()
assert summary["isValid"].to_list() == [True, True]
assert summary["isValid"].to_list() == [True, True, True, True]


def test_summary_by_channel_ia(test):
summary = test.summary_by_channel().collect()
assert summary["usesImpactAnalyzer"].to_list() == [True, True]
assert summary["usesImpactAnalyzer"].to_list() == [True, True, True, True]

test = Prediction(
mock_prediction_data.filter(
Expand All @@ -122,38 +113,48 @@ def test_summary_by_channel_ia(test):
)
)
)
# only Web still has the NBA indicator
assert test.summary_by_channel().collect()["usesImpactAnalyzer"].to_list() == [
False,
False,
False,
True,
]


def test_summary_by_channel_lift(test):
summary = test.summary_by_channel().collect()
assert [round(x, 5) for x in summary["Lift"].to_list()] == [0.88235, 0.83333]
assert [round(x, 5) for x in summary["Lift"].to_list()] == [0.83333, 0.88235] * 2


def test_summary_by_channel_controlpct(test):
summary = test.summary_by_channel().collect()
assert [round(x, 5) for x in summary["ControlPercentage"].to_list()] == [
16.0,
15.71429,
]
16.0,
] * 2
assert [round(x, 5) for x in summary["TestPercentage"].to_list()] == [
34.0,
34.28571,
]
34.0,
] * 2


def test_summary_by_channel_trend(test):
summary = test.summary_by_channel(by_period="1d").collect()
assert summary.select(pl.len()).item() == 2
assert summary.select(pl.len()).item() == 4


def test_summary_by_channel_trend2(test2):
summary = test2.summary_by_channel(by_period="1d").collect()
assert summary.select(pl.len()).item() == 4
assert summary.select(pl.len()).item() == 8


def test_summary_by_channel_channeldirectiongroup(test):
summary = test.summary_by_channel().collect()

assert summary["isMultiChannelPrediction"].to_list() == [False, True, False, False]
assert summary["isStandardNBADPrediction"].to_list() == [False, True, True, True]
assert summary["ChannelDirectionGroup"].to_list() == ["Other", "Other", "Mobile/Inbound", "Web/Inbound"]

def test_overall_summary_cols(test):
summary = test.overall_summary().collect()
Expand All @@ -174,20 +175,21 @@ def test_overall_summary_cols(test):


def test_overall_summary_n_valid_channels(test):
print(test.overall_summary().collect())
assert test.overall_summary().collect()["Number of Valid Channels"].item() == 2
assert test.overall_summary().collect()["Number of Valid Channels"].item() == 3


def test_overall_summary_overall_lift(test):
assert round(test.overall_summary().collect()["Overall Lift"].item(), 5) == 0.86964
# print(test.overall_summary().collect())
# print(test.summary_by_channel().collect())
assert round(test.overall_summary().collect()["Overall Lift"].item(), 5) == 0.86217


def test_overall_summary_positives(test):
assert test.overall_summary().collect()["Positives"].item() == 3000
assert test.overall_summary().collect()["Positives"].item() == 4000


def test_overall_summary_responsecount(test):
assert test.overall_summary().collect()["ResponseCount"].item() == 27000
assert test.overall_summary().collect()["ResponseCount"].item() == 34000


def test_overall_summary_channel_min_lift(test):
Expand All @@ -202,16 +204,16 @@ def test_overall_summary_min_lift(test):


def test_overall_summary_ctr(test):
assert round(test.overall_summary().collect()["CTR"].item(), 5) == 0.11111
assert round(test.overall_summary().collect()["CTR"].item(), 5) == 0.11765


def test_overall_summary_controlpct(test):
assert (
round(test.overall_summary().collect()["ControlPercentage"].item(), 5)
== 15.92593
== 15.88235
)
assert (
round(test.overall_summary().collect()["TestPercentage"].item(), 5) == 34.07407
round(test.overall_summary().collect()["TestPercentage"].item(), 5) == 34.11765
)


Expand Down
Loading