diff --git a/tests/system/vertexai/test_generative_models.py b/tests/system/vertexai/test_generative_models.py index 2f05d659c8..8cb7d9164f 100644 --- a/tests/system/vertexai/test_generative_models.py +++ b/tests/system/vertexai/test_generative_models.py @@ -596,3 +596,16 @@ def test_count_tokens_from_text(self): response_with_si_and_tool.total_billable_characters > response_with_si.total_billable_characters ) + # content + generation_config + response_with_generation_config = model.count_tokens( + content, + generation_config=generative_models.GenerationConfig(response_schema=_RESPONSE_SCHEMA_STRUCT), + ) + assert ( + response_with_generation_config.total_tokens + > response_with_si_and_tool.total_tokens + ) + assert ( + response_with_generation_config.total_billable_characters + > response_with_si_and_tool.total_billable_characters + ) diff --git a/vertexai/generative_models/_generative_models.py b/vertexai/generative_models/_generative_models.py index 805123bc1d..adf82a0e92 100644 --- a/vertexai/generative_models/_generative_models.py +++ b/vertexai/generative_models/_generative_models.py @@ -872,7 +872,7 @@ async def async_generator(): return async_generator() def count_tokens( - self, contents: ContentsType, *, tools: Optional[List["Tool"]] = None + self, contents: ContentsType, *, tools: Optional[List["Tool"]] = None, generation_config: Optional["GenerationConfig"] = None, ) -> gapic_prediction_service_types.CountTokensResponse: """Counts tokens. @@ -885,6 +885,7 @@ def count_tokens( * List[Union[str, Image, Part]], * List[Content] tools: A list of tools (functions) that the model can try calling. + generation_config: Parameters for the generate_content method. Returns: A CountTokensResponse object that has the following attributes: @@ -894,6 +895,7 @@ def count_tokens( request = self._prepare_request( contents=contents, tools=tools, + generation_config=generation_config, ) return self._gapic_count_tokens( prediction_resource_name=self._prediction_resource_name, @@ -907,6 +909,7 @@ async def count_tokens_async( contents: ContentsType, *, tools: Optional[List["Tool"]] = None, + generation_config: Optional["GenerationConfig"] = None, ) -> gapic_prediction_service_types.CountTokensResponse: """Counts tokens asynchronously. @@ -919,6 +922,7 @@ async def count_tokens_async( * List[Union[str, Image, Part]], * List[Content] tools: A list of tools (functions) that the model can try calling. + generation_config: Parameters for the generate_content method. Returns: And awaitable for a CountTokensResponse object that has the following attributes: @@ -928,6 +932,7 @@ async def count_tokens_async( request = self._prepare_request( contents=contents, tools=tools, + generation_config=generation_config, ) return await self._gapic_count_tokens_async( prediction_resource_name=self._prediction_resource_name, @@ -942,6 +947,7 @@ def _gapic_count_tokens( contents: List[gapic_content_types.Content], system_instruction: Optional[gapic_content_types.Content] = None, tools: Optional[List[gapic_tool_types.Tool]] = None, + generation_config: Optional[gapic_content_types.GenerationConfig] = None, ) -> gapic_prediction_service_types.CountTokensResponse: request = gapic_prediction_service_types.CountTokensRequest( endpoint=prediction_resource_name, @@ -949,6 +955,7 @@ def _gapic_count_tokens( contents=contents, system_instruction=system_instruction, tools=tools, + generation_config=generation_config, ) return self._prediction_client.count_tokens(request=request) @@ -958,6 +965,7 @@ async def _gapic_count_tokens_async( contents: List[gapic_content_types.Content], system_instruction: Optional[gapic_content_types.Content] = None, tools: Optional[List[gapic_tool_types.Tool]] = None, + generation_config: Optional[gapic_content_types.GenerationConfig] = None, ) -> gapic_prediction_service_types.CountTokensResponse: request = gapic_prediction_service_types.CountTokensRequest( endpoint=prediction_resource_name, @@ -965,6 +973,7 @@ async def _gapic_count_tokens_async( contents=contents, system_instruction=system_instruction, tools=tools, + generation_config=generation_config, ) return await self._prediction_async_client.count_tokens(request=request)