From ccc5c85a8ba39728f710fe64e8afbe98f84bf50e Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Wed, 6 Nov 2024 11:26:07 -0800 Subject: [PATCH] fix: deepcopy error from baseline_model in pairwiseMetric PiperOrigin-RevId: 693800802 --- tests/unit/vertexai/test_evaluation.py | 48 ++++++++++++++++++++++++++ vertexai/evaluation/_evaluation.py | 23 +++++++++--- 2 files changed, 66 insertions(+), 5 deletions(-) diff --git a/tests/unit/vertexai/test_evaluation.py b/tests/unit/vertexai/test_evaluation.py index 40f32e3295..de517041d2 100644 --- a/tests/unit/vertexai/test_evaluation.py +++ b/tests/unit/vertexai/test_evaluation.py @@ -317,6 +317,20 @@ ) ), ) +_MOCK_PAIRWISE_RESULT = ( + gapic_evaluation_service_types.EvaluateInstancesResponse( + pairwise_metric_result=gapic_evaluation_service_types.PairwiseMetricResult( + pairwise_choice=gapic_evaluation_service_types.PairwiseChoice.BASELINE, + explanation="explanation", + ) + ), + gapic_evaluation_service_types.EvaluateInstancesResponse( + pairwise_metric_result=gapic_evaluation_service_types.PairwiseMetricResult( + pairwise_choice=gapic_evaluation_service_types.PairwiseChoice.BASELINE, + explanation="explanation", + ) + ), +) _MOCK_SUMMARIZATION_QUALITY_RESULT = ( gapic_evaluation_service_types.EvaluateInstancesResponse( pointwise_metric_result=gapic_evaluation_service_types.PointwiseMetricResult( @@ -1216,6 +1230,40 @@ def test_evaluate_baseline_response_column_and_baseline_model_provided(self): test_eval_task.evaluate(model=mock.MagicMock()) _TEST_PAIRWISE_METRIC._baseline_model = None + def test_evaluate_baseline_model_provided_but_no_baseline_response_column(self): + mock_baseline_model = mock.create_autospec( + generative_models.GenerativeModel, instance=True + ) + mock_baseline_model.generate_content.return_value = ( + _MOCK_MODEL_INFERENCE_RESPONSE + ) + mock_baseline_model._model_name = "publishers/google/model/gemini-pro" + _TEST_PAIRWISE_METRIC._baseline_model = mock_baseline_model + + mock_candidate_model = mock.create_autospec( + generative_models.GenerativeModel, instance=True + ) + mock_candidate_model.generate_content.return_value = ( + _MOCK_MODEL_INFERENCE_RESPONSE + ) + mock_candidate_model._model_name = "publishers/google/model/gemini-1.0-pro" + mock_metric_results = _MOCK_PAIRWISE_RESULT + eval_dataset = _TEST_EVAL_DATASET_WITHOUT_RESPONSE.copy(deep=True) + test_eval_task = EvalTask( + dataset=eval_dataset, + metrics=[_TEST_PAIRWISE_METRIC], + ) + with mock.patch.object( + target=gapic_evaluation_services.EvaluationServiceClient, + attribute="evaluate_instances", + side_effect=mock_metric_results, + ): + test_result = test_eval_task.evaluate( + model=mock_candidate_model, + ) + _TEST_PAIRWISE_METRIC._baseline_model = None + assert test_result.summary_metrics["row_count"] == 2 + def test_evaluate_response_column_and_model_not_provided(self): test_eval_task = EvalTask( dataset=_TEST_EVAL_DATASET_SINGLE, diff --git a/vertexai/evaluation/_evaluation.py b/vertexai/evaluation/_evaluation.py index c8663f0fbe..9f8a2093c7 100644 --- a/vertexai/evaluation/_evaluation.py +++ b/vertexai/evaluation/_evaluation.py @@ -856,15 +856,28 @@ def evaluate( """ _validate_metrics(metrics) metrics = _convert_metric_prompt_template_example(metrics) - + copied_metrics = [] + for metric in metrics: + if isinstance(metric, pairwise_metric.PairwiseMetric): + copied_metrics.append( + pairwise_metric.PairwiseMetric( + metric=metric.metric_name, + metric_prompt_template=metric.metric_prompt_template, + baseline_model=metric.baseline_model, + ) + ) + else: + copied_metrics.append(copy.deepcopy(metric)) evaluation_run_config = evaluation_base.EvaluationRunConfig( dataset=dataset.copy(deep=True), - metrics=copy.deepcopy(metrics), + metrics=copied_metrics, metric_column_mapping=copy.deepcopy(metric_column_mapping), client=utils.create_evaluation_service_client(), - evaluation_service_qps=evaluation_service_qps - if evaluation_service_qps - else constants.QuotaLimit.EVAL_SERVICE_QPS, + evaluation_service_qps=( + evaluation_service_qps + if evaluation_service_qps + else constants.QuotaLimit.EVAL_SERVICE_QPS + ), retry_timeout=retry_timeout, )