-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
21 changed files
with
4,784 additions
and
0 deletions.
There are no files selected for viewing
140 changes: 140 additions & 0 deletions
140
src/main/java/io/orkes/conductor/client/AIOrchestrator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
package io.orkes.conductor.client; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.UUID; | ||
|
||
import com.netflix.conductor.sdk.workflow.executor.WorkflowExecutor; | ||
|
||
import io.orkes.conductor.client.http.ApiException; | ||
import io.orkes.conductor.client.model.integration.Category; | ||
import io.orkes.conductor.client.model.integration.Integration; | ||
import io.orkes.conductor.client.model.integration.IntegrationApi; | ||
import io.orkes.conductor.client.model.integration.IntegrationApiUpdate; | ||
import io.orkes.conductor.client.model.integration.IntegrationConfig; | ||
import io.orkes.conductor.client.model.integration.IntegrationUpdate; | ||
import io.orkes.conductor.client.model.integration.ai.PromptTemplate; | ||
|
||
public class AIOrchestrator { | ||
private IntegrationClient integrationClient; | ||
private WorkflowClient workflowClient; | ||
private WorkflowExecutor workflowExecutor; | ||
private PromptClient promptClient; | ||
private String promptTestWorkflowName; | ||
|
||
public enum VectorDB { | ||
PINECONE_DB("pineconedb"), | ||
WEAVIATE_DB("weaviatedb"); | ||
|
||
private final String value; | ||
|
||
private VectorDB(String value) { | ||
this.value = value; | ||
} | ||
|
||
public String getValue() { | ||
return this.value; | ||
} | ||
} | ||
|
||
public enum LLMProvider { | ||
AZURE_OPEN_AI("azure_openai"), | ||
OPEN_AI("openai"), | ||
GCP_VERTEX_AI("vertex_ai"), | ||
HUGGING_FACE("huggingface"); | ||
|
||
private final String value; | ||
|
||
private LLMProvider(String value) { | ||
this.value = value; | ||
} | ||
|
||
public String getValue() { | ||
return this.value; | ||
} | ||
} | ||
|
||
public AIOrchestrator(ApiClient apiConfiguration, String promptTestWorkflowName) { | ||
OrkesClients orkesClients = new OrkesClients(apiConfiguration); | ||
this.integrationClient = orkesClients.getIntegrationClient(); | ||
this.workflowClient = orkesClients.getWorkflowClient(); | ||
this.promptClient = orkesClients.getPromptClient(); | ||
this.promptTestWorkflowName = promptTestWorkflowName.isEmpty() ? "prompt_test_" + UUID.randomUUID().toString() : promptTestWorkflowName; | ||
} | ||
|
||
public AIOrchestrator addPromptTemplate(String name, String promptTemplate, String description) { | ||
promptClient.savePrompt(name, description, promptTemplate); | ||
return this; | ||
} | ||
|
||
public PromptTemplate getPromptTemplate(String templateName) { | ||
try { | ||
return promptClient.getPrompt(templateName); | ||
} catch (ApiException e) { | ||
if (e.getStatusCode() == 404) { | ||
return null; | ||
} | ||
throw e; | ||
} | ||
} | ||
|
||
public void associatePromptTemplate(String name, String aiIntegration, List<String> aiModels) { | ||
aiModels.forEach(aiModel -> integrationClient.associatePromptWithIntegration(aiIntegration, aiModel, name)); | ||
} | ||
|
||
public Object testPromptTemplate(String text, Map<String, Object> variables, String aiIntegration, String textCompleteModel, List<String> stopWords, Integer maxTokens, int temperature, int topP) { | ||
return promptClient.testPrompt(text, variables, aiIntegration, textCompleteModel, temperature, topP, stopWords); | ||
} | ||
|
||
public void addAIIntegration(String aiIntegrationName, LLMProvider provider, List<String> models, String description, IntegrationConfig config, boolean overwrite) { | ||
IntegrationUpdate details = new IntegrationUpdate(); | ||
details.setConfiguration(config.toMap()); | ||
details.setType(provider.toString()); | ||
details.setCategory(Category.AI_MODEL); | ||
details.setEnabled(true); | ||
details.setDescription(description); | ||
Integration existingIntegration = integrationClient.getIntegration(aiIntegrationName); | ||
if (existingIntegration == null || overwrite) { | ||
integrationClient.saveIntegration(aiIntegrationName, details); | ||
} | ||
models.forEach(model -> { | ||
IntegrationApiUpdate apiDetails = new IntegrationApiUpdate(); | ||
apiDetails.setEnabled(true); | ||
apiDetails.setDescription(description); | ||
IntegrationApi existingIntegrationApi = integrationClient.getIntegrationApi(aiIntegrationName, model); | ||
if (existingIntegrationApi == null || overwrite) { | ||
integrationClient.saveIntegrationApi(aiIntegrationName, model, apiDetails); | ||
} | ||
}); | ||
} | ||
|
||
public void addVectorStore(String dbIntegrationName, VectorDB provider, List<String> indices, IntegrationConfig config, String description, boolean overwrite) { | ||
IntegrationUpdate vectorDb = new IntegrationUpdate(); | ||
vectorDb.setConfiguration(config.toMap()); | ||
vectorDb.setType(provider.toString()); | ||
vectorDb.setCategory(Category.VECTOR_DB); | ||
vectorDb.setEnabled(true); | ||
vectorDb.setDescription(description != null ? description : dbIntegrationName); | ||
Integration existingIntegration = integrationClient.getIntegration(dbIntegrationName); | ||
if (existingIntegration == null || overwrite) { | ||
integrationClient.saveIntegration(dbIntegrationName, vectorDb); | ||
} | ||
indices.forEach(index -> { | ||
IntegrationApiUpdate apiDetails = new IntegrationApiUpdate(); | ||
apiDetails.setEnabled(true); | ||
apiDetails.setDescription(description); | ||
IntegrationApi existingIntegrationApi = integrationClient.getIntegrationApi(dbIntegrationName, index); | ||
if (existingIntegrationApi == null || overwrite) { | ||
integrationClient.saveIntegrationApi(dbIntegrationName, index, apiDetails); | ||
} | ||
}); | ||
} | ||
|
||
public Map<String, Integer> getTokenUsed(String aiIntegration) { | ||
return integrationClient.getTokenUsageForIntegrationProvider(aiIntegration); | ||
} | ||
|
||
public int getTokenUsedByModel(String aiIntegration, String model) { | ||
return integrationClient.getTokenUsageForIntegration(aiIntegration, model); | ||
} | ||
} |
65 changes: 65 additions & 0 deletions
65
src/main/java/io/orkes/conductor/client/IntegrationClient.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
package io.orkes.conductor.client; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import io.orkes.conductor.client.model.TagObject; | ||
import io.orkes.conductor.client.model.integration.Integration; | ||
import io.orkes.conductor.client.model.integration.IntegrationApi; | ||
import io.orkes.conductor.client.model.integration.IntegrationApiUpdate; | ||
import io.orkes.conductor.client.model.integration.IntegrationUpdate; | ||
import io.orkes.conductor.client.model.integration.ai.PromptTemplate; | ||
|
||
public interface IntegrationClient { | ||
/** | ||
* Client for managing integrations with external systems. Some examples of integrations are: | ||
* 1. AI/LLM providers (e.g. OpenAI, HuggingFace) | ||
* 2. Vector DBs (Pinecone, Weaviate etc.) | ||
* 3. Kafka | ||
* 4. Relational databases | ||
* | ||
* Integrations are configured as integration -> api with 1->N cardinality. | ||
* APIs are the underlying resources for an integration and depending on the type of integration they represent underlying resources. | ||
* Examples: | ||
* LLM integrations | ||
* The integration specifies the name of the integration unique to your environment, api keys and endpoint used. | ||
* APIs are the models (e.g. text-davinci-003, or text-embedding-ada-002) | ||
* | ||
* Vector DB integrations, | ||
* The integration represents the cluster, specifies the name of the integration unique to your environment, api keys and endpoint used. | ||
* APIs are the indexes (e.g. pinecone) or class (e.g. for weaviate) | ||
* | ||
* Kafka | ||
* The integration represents the cluster, specifies the name of the integration unique to your environment, api keys and endpoint used. | ||
* APIs are the topics that are configured for use within this kafka cluster | ||
*/ | ||
|
||
void associatePromptWithIntegration(String aiIntegration, String modelName, String promptName); | ||
|
||
void deleteIntegrationApi(String apiName, String integrationName); | ||
|
||
void deleteIntegration(String integrationName); | ||
|
||
IntegrationApi getIntegrationApi(String apiName, String integrationName); | ||
|
||
List<IntegrationApi> getIntegrationApis(String integrationName); | ||
|
||
Integration getIntegration(String integrationName); | ||
|
||
List<Integration> getIntegrations(String category, Boolean activeOnly); | ||
|
||
List<PromptTemplate> getPromptsWithIntegration(String aiIntegration, String modelName); | ||
|
||
int getTokenUsageForIntegration(String name, String integrationName); | ||
|
||
Map<String, Integer> getTokenUsageForIntegrationProvider(String name); | ||
|
||
void saveIntegrationApi(String integrationName, String apiName, IntegrationApiUpdate apiDetails); | ||
|
||
void saveIntegration(String integrationName, IntegrationUpdate integrationDetails); | ||
|
||
// Tags | ||
void deleteTagForIntegrationProvider(List<TagObject> tags, String name); | ||
void saveTagForIntegrationProvider(List<TagObject> tags, String name); | ||
List<TagObject> getTagsForIntegrationProvider(String name); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
67 changes: 67 additions & 0 deletions
67
src/main/java/io/orkes/conductor/client/OrkesPromptClient.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
package io.orkes.conductor.client; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import io.orkes.conductor.client.http.api.PromptResourceApi; | ||
import io.orkes.conductor.client.model.TagObject; | ||
import io.orkes.conductor.client.model.integration.PromptTemplateTestRequest; | ||
import io.orkes.conductor.client.model.integration.ai.PromptTemplate; | ||
|
||
public class OrkesPromptClient implements PromptClient { | ||
|
||
private final PromptResourceApi promptResourceApi; | ||
|
||
public OrkesPromptClient(ApiClient apiClient) { | ||
this.promptResourceApi = new PromptResourceApi(apiClient); | ||
} | ||
|
||
@Override | ||
public void savePrompt(String promptName, String description, String promptTemplate) { | ||
promptResourceApi.savePromptTemplate(promptTemplate, description, promptName, List.of()); | ||
} | ||
|
||
@Override | ||
public PromptTemplate getPrompt(String promptName) { | ||
return promptResourceApi.getPromptTemplate(promptName); | ||
} | ||
|
||
@Override | ||
public List<PromptTemplate> getPrompts() { | ||
return promptResourceApi.getPromptTemplates(); | ||
} | ||
|
||
@Override | ||
public void deletePrompt(String promptName) { | ||
promptResourceApi.deletePromptTemplate(promptName); | ||
} | ||
|
||
@Override | ||
public List<TagObject> getTagsForPromptTemplate(String promptName) { | ||
return promptResourceApi.getTagsForPromptTemplate(promptName); | ||
} | ||
|
||
@Override | ||
public void updateTagForPromptTemplate(String promptName, List<TagObject> tags) { | ||
promptResourceApi.putTagForPromptTemplate(tags, promptName); | ||
} | ||
|
||
@Override | ||
public void deleteTagForPromptTemplate(String promptName, List<TagObject> tags) { | ||
promptResourceApi.deleteTagForPromptTemplate(tags, promptName); | ||
} | ||
|
||
@Override | ||
public String testPrompt(String promptText, Map<String, Object> variables, String aiIntegration, String textCompleteModel, float temperature, float topP, | ||
List<String> stopWords) { | ||
PromptTemplateTestRequest request = new PromptTemplateTestRequest(); | ||
request.setPrompt(promptText); | ||
request.setLlmProvider(aiIntegration); | ||
request.setModel(textCompleteModel); | ||
request.setTemperature((double) temperature); | ||
request.setTopP((double) topP); | ||
request.setStopWords(stopWords == null ? List.of() : stopWords); | ||
request.setPromptVariables(variables); | ||
return promptResourceApi.testMessageTemplate(request); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
package io.orkes.conductor.client; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import io.orkes.conductor.client.model.TagObject; | ||
import io.orkes.conductor.client.model.integration.ai.PromptTemplate; | ||
|
||
public interface PromptClient { | ||
|
||
void savePrompt(String promptName, String description, String promptTemplate); | ||
|
||
PromptTemplate getPrompt(String promptName); | ||
|
||
List<PromptTemplate> getPrompts(); | ||
|
||
void deletePrompt(String promptName); | ||
|
||
List<TagObject> getTagsForPromptTemplate(String promptName); | ||
|
||
void updateTagForPromptTemplate(String promptName, List<TagObject> tags); | ||
|
||
void deleteTagForPromptTemplate(String promptName, List<TagObject> tags); | ||
|
||
/** | ||
* Tests a prompt template by substituting variables and processing through the specified AI model. | ||
* | ||
* @param promptText the text of the prompt template | ||
* @param variables a map containing variables to be replaced in the template | ||
* @param aiIntegration the AI integration context | ||
* @param textCompleteModel the AI model used for completing text | ||
* @param temperature the randomness of the output (optional, default is 0.1) | ||
* @param topP the probability mass to consider from the output distribution (optional, default is 0.9) | ||
* @param stopWords a list of words to stop generating further (can be null) | ||
* @return the processed prompt text | ||
*/ | ||
String testPrompt(String promptText, Map<String, Object> variables, String aiIntegration, | ||
String textCompleteModel, float temperature, float topP, List<String> stopWords); | ||
} |
Oops, something went wrong.