Skip to content

Commit

Permalink
feat: Prompt management for VertexSDK
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 692272106
  • Loading branch information
matthew29tang authored and copybara-github committed Nov 6, 2024
1 parent 169dd44 commit fe44ea3
Showing 1 changed file with 253 additions and 0 deletions.
253 changes: 253 additions & 0 deletions vertexai/generative_models/_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer as aiplatform_initializer
from google.cloud.aiplatform.compat.services import dataset_service_client
from google.cloud.aiplatform.compat.types import dataset as gca_dataset
from google.cloud.aiplatform_v1.types import dataset_version as gca_dataset_version
from vertexai.generative_models import (
Content,
Image,
Expand All @@ -39,6 +42,7 @@
SafetySettingsType,
)

import dataclasses
import re
from typing import (
Any,
Expand All @@ -51,9 +55,138 @@

_LOGGER = base.Logger(__name__)

DEFAULT_API_SCHEMA_VERSION = "1.0.0"
VARIABLE_NAME_REGEX = r"(\{[^\W0-9]\w*\})"


@dataclasses.dataclass
class PromptMessage:
"""PromptMessage.
Attributes:
model: The model name (in the format publishers/google/models/{model_name}).
contents: The contents of the prompt.
system_instruction: The system instruction of the prompt.
tools: The tools of the prompt.
tool_config: The tool config of the prompt.
generation_config: The generation config of the prompt.
safety_settings: The safety settings of the prompt.
"""
model: Optional[str] = None
contents: Optional[List[Content]] = None
system_instruction: Optional[Content] = None
tools: Optional[List[Tool]] = None
tool_config: Optional[ToolConfig] = None
generation_config: Optional[GenerationConfig] = None
safety_settings: Optional[SafetySetting] = None

def to_dict(self) -> Dict[str, Any]:
dct = {}
dct["model"] = self.model
if self.contents:
dct["contents"] = [content.to_dict() for content in self.contents]
if self.system_instruction:
dct["systemInstruction"] = self.system_instruction.to_dict()
if self.tools:
dct["tools"] = [tool.to_dict() for tool in self.tools]
# dct["toolConfig"] = self.tool_config.to_dict() if self.tool_config else None # no toolconfig for now
if self.generation_config:
dct["generationConfig"] = self.generation_config.to_dict()
if self.safety_settings:
dct["safetySettings"] = self.safety_settings.to_dict()
return dct


@dataclasses.dataclass
class Arguments:
"""Arguments.
Attributes:
variables: The arguments of the execution.
"""
variables: dict[str, list[Part]]

def to_dict(self) -> Dict[str, Any]:
dct = {}
for variable_name in self.variables:
dct[variable_name] = {
"partList": {
"parts": [
part.to_dict() for part in self.variables[variable_name]
]
}
}
return dct


@dataclasses.dataclass
class Execution:
"""Execution.
Attributes:
arguments: The arguments of the execution.
"""
arguments: Arguments

def __init__(self, arguments: dict[str, list[Part]]):
self.arguments = Arguments(variables=arguments)

def to_dict(self) -> Dict[str, Any]:
dct = {}
dct["arguments"] = self.arguments.to_dict()
return dct


@dataclasses.dataclass
class MultimodalPrompt:
"""MultimodalPrompt.
Attributes:
prompt_message: The schema for the prompt. Mirrors the GenerateContentRequest schema.
api_schema_version: The api schema version of the prompt when it was last modified.
executions: Contains data related to an execution of a prompt (ex. variables)
"""
prompt_message: PromptMessage
api_schema_version: Optional[str] = DEFAULT_API_SCHEMA_VERSION
executions: Optional[list[Execution]] = None

def to_dict(self) -> Dict[str, Any]:
dct = {
"multimodalPrompt": {}
}
dct["apiSchemaVersion"] = self.api_schema_version
dct["multimodalPrompt"]["promptMessage"] = self.prompt_message.to_dict()
if self.executions and self.executions[0]:
# Only add variable sets if they are non empty.
execution_dcts = []
for execution in self.executions:
exeuction_dct = execution.to_dict()
if exeuction_dct and exeuction_dct["arguments"]:
execution_dcts.append(exeuction_dct)
if execution_dcts:
dct["executions"] = execution_dcts
return dct


@dataclasses.dataclass
class PromptDatasetMetadata:
"""PromptDatasetMetadata.
Attributes:
prompt_type: Requird. SDK only supports "freeform" or "multimodalFreeform"
prompt_api_schema: Required. SDK only supports multimodalPrompt
"""

prompt_type: str
prompt_api_schema: MultimodalPrompt

def to_dict(self) -> Dict[str, Any]:
dct = {}
dct["promptType"] = self.prompt_type
dct["promptApiSchema"] = self.prompt_api_schema.to_dict()
return dct


class Prompt:
"""A prompt which may be a template with variables.
Expand Down Expand Up @@ -157,6 +290,10 @@ def __init__(
self._system_instruction = None
self._tools = None
self._tool_config = None
self._dataset_client_value = None
self._dataset = None
self._prompt_name = None
self._version_id = None

self.prompt_data = prompt_data
self.variables = variables if variables else [{}]
Expand Down Expand Up @@ -567,6 +704,122 @@ def generate_content(
stream=stream,
)

@property
def _dataset_client(self) -> dataset_service_client.DatasetServiceClient:
if not getattr(self, "_dataset_client_value", None):
self._dataset_client_value = (
aiplatform_initializer.global_config.create_client(
client_class=dataset_service_client.DatasetServiceClient,
)
)
return self._dataset_client_value

def _create_dataset(self, parent: str, contents: list[Content], model_name: str) -> gca_dataset.Dataset:
metadata_schema_uri = "gs://google-cloud-aiplatform/schema/dataset/metadata/text_prompt_1.0.0.yaml"
dataset_metadata = PromptDatasetMetadata(
prompt_type="freeform",
prompt_api_schema=MultimodalPrompt(
prompt_message=PromptMessage(
model=model_name,
contents=contents,
),
executions=[Execution(variable_set) for variable_set in self.variables],
),
)

dataset = gca_dataset.Dataset(
name=parent,
display_name=self._prompt_name or "test1",
metadata_schema_uri=metadata_schema_uri,
metadata=dataset_metadata.to_dict(),
model_reference = model_name,

)
operation = self._dataset_client.create_dataset(
parent=parent,
dataset=dataset,
)
dataset = operation.result()

# Purge labels
dataset.labels = None
return dataset

def _create_dataset_version(self, parent):
dataset_version = gca_dataset_version.DatasetVersion(
display_name=self._prompt_name,
)

dataset_version = self._dataset_client.create_dataset_version(
parent=parent,
dataset_version=dataset_version,
)
return dataset_version.result()

def _update_dataset(self, dataset: gca_dataset.Dataset, contents: list[Content], model_name: str) -> gca_dataset_version.DatasetVersion:
dataset_metadata = PromptDatasetMetadata(
prompt_type="freeform",
prompt_api_schema=MultimodalPrompt(
prompt_message=PromptMessage(
model=model_name,
contents=contents,
),
executions=[Execution(variable_set) for variable_set in self.variables],
),
)
dataset.metadata = dataset_metadata.to_dict()

dataset_version = self._dataset_client.update_dataset(
dataset=dataset,
)
return dataset_version

def create_version(
self,
create_new_prompt: bool = False,
) -> None:
"""Creates a Prompt in the online prompt store"""
if not self._dataset or create_new_prompt:
return self._create_prompt_resource()
else:
return self._create_prompt_version_resource()

def _create_prompt_resource(self) -> None:
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location

# Step 1: Create prompt dataset
parent = f"projects/{project}/locations/{location}"
prompt_contents = [Content._from_gapic(_to_content(value=self.prompt_data))]
dataset = self._create_dataset(
parent=parent,
contents=prompt_contents,
model_name=self.model_name
)

# Step 2: Create prompt version (snapshot)
dataset_version = self._create_dataset_version(dataset.name)

# Step 3: Update Prompt object
self._dataset = dataset
self._version_id = dataset_version.name.split("/")[-1]

def _create_prompt_version_resource(self) -> None:
# Step 1: Update prompt
prompt_contents = [Content._from_gapic(_to_content(value=self.prompt_data))]
updated_dataset = self._update_dataset(
dataset=self._dataset,
contents=prompt_contents,
model_name=self.model_name
)

# Step 2: Create prompt version (snapshot)
dataset_version = self._create_dataset_version(updated_dataset.name)

# Step 3: Update Prompt object
self._dataset = updated_dataset
self._version_id = dataset_version.name.split("/")[-1]

def get_unassembled_prompt_data(self) -> PartsType:
"""Returns the prompt data, without any variables replaced."""
return self.prompt_data
Expand Down

0 comments on commit fe44ea3

Please sign in to comment.