Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
refactored code to get all models associated with a given address
  • Loading branch information
anandrgitnirman committed May 19, 2022
1 parent d930bfb commit 1952d52
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 36 deletions.
83 changes: 57 additions & 26 deletions training/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,20 @@ type ModelService struct {
storage *ModelStorage
serviceUrl string
}

func (service ModelService) GetAllModels(c context.Context, request *AccessibleModelsRequest) (*AccessibleModelsResponse, error) {
//TODO implement me
panic("implement me")
}

type NoModelSupportService struct {
}

func (n NoModelSupportService) GetAllModels(c context.Context, request *AccessibleModelsRequest) (*AccessibleModelsResponse, error) {
//TODO implement me
panic("implement me")
}

func (n NoModelSupportService) CreateModel(c context.Context, request *CreateModelRequest) (*ModelDetailsResponse, error) {
return &ModelDetailsResponse{Status: Status_ERROR},
fmt.Errorf("service end point is not defined or is invalid , please contact the AI developer")
Expand Down Expand Up @@ -77,7 +88,7 @@ func (service ModelService) storeModelDetails(request *CreateModelRequest, respo
}

func (service ModelService) deleteModelDetails(request *UpdateModelRequest, response *ModelDetailsResponse) (err error) {
key := service.getModelKeyToUpdate(request)
key := service.getModelKeyToUpdate(request.ModelDetailsRequest)
data, ok, err := service.storage.Get(key)
if ok && err != nil {
data.Status = "DELETED"
Expand All @@ -87,7 +98,7 @@ func (service ModelService) deleteModelDetails(request *UpdateModelRequest, resp
}

func (service ModelService) getModelDetails(request *UpdateModelRequest, response *ModelDetailsResponse) (data *ModelUserData, err error) {
key := service.getModelKeyToUpdate(request)
key := service.getModelKeyToUpdate(request.ModelDetailsRequest)
data, ok, err := service.storage.Get(key)

if err != nil {
Expand All @@ -102,8 +113,23 @@ func (service ModelService) getModelDetails(request *UpdateModelRequest, respons
}

func (service ModelService) updateModelDetails(request *UpdateModelRequest, response *ModelDetailsResponse) (err error) {
key := service.getModelKeyToUpdate(request)
key := service.getModelKeyToUpdate(request.ModelDetailsRequest)
if data, err := service.getModelDataForUpdate(request, response); err != nil {
if data, ok, err := service.storage.Get(key); err != nil && ok {
data.AuthorizedAddresses = request.AddressList
data.isPublic = request.IsPubliclyAccessible
data.UpdatedByAddress = request.ModelDetailsRequest.Authorization.SignerAddress
data.Status = string(response.Status)
}

err = service.storage.Put(key, data)
}
return
}

func (service ModelService) updateModelDetailsForStatus(request *ModelDetailsRequest, response *ModelDetailsResponse) (err error) {
key := service.getModelKeyToUpdate(request)
if data, err := service.getModelDataForStatusUpdate(request, response); err != nil {
err = service.storage.Put(key, data)
}
return
Expand All @@ -119,7 +145,7 @@ func (service ModelService) getModelKeyToCreate(request *CreateModelRequest, res
return
}

func (service ModelService) getModelKeyToUpdate(request *UpdateModelRequest) (key *ModelUserKey) {
func (service ModelService) getModelKeyToUpdate(request *ModelDetailsRequest) (key *ModelUserKey) {
key = &ModelUserKey{
OrganizationId: config.GetString(config.OrganizationId),
ServiceId: config.GetString(config.ServiceId),
Expand All @@ -131,20 +157,24 @@ func (service ModelService) getModelKeyToUpdate(request *UpdateModelRequest) (ke
}

func (service ModelService) getModelDataForUpdate(request *UpdateModelRequest, response *ModelDetailsResponse) (data *ModelUserData, err error) {
data, err = service.getModelDataForStatusUpdate(request.ModelDetailsRequest, response)
return
}

func (service ModelService) getModelDataForStatusUpdate(request *ModelDetailsRequest, response *ModelDetailsResponse) (data *ModelUserData, err error) {
key := service.getModelKeyToUpdate(request)
if data, ok, err := service.storage.Get(key); err != nil && ok {
data.AuthorizedAddresses = request.AddressList
data.isPublic = request.IsPubliclyAccessible
data.UpdatedByAddress = request.Authorization.UserAddress
data.Status = string(response.Status)
ok := false

if data, ok, err = service.storage.Get(key); err != nil && !ok {
log.WithError(fmt.Errorf("Issue with retrieving model data from storage"))
}
return
}

func (service ModelService) createModelData(request *CreateModelRequest, response *ModelDetailsResponse) (data *ModelUserData) {
data = &ModelUserData{
Status: string(response.Status),
CreatedByAddress: request.Authorization.UserAddress,
CreatedByAddress: request.Authorization.SignerAddress,
AuthorizedAddresses: request.AddressList,
isPublic: request.IsPubliclyAccessible,
ModelId: response.ModelDetails.ModelId,
Expand Down Expand Up @@ -182,7 +212,7 @@ func (service ModelService) CreateModel(c context.Context, request *CreateModelR

func (service ModelService) UpdateModelAccess(c context.Context, request *UpdateModelRequest) (response *ModelDetailsResponse,
err error) {
if err = service.verifySignerForUpdateModel(request.Authorization); err != nil {
if err = service.verifySignerForUpdateModel(request.ModelDetailsRequest.Authorization); err != nil {
return &ModelDetailsResponse{Status: Status_ERROR},
fmt.Errorf(" authentication FAILED , %v", err)
}
Expand All @@ -203,7 +233,7 @@ func (service ModelService) UpdateModelAccess(c context.Context, request *Update

func (service ModelService) DeleteModel(c context.Context, request *UpdateModelRequest) (response *ModelDetailsResponse,
err error) {
if err = service.verifySignerForDeleteModel(request.Authorization); err != nil {
if err = service.verifySignerForDeleteModel(request.ModelDetailsRequest.Authorization); err != nil {
return &ModelDetailsResponse{Status: Status_ERROR},
fmt.Errorf(" authentication FAILED , %v", err)
}
Expand All @@ -224,7 +254,7 @@ func (service ModelService) DeleteModel(c context.Context, request *UpdateModelR

func (service ModelService) GetModelStatus(c context.Context, request *ModelDetailsRequest) (response *ModelDetailsResponse,
err error) {
if err = service.verifySignerForGetTrainingStatus(request.Authorization); err != nil {
if err = service.verifySignerForGetModelStatus(request.Authorization); err != nil {
return &ModelDetailsResponse{Status: Status_ERROR},
fmt.Errorf(" authentication FAILED , %v", err)
}
Expand All @@ -233,9 +263,10 @@ func (service ModelService) GetModelStatus(c context.Context, request *ModelDeta

if client, err := service.getServiceClient(); err != nil {
response, err = client.GetModelStatus(ctx, request)
log.Infof("Updating model based on response from GetTrainingStatus")
//todo update data from client and return data stored in etcd

log.Infof("Updating modelG based on response from UpdateModel")
if err = service.updateModelDetailsForStatus(request, response); err != nil {
return response, fmt.Errorf("issue with storing Model Id in the Daemon Storage %v", err)
}
} else {
return &ModelDetailsResponse{Status: Status_ERROR}, fmt.Errorf("error in invoking service for Model Training")
}
Expand All @@ -245,36 +276,36 @@ func (service ModelService) GetModelStatus(c context.Context, request *ModelDeta
//message used to sign is of the form ("__create_model", mpe_address, current_block_number)
func (service *ModelService) verifySignerForCreateModel(request *AuthorizationDetails) error {
return utils.VerifySigner(service.getMessageBytes("__create_model", request),
request.GetSignature(), utils.ToChecksumAddress(request.UserAddress))
request.GetSignature(), utils.ToChecksumAddress(request.SignerAddress))
}

//message used to sign is of the form ("__update_model", mpe_address, current_block_number)
func (service *ModelService) verifySignerForUpdateModel(request *AuthorizationDetails) error {
return utils.VerifySigner(service.getMessageBytes("__update_model", request),
request.GetSignature(), utils.ToChecksumAddress(request.UserAddress))
request.GetSignature(), utils.ToChecksumAddress(request.SignerAddress))
}

//message used to sign is of the form ("__delete_model", mpe_address, current_block_number)
func (service *ModelService) verifySignerForDeleteModel(request *AuthorizationDetails) error {
return utils.VerifySigner(service.getMessageBytes("__delete_model", request),
request.GetSignature(), utils.ToChecksumAddress(request.UserAddress))
request.GetSignature(), utils.ToChecksumAddress(request.SignerAddress))
}

//message used to sign is of the form ("__delete_model", mpe_address, current_block_number)
func (service *ModelService) verifySignerForGetModelDetails(request *AuthorizationDetails) error {
return utils.VerifySigner(service.getMessageBytes("__get_model_details", request),
request.GetSignature(), utils.ToChecksumAddress(request.UserAddress))
func (service *ModelService) verifySignerForGetModelStatus(request *AuthorizationDetails) error {
return utils.VerifySigner(service.getMessageBytes("__get_model_status", request),
request.GetSignature(), utils.ToChecksumAddress(request.SignerAddress))
}

//message used to sign is of the form ("__get_training_status", mpe_address, current_block_number)
func (service *ModelService) verifySignerForGetTrainingStatus(request *AuthorizationDetails) error {
return utils.VerifySigner(service.getMessageBytes("__get_training_status", request),
request.GetSignature(), utils.ToChecksumAddress(request.UserAddress))
func (service *ModelService) verifySignatureForGetAllModels(request *AuthorizationDetails) error {
return utils.VerifySigner(service.getMessageBytes("__get_all_models", request),
request.GetSignature(), utils.ToChecksumAddress(request.SignerAddress))
}

//"__methodName", user_address, current_block_number
func (service *ModelService) getMessageBytes(prefixMessage string, request *AuthorizationDetails) []byte {
userAddress := utils.ToChecksumAddress(request.UserAddress)
userAddress := utils.ToChecksumAddress(request.SignerAddress)
message := bytes.Join([][]byte{
[]byte(prefixMessage),
userAddress.Bytes(),
Expand Down
28 changes: 18 additions & 10 deletions training/training.proto
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ message AuthorizationDetails {
//signature of the following message:
//("__methodName", user_address, current_block_number)
bytes signature = 2;
string user_address = 3;
int32 channel_id = 4;
string signer_address = 3;

}

enum Status {
Expand All @@ -39,6 +39,19 @@ message CreateModelRequest {
string model_description = 6;
}

message AccessibleModelsRequest {
string method_name = 1;
AuthorizationDetails authorization = 2;
}

message AllAccessibleModelsRequest {

}

message AccessibleModelsResponse {
repeated ModelDetails list_of_models= 1;
}

message ModelDetailsRequest {
ModelDetails model_details =1 ;
AuthorizationDetails authorization = 2;
Expand All @@ -47,14 +60,8 @@ message ModelDetailsRequest {

message UpdateModelRequest {
repeated string addressList = 1;
ModelDetails model_details = 2;
Status status = 3;
string message = 4;
bytes signature = 5;
//This should be in the list of address maintained against the model id
string signed_address = 6;
AuthorizationDetails authorization = 7;
bool is_publicly_accessible = 8;
ModelDetailsRequest model_details_request = 2;
bool is_publicly_accessible = 3;
}


Expand All @@ -79,6 +86,7 @@ service Model {
rpc update_model_access(UpdateModelRequest) returns (ModelDetailsResponse) {}
rpc delete_model(UpdateModelRequest) returns (ModelDetailsResponse) {}
rpc get_model_status(ModelDetailsRequest) returns (ModelDetailsResponse) {}
rpc get_all_models(AccessibleModelsRequest) returns (AccessibleModelsResponse) {}


}

0 comments on commit 1952d52

Please sign in to comment.