Skip to content

Commit

Permalink
feat: Add generation_config to count_tokens
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688307166
  • Loading branch information
happy-qiao authored and copybara-github committed Oct 22, 2024
1 parent f713417 commit a2de5e2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
13 changes: 13 additions & 0 deletions tests/system/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,3 +596,16 @@ def test_count_tokens_from_text(self):
response_with_si_and_tool.total_billable_characters
> response_with_si.total_billable_characters
)
# content + generation_config
response_with_generation_config = model.count_tokens(
content,
generation_config=generative_models.GenerationConfig(response_schema=_RESPONSE_SCHEMA_STRUCT),
)
assert (
response_with_generation_config.total_tokens
> response_with_si_and_tool.total_tokens
)
assert (
response_with_generation_config.total_billable_characters
> response_with_si_and_tool.total_billable_characters
)
11 changes: 10 additions & 1 deletion vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,7 @@ async def async_generator():
return async_generator()

def count_tokens(
self, contents: ContentsType, *, tools: Optional[List["Tool"]] = None
self, contents: ContentsType, *, tools: Optional[List["Tool"]] = None, generation_config: Optional["GenerationConfig"] = None,
) -> gapic_prediction_service_types.CountTokensResponse:
"""Counts tokens.
Expand All @@ -885,6 +885,7 @@ def count_tokens(
* List[Union[str, Image, Part]],
* List[Content]
tools: A list of tools (functions) that the model can try calling.
generation_config: Parameters for the generate_content method.
Returns:
A CountTokensResponse object that has the following attributes:
Expand All @@ -894,6 +895,7 @@ def count_tokens(
request = self._prepare_request(
contents=contents,
tools=tools,
generation_config=generation_config,
)
return self._gapic_count_tokens(
prediction_resource_name=self._prediction_resource_name,
Expand All @@ -907,6 +909,7 @@ async def count_tokens_async(
contents: ContentsType,
*,
tools: Optional[List["Tool"]] = None,
generation_config: Optional["GenerationConfig"] = None,
) -> gapic_prediction_service_types.CountTokensResponse:
"""Counts tokens asynchronously.
Expand All @@ -919,6 +922,7 @@ async def count_tokens_async(
* List[Union[str, Image, Part]],
* List[Content]
tools: A list of tools (functions) that the model can try calling.
generation_config: Parameters for the generate_content method.
Returns:
And awaitable for a CountTokensResponse object that has the following attributes:
Expand All @@ -928,6 +932,7 @@ async def count_tokens_async(
request = self._prepare_request(
contents=contents,
tools=tools,
generation_config=generation_config,
)
return await self._gapic_count_tokens_async(
prediction_resource_name=self._prediction_resource_name,
Expand All @@ -942,13 +947,15 @@ def _gapic_count_tokens(
contents: List[gapic_content_types.Content],
system_instruction: Optional[gapic_content_types.Content] = None,
tools: Optional[List[gapic_tool_types.Tool]] = None,
generation_config: Optional[gapic_content_types.GenerationConfig] = None,
) -> gapic_prediction_service_types.CountTokensResponse:
request = gapic_prediction_service_types.CountTokensRequest(
endpoint=prediction_resource_name,
model=prediction_resource_name,
contents=contents,
system_instruction=system_instruction,
tools=tools,
generation_config=generation_config,
)
return self._prediction_client.count_tokens(request=request)

Expand All @@ -958,13 +965,15 @@ async def _gapic_count_tokens_async(
contents: List[gapic_content_types.Content],
system_instruction: Optional[gapic_content_types.Content] = None,
tools: Optional[List[gapic_tool_types.Tool]] = None,
generation_config: Optional[gapic_content_types.GenerationConfig] = None,
) -> gapic_prediction_service_types.CountTokensResponse:
request = gapic_prediction_service_types.CountTokensRequest(
endpoint=prediction_resource_name,
model=prediction_resource_name,
contents=contents,
system_instruction=system_instruction,
tools=tools,
generation_config=generation_config,
)
return await self._prediction_async_client.count_tokens(request=request)

Expand Down

0 comments on commit a2de5e2

Please sign in to comment.