diff --git a/vertexai/generative_models/_prompts.py b/vertexai/generative_models/_prompts.py index c70e37ba2f..92b55bc6c5 100644 --- a/vertexai/generative_models/_prompts.py +++ b/vertexai/generative_models/_prompts.py @@ -17,6 +17,9 @@ from google.cloud.aiplatform import base from google.cloud.aiplatform import initializer as aiplatform_initializer +from google.cloud.aiplatform.compat.services import dataset_service_client +from google.cloud.aiplatform.compat.types import dataset as gca_dataset +from google.cloud.aiplatform_v1.types import dataset_version as gca_dataset_version from vertexai.generative_models import ( Content, Image, @@ -39,6 +42,7 @@ SafetySettingsType, ) +import dataclasses import re from typing import ( Any, @@ -51,9 +55,138 @@ _LOGGER = base.Logger(__name__) +DEFAULT_API_SCHEMA_VERSION = "1.0.0" VARIABLE_NAME_REGEX = r"(\{[^\W0-9]\w*\})" +@dataclasses.dataclass +class PromptMessage: + """PromptMessage. + + Attributes: + model: The model name (in the format publishers/google/models/{model_name}). + contents: The contents of the prompt. + system_instruction: The system instruction of the prompt. + tools: The tools of the prompt. + tool_config: The tool config of the prompt. + generation_config: The generation config of the prompt. + safety_settings: The safety settings of the prompt. + """ + model: Optional[str] = None + contents: Optional[List[Content]] = None + system_instruction: Optional[Content] = None + tools: Optional[List[Tool]] = None + tool_config: Optional[ToolConfig] = None + generation_config: Optional[GenerationConfig] = None + safety_settings: Optional[SafetySetting] = None + + def to_dict(self) -> Dict[str, Any]: + dct = {} + dct["model"] = self.model + if self.contents: + dct["contents"] = [content.to_dict() for content in self.contents] + if self.system_instruction: + dct["systemInstruction"] = self.system_instruction.to_dict() + if self.tools: + dct["tools"] = [tool.to_dict() for tool in self.tools] + # dct["toolConfig"] = self.tool_config.to_dict() if self.tool_config else None # no toolconfig for now + if self.generation_config: + dct["generationConfig"] = self.generation_config.to_dict() + if self.safety_settings: + dct["safetySettings"] = self.safety_settings.to_dict() + return dct + + +@dataclasses.dataclass +class Arguments: + """Arguments. + + Attributes: + variables: The arguments of the execution. + """ + variables: dict[str, list[Part]] + + def to_dict(self) -> Dict[str, Any]: + dct = {} + for variable_name in self.variables: + dct[variable_name] = { + "partList": { + "parts": [ + part.to_dict() for part in self.variables[variable_name] + ] + } + } + return dct + + +@dataclasses.dataclass +class Execution: + """Execution. + + Attributes: + arguments: The arguments of the execution. + """ + arguments: Arguments + + def __init__(self, arguments: dict[str, list[Part]]): + self.arguments = Arguments(variables=arguments) + + def to_dict(self) -> Dict[str, Any]: + dct = {} + dct["arguments"] = self.arguments.to_dict() + return dct + + +@dataclasses.dataclass +class MultimodalPrompt: + """MultimodalPrompt. + + Attributes: + prompt_message: The schema for the prompt. Mirrors the GenerateContentRequest schema. + api_schema_version: The api schema version of the prompt when it was last modified. + executions: Contains data related to an execution of a prompt (ex. variables) + """ + prompt_message: PromptMessage + api_schema_version: Optional[str] = DEFAULT_API_SCHEMA_VERSION + executions: Optional[list[Execution]] = None + + def to_dict(self) -> Dict[str, Any]: + dct = { + "multimodalPrompt": {} + } + dct["apiSchemaVersion"] = self.api_schema_version + dct["multimodalPrompt"]["promptMessage"] = self.prompt_message.to_dict() + if self.executions and self.executions[0]: + # Only add variable sets if they are non empty. + execution_dcts = [] + for execution in self.executions: + exeuction_dct = execution.to_dict() + if exeuction_dct and exeuction_dct["arguments"]: + execution_dcts.append(exeuction_dct) + if execution_dcts: + dct["executions"] = execution_dcts + return dct + + +@dataclasses.dataclass +class PromptDatasetMetadata: + """PromptDatasetMetadata. + + Attributes: + prompt_type: Requird. SDK only supports "freeform" or "multimodalFreeform" + prompt_api_schema: Required. SDK only supports multimodalPrompt + """ + + prompt_type: str + prompt_api_schema: MultimodalPrompt + + def to_dict(self) -> Dict[str, Any]: + dct = {} + dct["promptType"] = self.prompt_type # TODO(tangmatthew): Check if multimodal + dct["promptApiSchema"] = self.prompt_api_schema.to_dict() + return dct + + class Prompt: """A prompt which may be a template with variables. @@ -157,6 +290,10 @@ def __init__( self._system_instruction = None self._tools = None self._tool_config = None + self._dataset_client_value = None + self._dataset = None + self._prompt_name = None + self._version_id = None self.prompt_data = prompt_data self.variables = variables if variables else [{}] @@ -567,6 +704,115 @@ def generate_content( stream=stream, ) + @property + def _dataset_client(self) -> dataset_service_client.DatasetServiceClient: + if not getattr(self, "_dataset_client_value", None): + self._dataset_client_value = ( + aiplatform_initializer.global_config.create_client( + client_class=dataset_service_client.DatasetServiceClient, + ) + ) + return self._dataset_client_value + + def create_version( + self, + create_new_prompt: bool = False, + ) -> None: + """Creates a Prompt in the online prompt store""" + if not self._dataset or create_new_prompt: + return self._create_prompt_resource() + else: + return self._create_prompt_version_resource() + + def _format_dataset_metadata_dict(self) -> dict[str, Any]: + contents = [Content._from_gapic(_to_content(value=self.prompt_data))] + return PromptDatasetMetadata( + prompt_type="freeform", + prompt_api_schema=MultimodalPrompt( + prompt_message=PromptMessage( + model=self.model_name, + contents=contents, + system_instruction=self.system_instruction, + tools=self.tools, + tool_config=self.tool_config, + safety_settings=self.safety_settings, + generation_config=self.generation_config, + ), + executions=[Execution(variable_set) for variable_set in self.variables], + ), + ).to_dict() + + def _create_dataset(self, parent: str) -> gca_dataset.Dataset: + metadata_schema_uri = "gs://google-cloud-aiplatform/schema/dataset/metadata/text_prompt_1.0.0.yaml" + dataset_metadata = self._format_dataset_metadata_dict() + + dataset = gca_dataset.Dataset( + name=parent, + display_name=self._prompt_name or "test1", # TODO: tangmatthew Remove default display name + metadata_schema_uri=metadata_schema_uri, + metadata=dataset_metadata, + model_reference = self.model_name, + + ) + operation = self._dataset_client.create_dataset( + parent=parent, + dataset=dataset, + ) + dataset = operation.result() + + # Purge labels + dataset.labels = None + return dataset + + def _create_dataset_version(self, parent): + dataset_version = gca_dataset_version.DatasetVersion( + display_name=self._prompt_name, + ) + + dataset_version = self._dataset_client.create_dataset_version( + parent=parent, + dataset_version=dataset_version, + ) + return dataset_version.result() + + def _update_dataset(self, dataset: gca_dataset.Dataset) -> gca_dataset_version.DatasetVersion: + dataset.metadata = self._format_dataset_metadata_dict() + + updated_dataset = self._dataset_client.update_dataset( + dataset=dataset, + ) + return updated_dataset + + def _create_prompt_resource(self) -> None: + project = aiplatform_initializer.global_config.project + location = aiplatform_initializer.global_config.location + + # Step 1: Create prompt dataset + parent = f"projects/{project}/locations/{location}" + dataset = self._create_dataset(parent=parent) + + # Step 2: Create prompt version (snapshot) + dataset_version = self._create_dataset_version(dataset.name) + + # Step 3: Update Prompt object + self._dataset = dataset + self._version_id = dataset_version.name.split("/")[-1] + prompt_id = self._dataset.name.split("/")[5] + _LOGGER.info(f"Created prompt resource with id {prompt_id} with version number {self._version_id}") + + def _create_prompt_version_resource(self) -> None: + # Step 1: Update prompt + updated_dataset = self._update_dataset(dataset=self._dataset) + + # Step 2: Create prompt version (snapshot) + dataset_version = self._create_dataset_version(updated_dataset.name) + + # Step 3: Update Prompt object + self._dataset = updated_dataset + self._version_id = dataset_version.name.split("/")[-1] + prompt_id = self._dataset.name.split("/")[5] + _LOGGER.info(f"Updated prompt resource with id {prompt_id} as version number {self._version_id}") + def get_unassembled_prompt_data(self) -> PartsType: """Returns the prompt data, without any variables replaced.""" return self.prompt_data