From 1952d52629471b3bc74341101c63a2d825d56b9d Mon Sep 17 00:00:00 2001 From: anandrgitnirman Date: Thu, 19 May 2022 23:45:04 +0530 Subject: [PATCH] #565 refactored code to get all models associated with a given address --- training/service.go | 83 ++++++++++++++++++++++++++++------------- training/training.proto | 28 +++++++++----- 2 files changed, 75 insertions(+), 36 deletions(-) diff --git a/training/service.go b/training/service.go index b52b2407..e40cbfdd 100644 --- a/training/service.go +++ b/training/service.go @@ -25,9 +25,20 @@ type ModelService struct { storage *ModelStorage serviceUrl string } + +func (service ModelService) GetAllModels(c context.Context, request *AccessibleModelsRequest) (*AccessibleModelsResponse, error) { + //TODO implement me + panic("implement me") +} + type NoModelSupportService struct { } +func (n NoModelSupportService) GetAllModels(c context.Context, request *AccessibleModelsRequest) (*AccessibleModelsResponse, error) { + //TODO implement me + panic("implement me") +} + func (n NoModelSupportService) CreateModel(c context.Context, request *CreateModelRequest) (*ModelDetailsResponse, error) { return &ModelDetailsResponse{Status: Status_ERROR}, fmt.Errorf("service end point is not defined or is invalid , please contact the AI developer") @@ -77,7 +88,7 @@ func (service ModelService) storeModelDetails(request *CreateModelRequest, respo } func (service ModelService) deleteModelDetails(request *UpdateModelRequest, response *ModelDetailsResponse) (err error) { - key := service.getModelKeyToUpdate(request) + key := service.getModelKeyToUpdate(request.ModelDetailsRequest) data, ok, err := service.storage.Get(key) if ok && err != nil { data.Status = "DELETED" @@ -87,7 +98,7 @@ func (service ModelService) deleteModelDetails(request *UpdateModelRequest, resp } func (service ModelService) getModelDetails(request *UpdateModelRequest, response *ModelDetailsResponse) (data *ModelUserData, err error) { - key := service.getModelKeyToUpdate(request) + key := service.getModelKeyToUpdate(request.ModelDetailsRequest) data, ok, err := service.storage.Get(key) if err != nil { @@ -102,8 +113,23 @@ func (service ModelService) getModelDetails(request *UpdateModelRequest, respons } func (service ModelService) updateModelDetails(request *UpdateModelRequest, response *ModelDetailsResponse) (err error) { - key := service.getModelKeyToUpdate(request) + key := service.getModelKeyToUpdate(request.ModelDetailsRequest) if data, err := service.getModelDataForUpdate(request, response); err != nil { + if data, ok, err := service.storage.Get(key); err != nil && ok { + data.AuthorizedAddresses = request.AddressList + data.isPublic = request.IsPubliclyAccessible + data.UpdatedByAddress = request.ModelDetailsRequest.Authorization.SignerAddress + data.Status = string(response.Status) + } + + err = service.storage.Put(key, data) + } + return +} + +func (service ModelService) updateModelDetailsForStatus(request *ModelDetailsRequest, response *ModelDetailsResponse) (err error) { + key := service.getModelKeyToUpdate(request) + if data, err := service.getModelDataForStatusUpdate(request, response); err != nil { err = service.storage.Put(key, data) } return @@ -119,7 +145,7 @@ func (service ModelService) getModelKeyToCreate(request *CreateModelRequest, res return } -func (service ModelService) getModelKeyToUpdate(request *UpdateModelRequest) (key *ModelUserKey) { +func (service ModelService) getModelKeyToUpdate(request *ModelDetailsRequest) (key *ModelUserKey) { key = &ModelUserKey{ OrganizationId: config.GetString(config.OrganizationId), ServiceId: config.GetString(config.ServiceId), @@ -131,12 +157,16 @@ func (service ModelService) getModelKeyToUpdate(request *UpdateModelRequest) (ke } func (service ModelService) getModelDataForUpdate(request *UpdateModelRequest, response *ModelDetailsResponse) (data *ModelUserData, err error) { + data, err = service.getModelDataForStatusUpdate(request.ModelDetailsRequest, response) + return +} + +func (service ModelService) getModelDataForStatusUpdate(request *ModelDetailsRequest, response *ModelDetailsResponse) (data *ModelUserData, err error) { key := service.getModelKeyToUpdate(request) - if data, ok, err := service.storage.Get(key); err != nil && ok { - data.AuthorizedAddresses = request.AddressList - data.isPublic = request.IsPubliclyAccessible - data.UpdatedByAddress = request.Authorization.UserAddress - data.Status = string(response.Status) + ok := false + + if data, ok, err = service.storage.Get(key); err != nil && !ok { + log.WithError(fmt.Errorf("Issue with retrieving model data from storage")) } return } @@ -144,7 +174,7 @@ func (service ModelService) getModelDataForUpdate(request *UpdateModelRequest, r func (service ModelService) createModelData(request *CreateModelRequest, response *ModelDetailsResponse) (data *ModelUserData) { data = &ModelUserData{ Status: string(response.Status), - CreatedByAddress: request.Authorization.UserAddress, + CreatedByAddress: request.Authorization.SignerAddress, AuthorizedAddresses: request.AddressList, isPublic: request.IsPubliclyAccessible, ModelId: response.ModelDetails.ModelId, @@ -182,7 +212,7 @@ func (service ModelService) CreateModel(c context.Context, request *CreateModelR func (service ModelService) UpdateModelAccess(c context.Context, request *UpdateModelRequest) (response *ModelDetailsResponse, err error) { - if err = service.verifySignerForUpdateModel(request.Authorization); err != nil { + if err = service.verifySignerForUpdateModel(request.ModelDetailsRequest.Authorization); err != nil { return &ModelDetailsResponse{Status: Status_ERROR}, fmt.Errorf(" authentication FAILED , %v", err) } @@ -203,7 +233,7 @@ func (service ModelService) UpdateModelAccess(c context.Context, request *Update func (service ModelService) DeleteModel(c context.Context, request *UpdateModelRequest) (response *ModelDetailsResponse, err error) { - if err = service.verifySignerForDeleteModel(request.Authorization); err != nil { + if err = service.verifySignerForDeleteModel(request.ModelDetailsRequest.Authorization); err != nil { return &ModelDetailsResponse{Status: Status_ERROR}, fmt.Errorf(" authentication FAILED , %v", err) } @@ -224,7 +254,7 @@ func (service ModelService) DeleteModel(c context.Context, request *UpdateModelR func (service ModelService) GetModelStatus(c context.Context, request *ModelDetailsRequest) (response *ModelDetailsResponse, err error) { - if err = service.verifySignerForGetTrainingStatus(request.Authorization); err != nil { + if err = service.verifySignerForGetModelStatus(request.Authorization); err != nil { return &ModelDetailsResponse{Status: Status_ERROR}, fmt.Errorf(" authentication FAILED , %v", err) } @@ -233,9 +263,10 @@ func (service ModelService) GetModelStatus(c context.Context, request *ModelDeta if client, err := service.getServiceClient(); err != nil { response, err = client.GetModelStatus(ctx, request) - log.Infof("Updating model based on response from GetTrainingStatus") - //todo update data from client and return data stored in etcd - + log.Infof("Updating modelG based on response from UpdateModel") + if err = service.updateModelDetailsForStatus(request, response); err != nil { + return response, fmt.Errorf("issue with storing Model Id in the Daemon Storage %v", err) + } } else { return &ModelDetailsResponse{Status: Status_ERROR}, fmt.Errorf("error in invoking service for Model Training") } @@ -245,36 +276,36 @@ func (service ModelService) GetModelStatus(c context.Context, request *ModelDeta //message used to sign is of the form ("__create_model", mpe_address, current_block_number) func (service *ModelService) verifySignerForCreateModel(request *AuthorizationDetails) error { return utils.VerifySigner(service.getMessageBytes("__create_model", request), - request.GetSignature(), utils.ToChecksumAddress(request.UserAddress)) + request.GetSignature(), utils.ToChecksumAddress(request.SignerAddress)) } //message used to sign is of the form ("__update_model", mpe_address, current_block_number) func (service *ModelService) verifySignerForUpdateModel(request *AuthorizationDetails) error { return utils.VerifySigner(service.getMessageBytes("__update_model", request), - request.GetSignature(), utils.ToChecksumAddress(request.UserAddress)) + request.GetSignature(), utils.ToChecksumAddress(request.SignerAddress)) } //message used to sign is of the form ("__delete_model", mpe_address, current_block_number) func (service *ModelService) verifySignerForDeleteModel(request *AuthorizationDetails) error { return utils.VerifySigner(service.getMessageBytes("__delete_model", request), - request.GetSignature(), utils.ToChecksumAddress(request.UserAddress)) + request.GetSignature(), utils.ToChecksumAddress(request.SignerAddress)) } //message used to sign is of the form ("__delete_model", mpe_address, current_block_number) -func (service *ModelService) verifySignerForGetModelDetails(request *AuthorizationDetails) error { - return utils.VerifySigner(service.getMessageBytes("__get_model_details", request), - request.GetSignature(), utils.ToChecksumAddress(request.UserAddress)) +func (service *ModelService) verifySignerForGetModelStatus(request *AuthorizationDetails) error { + return utils.VerifySigner(service.getMessageBytes("__get_model_status", request), + request.GetSignature(), utils.ToChecksumAddress(request.SignerAddress)) } //message used to sign is of the form ("__get_training_status", mpe_address, current_block_number) -func (service *ModelService) verifySignerForGetTrainingStatus(request *AuthorizationDetails) error { - return utils.VerifySigner(service.getMessageBytes("__get_training_status", request), - request.GetSignature(), utils.ToChecksumAddress(request.UserAddress)) +func (service *ModelService) verifySignatureForGetAllModels(request *AuthorizationDetails) error { + return utils.VerifySigner(service.getMessageBytes("__get_all_models", request), + request.GetSignature(), utils.ToChecksumAddress(request.SignerAddress)) } //"__methodName", user_address, current_block_number func (service *ModelService) getMessageBytes(prefixMessage string, request *AuthorizationDetails) []byte { - userAddress := utils.ToChecksumAddress(request.UserAddress) + userAddress := utils.ToChecksumAddress(request.SignerAddress) message := bytes.Join([][]byte{ []byte(prefixMessage), userAddress.Bytes(), diff --git a/training/training.proto b/training/training.proto index 2adff7c6..e334c230 100644 --- a/training/training.proto +++ b/training/training.proto @@ -13,8 +13,8 @@ message AuthorizationDetails { //signature of the following message: //("__methodName", user_address, current_block_number) bytes signature = 2; - string user_address = 3; - int32 channel_id = 4; + string signer_address = 3; + } enum Status { @@ -39,6 +39,19 @@ message CreateModelRequest { string model_description = 6; } +message AccessibleModelsRequest { + string method_name = 1; + AuthorizationDetails authorization = 2; +} + +message AllAccessibleModelsRequest { + +} + +message AccessibleModelsResponse { + repeated ModelDetails list_of_models= 1; +} + message ModelDetailsRequest { ModelDetails model_details =1 ; AuthorizationDetails authorization = 2; @@ -47,14 +60,8 @@ message ModelDetailsRequest { message UpdateModelRequest { repeated string addressList = 1; - ModelDetails model_details = 2; - Status status = 3; - string message = 4; - bytes signature = 5; - //This should be in the list of address maintained against the model id - string signed_address = 6; - AuthorizationDetails authorization = 7; - bool is_publicly_accessible = 8; + ModelDetailsRequest model_details_request = 2; + bool is_publicly_accessible = 3; } @@ -79,6 +86,7 @@ service Model { rpc update_model_access(UpdateModelRequest) returns (ModelDetailsResponse) {} rpc delete_model(UpdateModelRequest) returns (ModelDetailsResponse) {} rpc get_model_status(ModelDetailsRequest) returns (ModelDetailsResponse) {} + rpc get_all_models(AccessibleModelsRequest) returns (AccessibleModelsResponse) {} }