Skip to content

Commit

Permalink
Merge pull request #12628 from PasanT9/api-key-mediator
Browse files Browse the repository at this point in the history
Move endpoint configurations and backend throttling from AiConfigurations
  • Loading branch information
PasanT9 authored Oct 10, 2024
2 parents 652ab60 + 641df8b commit d21ce73
Show file tree
Hide file tree
Showing 38 changed files with 443 additions and 903 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ public class APIConstants {
public static final String ENDPOINT_CONFIG_SESSION_TIMEOUT = "sessionTimeOut";

public static class AIAPIConstants {
public static final String API_KEY_IDENTIFIER_TYPE = "API_KEY_IDENTIFIER_TYPE";
public static final String API_KEY_IDENTIFIER_TYPE_HEADER = "HEADER";
public static final String API_KEY_IDENTIFIER_TYPE_QUERY_PARAMETER = "QUERY_PARAMETER";
public static final String AI_API_REQUEST_METADATA = "AI_API_REQUEST_METADATA";
public static final String AI_API_RESPONSE_METADATA = "AI_API_RESPONSE_METADATA";
public static final String INPUT_SOURCE_PAYLOAD = "payload";
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, WSO2 LLC. (http://www.wso2.org) All Rights Reserved.
* Copyright (c) 2024, WSO2 LLC. (http://www.wso2.com) All Rights Reserved.
*
* WSO2 LLC. licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
Expand All @@ -18,7 +18,7 @@

package org.wso2.carbon.apimgt.api;

public class TokenBaseThrottlingCountHolder {
public class TokenBasedThrottlingCountHolder {

private String productionMaxPromptTokenCount;
private String productionMaxCompletionTokenCount;
Expand All @@ -28,11 +28,11 @@ public class TokenBaseThrottlingCountHolder {
private String sandboxMaxTotalTokenCount;
private Boolean isTokenBasedThrottlingEnabled = false;

public TokenBaseThrottlingCountHolder() {
public TokenBasedThrottlingCountHolder() {

}

public TokenBaseThrottlingCountHolder(String productionMaxPromptTokenCount, String productionMaxCompletionTokenCount,
public TokenBasedThrottlingCountHolder(String productionMaxPromptTokenCount, String productionMaxCompletionTokenCount,
String productionMaxTotalTokenCount, String sandboxMaxPromptTokenCount,
String sandboxMaxCompletionTokenCount, String sandboxMaxTotalTokenCount,
boolean isTokenBasedThrottlingEnabled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,10 @@

package org.wso2.carbon.apimgt.api.model;

import org.wso2.carbon.apimgt.api.TokenBaseThrottlingCountHolder;

public class AIConfiguration {
private String llmProviderId;
private AIEndpointConfiguration aiEndpointConfiguration;
private String llmProviderName;
private String llmProviderApiVersion;
private TokenBaseThrottlingCountHolder tokenBasedThrottlingConfiguration;

public AIEndpointConfiguration getAiEndpointConfiguration() {

return aiEndpointConfiguration;
}

public void setAiEndpointConfiguration(AIEndpointConfiguration aiEndpointConfiguration) {

this.aiEndpointConfiguration = aiEndpointConfiguration;
}

public TokenBaseThrottlingCountHolder getTokenBasedThrottlingConfiguration() {

return tokenBasedThrottlingConfiguration;
}

public void setTokenBasedThrottlingConfiguration(TokenBaseThrottlingCountHolder tokenBasedThrottlingConfiguration) {

this.tokenBasedThrottlingConfiguration = tokenBasedThrottlingConfiguration;
}

public String getLlmProviderName() {

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ public void setSoapToRestSequences(List<SOAPToRestSequence> soapToRestSequences)

private AIConfiguration aiConfiguration;

private BackendThrottlingConfiguration backendThrottlingConfiguration;

public AIConfiguration getAiConfiguration() {

return aiConfiguration;
Expand All @@ -246,6 +248,14 @@ public void setAiConfiguration(AIConfiguration AiConfiguration) {
this.aiConfiguration = AiConfiguration;
}

public BackendThrottlingConfiguration getBackendThrottlingConfiguration() {
return backendThrottlingConfiguration;
}

public void setBackendThrottlingConfiguration(BackendThrottlingConfiguration backendThrottlingConfiguration) {
this.backendThrottlingConfiguration = backendThrottlingConfiguration;
}

public String getAudience() {
return audience;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (c) 2024, WSO2 LLC. (http://www.wso2.com) All Rights Reserved.
*
* WSO2 LLC. licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
* in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.wso2.carbon.apimgt.api.model;

import org.wso2.carbon.apimgt.api.TokenBasedThrottlingCountHolder;

public class BackendThrottlingConfiguration {

private String productionMaxTps;
private String productionTimeUnit = "1000";
private String sandboxMaxTps;
private String sandboxTimeUnit = "1000";
private TokenBasedThrottlingCountHolder tokenBasedThrottlingConfiguration;

public String getProductionMaxTps() {

return productionMaxTps;
}

public void setProductionMaxTps(String productionMaxTps) {

this.productionMaxTps = productionMaxTps;
}

public String getProductionTimeUnit() {
return productionTimeUnit;
}

public void setProductionTimeUnit(String productionTimeUnit) {

this.productionTimeUnit = productionTimeUnit;
}

public String getSandboxMaxTps() {

return sandboxMaxTps;
}

public void setSandboxMaxTps(String sandboxMaxTps) {

this.sandboxMaxTps = sandboxMaxTps;
}

public String getSandboxTimeUnit() {

return sandboxTimeUnit;
}

public void setSandboxTimeUnit(String sandboxTimeUnit) {

this.sandboxTimeUnit = sandboxTimeUnit;
}

public TokenBasedThrottlingCountHolder getTokenBasedThrottlingConfiguration() {

return tokenBasedThrottlingConfiguration;
}

public void setTokenBasedThrottlingConfiguration(TokenBasedThrottlingCountHolder tokenBasedThrottlingConfiguration) {

this.tokenBasedThrottlingConfiguration = tokenBasedThrottlingConfiguration;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ public class EndpointSecurity {

private String apiKeyValue = null;

private String apiKeyIdentifierType = null;

private String customParameters = null;

private Map additionalProperties = new HashMap();
Expand Down Expand Up @@ -205,6 +207,16 @@ public void setApiKeyValue(String apiKeyValue) {
this.apiKeyValue = apiKeyValue;
}

public String getApiKeyIdentifierType() {

return apiKeyIdentifierType;
}

public void setApiKeyIdentifierType(String apiKeyIdentifierType) {

this.apiKeyIdentifierType = apiKeyIdentifierType;
}

public int getConnectionTimeoutDuration() {
return connectionTimeoutDuration;
}
Expand Down Expand Up @@ -244,6 +256,7 @@ public String toString() {
", clientSecret='" + clientSecret + '\'' +
", apiKeyIdentifier='" + apiKeyIdentifier + '\'' +
", apiKeyValue='" + apiKeyValue + '\'' +
", apiKeyIdentifierType='" + apiKeyIdentifierType + '\'' +
", customParameters='" + customParameters + '\'' +
", additionalProperties=" + additionalProperties +
", connectionTimeoutDuration=" + connectionTimeoutDuration +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import com.google.gson.Gson;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.http.client.utils.URIBuilder;
import org.apache.synapse.MessageContext;
import org.apache.synapse.api.ApiUtils;
import org.apache.synapse.core.axis2.Axis2MessageContext;
Expand All @@ -28,14 +27,11 @@
import org.apache.synapse.transport.passthru.util.RelayUtils;
import org.json.XML;
import org.wso2.carbon.apimgt.api.model.AIConfiguration;
import org.wso2.carbon.apimgt.api.model.AIEndpointConfiguration;
import org.wso2.carbon.apimgt.api.APIConstants;
import org.wso2.carbon.apimgt.api.APIManagementException;
import org.wso2.carbon.apimgt.api.LLMProviderService;
import org.wso2.carbon.apimgt.api.model.LLMProvider;
import org.wso2.carbon.apimgt.gateway.APIMgtGatewayConstants;
import org.wso2.carbon.apimgt.gateway.handlers.security.APISecurityUtils;
import org.wso2.carbon.apimgt.gateway.handlers.security.AuthenticationContext;
import org.wso2.carbon.apimgt.gateway.internal.DataHolder;
import org.wso2.carbon.apimgt.gateway.internal.ServiceReferenceHolder;
import org.wso2.carbon.apimgt.api.LLMProviderConfiguration;
Expand All @@ -48,7 +44,6 @@
import javax.ws.rs.core.MediaType;
import javax.xml.stream.XMLStreamException;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -133,8 +128,11 @@ private boolean processMessage(MessageContext messageContext, boolean isRequest)
LLMProviderConfiguration providerConfiguration = new Gson().fromJson(config,
LLMProviderConfiguration.class);
if (isRequest) {
addEndpointConfigurationToMessageContext(messageContext, aiConfiguration.getAiEndpointConfiguration(),
providerConfiguration);
Map<String, String> transportHeaders =
(Map<String, String>) ((Axis2MessageContext) messageContext)
.getAxis2MessageContext()
.getProperty(org.apache.axis2.context.MessageContext.TRANSPORT_HEADERS);
transportHeaders.remove(HttpHeaders.ACCEPT_ENCODING);
}
LLMProviderService llmProviderService = ServiceReferenceHolder.getInstance()
.getLLMProviderService(providerConfiguration.getConnectorType());
Expand Down Expand Up @@ -176,45 +174,6 @@ private boolean processMessage(MessageContext messageContext, boolean isRequest)
}
}

/**
* Adds endpoint configuration to the message context.
*
* @param messageContext the Synapse MessageContext
* @param aiEndpointConfiguration AI endpoint configuration
* @param providerConfiguration LLM provider configuration
*/
private void addEndpointConfigurationToMessageContext(MessageContext messageContext,
AIEndpointConfiguration aiEndpointConfiguration,
LLMProviderConfiguration providerConfiguration)
throws CryptoException, URISyntaxException {

if (aiEndpointConfiguration != null) {
org.apache.axis2.context.MessageContext axCtx =
((Axis2MessageContext) messageContext).getAxis2MessageContext();
AuthenticationContext authContext =
(AuthenticationContext) messageContext.getProperty(APISecurityUtils.API_AUTH_CONTEXT);
String authValue = authContext.getKeyType().equals(org.wso2.carbon.apimgt.impl.
APIConstants.API_KEY_TYPE_PRODUCTION)
? aiEndpointConfiguration.getProductionAuthValue()
: aiEndpointConfiguration.getSandboxAuthValue();
if (providerConfiguration.getAuthHeader() != null) {
Map<String, String> transportHeaders =
(Map<String, String>) axCtx.getProperty(org.apache.axis2.context.MessageContext
.TRANSPORT_HEADERS);
transportHeaders.put(providerConfiguration.getAuthHeader(), decryptSecret(authValue));

// TODO: Handle encoded scenario
transportHeaders.remove(HttpHeaders.ACCEPT_ENCODING);
} else if (providerConfiguration.getAuthQueryParameter() != null) {
URI updatedFullPath =
new URIBuilder((String) axCtx.getProperty(APIMgtGatewayConstants.REST_URL_POSTFIX))
.addParameter(providerConfiguration.getAuthQueryParameter(),
decryptSecret(authValue)).build();
axCtx.setProperty(APIMgtGatewayConstants.REST_URL_POSTFIX, updatedFullPath.toString());
}
}
}

/**
* Decrypts the secret value.
*
Expand Down
Loading

0 comments on commit d21ce73

Please sign in to comment.