diff --git a/snetd/cmd/components.go b/snetd/cmd/components.go index 81cd80af..5c719a08 100644 --- a/snetd/cmd/components.go +++ b/snetd/cmd/components.go @@ -58,7 +58,7 @@ type Components struct { freeCallLockerStorage *storage.PrefixedAtomicStorage tokenManager token.Manager tokenService *escrow.TokenService - modelService *training.ModelService + modelService training.ModelServer modelUserStorage *training.ModelStorage } @@ -558,7 +558,7 @@ func (components *Components) ModelUserStorage() *training.ModelStorage { return components.modelUserStorage } -func (components *Components) ModelService() *training.ModelService { +func (components *Components) ModelService() training.ModelServer { if components.modelService != nil { return components.modelService } diff --git a/training/service.go b/training/service.go index 90e256a5..93c23363 100644 --- a/training/service.go +++ b/training/service.go @@ -2,26 +2,59 @@ package training import ( + "bytes" "fmt" + "github.com/ethereum/go-ethereum/accounts/abi" "github.com/singnet/snet-daemon/blockchain" "github.com/singnet/snet-daemon/config" "github.com/singnet/snet-daemon/escrow" + "github.com/singnet/snet-daemon/utils" log "github.com/sirupsen/logrus" "golang.org/x/net/context" "google.golang.org/grpc" + "math/big" "time" ) +type IService interface { +} type ModelService struct { serviceMetaData *blockchain.ServiceMetadata organizationMetaData *blockchain.OrganizationMetaData channelService escrow.PaymentChannelService storage *ModelStorage + serviceUrl string +} +type NoModelSupportService struct { } -func getServiceClient() (client ModelClient, err error) { - serviceURL := config.GetString(config.ModelTrainingEndpoint) - conn, err := grpc.Dial(serviceURL, grpc.WithInsecure()) +func (n NoModelSupportService) CreateModel(c context.Context, request *CreateModelRequest) (*ModelDetailsResponse, error) { + return &ModelDetailsResponse{Status: Status_ERROR}, + fmt.Errorf("service end point is not defined or is invalid , please contact the AI developer") +} + +func (n NoModelSupportService) UpdateModelAccess(c context.Context, request *UpdateModelRequest) (*ModelDetailsResponse, error) { + return &ModelDetailsResponse{Status: Status_ERROR}, + fmt.Errorf("service end point is not defined or is invalid , please contact the AI developer") +} + +func (n NoModelSupportService) DeleteModel(c context.Context, request *UpdateModelRequest) (*ModelDetailsResponse, error) { + return &ModelDetailsResponse{Status: Status_ERROR}, + fmt.Errorf("service end point is not defined or is invalid , please contact the AI developer") +} + +func (n NoModelSupportService) GetModelDetails(c context.Context, id *ModelDetailsRequest) (*ModelDetailsResponse, error) { + return &ModelDetailsResponse{Status: Status_ERROR}, + fmt.Errorf("service end point is not defined or is invalid , please contact the AI developer") +} + +func (n NoModelSupportService) GetTrainingStatus(c context.Context, id *ModelDetailsRequest) (*ModelDetailsResponse, error) { + return &ModelDetailsResponse{Status: Status_ERROR}, + fmt.Errorf("service end point is not defined or is invalid for training , please contact the AI developer") +} + +func (m ModelService) getServiceClient() (client ModelClient, err error) { + conn, err := grpc.Dial(m.serviceUrl, grpc.WithInsecure()) if err != nil { log.WithError(err).Warningf("unable to connect to grpc endpoint: %v", err) return nil, err @@ -31,75 +64,182 @@ func getServiceClient() (client ModelClient, err error) { client = NewModelClient(conn) return } +func (m ModelService) storeModelDetails(request *CreateModelRequest, response *ModelDetailsResponse) (err error) { + key := m.getModelKeyToCreate(request, response) + data := m.createModelData(request, response) + err = m.storage.Put(key, data) + return +} -func (m ModelService) CreateModel(c context.Context, request *CreateModelRequest) (*ModelDetailsResponse, error) { +func (m ModelService) updateModelDetails(request *UpdateModelRequest, response *ModelDetailsResponse) (err error) { + key := m.getModelKeyToUpdate(request) + data := m.getModelDataForUpdate(request) + err = m.storage.Put(key, data) + return +} +func (m ModelService) getModelKeyToCreate(request *CreateModelRequest, response *ModelDetailsResponse) (key *ModelUserKey) { + key = &ModelUserKey{ + OrganizationId: config.GetString(config.OrganizationId), + ServiceId: config.GetString(config.ServiceId), + GroupID: m.organizationMetaData.GetGroupIdString(), + MethodName: request.MethodName, + ModelId: response.ModelDetails.ModelId, + } + return +} + +func (m ModelService) getModelKeyToUpdate(request *UpdateModelRequest) (key *ModelUserKey) { + key = &ModelUserKey{ + OrganizationId: config.GetString(config.OrganizationId), + ServiceId: config.GetString(config.ServiceId), + GroupID: m.organizationMetaData.GetGroupIdString(), + MethodName: request.ModelDetails.MethodName, + ModelId: request.ModelDetails.ModelId, + } + return +} + +func (m ModelService) getModelDataForUpdate(request *UpdateModelRequest) (key *ModelUserData) { + key = &ModelUserData{ + + ModelId: request.ModelDetails.ModelId, + } + return +} + +func (m ModelService) createModelData(request *CreateModelRequest, response *ModelDetailsResponse) (data *ModelUserData) { + data = &ModelUserData{ + Status: string(response.Status), + CreatedByAddress: request.Authorization.UserAddress, + AuthorizedAddresses: request.Address, + isPublic: request.IsPubliclyAccessible, + ModelId: response.ModelDetails.ModelId, + } + return +} + +func (m ModelService) CreateModel(c context.Context, request *CreateModelRequest) (response *ModelDetailsResponse, + err error) { // verify the request that has come in // make a call to the client // if the response is successful , store details in etcd // send back the response to the client - //TODO implement me - fmt.Println("Adding addresses to etcd ") ctx, cancel := context.WithTimeout(context.Background(), time.Second*200) defer cancel() - if client, err := getServiceClient(); err == nil { - return client.CreateModel(ctx, request) + if client, err := m.getServiceClient(); err == nil { + response, err = client.CreateModel(ctx, request) + if err == nil { + //store the details in etcd + log.Debugf("Adding addresses to etcd ...") + if err = m.storeModelDetails(request, response); err != nil { + return response, fmt.Errorf("issue with storing Model Id in the Daemon Storage %v", err) + } + } } else { - return &ModelDetailsResponse{Status: Status_ERROR}, fmt.Errorf("Error in invoking service for Model Training") + return &ModelDetailsResponse{Status: Status_ERROR}, + fmt.Errorf("error in invoking service for Model Training %v", err) } - + return } -func (m ModelService) UpdateModelAccess(c context.Context, request *UpdateModelRequest) (*ModelDetailsResponse, error) { +func (m ModelService) UpdateModelAccess(c context.Context, request *UpdateModelRequest) (response *ModelDetailsResponse, + err error) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() fmt.Println("Updating model access addresses to etcd ") - if client, err := getServiceClient(); err != nil { - return client.UpdateModelAccess(ctx, request) + if client, err := m.getServiceClient(); err != nil { + response, err = client.UpdateModelAccess(ctx, request) + if err = m.updateModelDetails(request, response); err != nil { + return response, fmt.Errorf("issue with storing Model Id in the Daemon Storage %v", err) + } } else { return &ModelDetailsResponse{Status: Status_ERROR}, fmt.Errorf("error in invoking service for Model Training") } + return } func (m ModelService) DeleteModel(c context.Context, request *UpdateModelRequest) (*ModelDetailsResponse, error) { fmt.Println("Deleting model addresses from etcd ") ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - if client, err := getServiceClient(); err != nil { + if client, err := m.getServiceClient(); err != nil { return client.DeleteModel(ctx, request) } else { return &ModelDetailsResponse{Status: Status_ERROR}, fmt.Errorf("error in invoking service for Model Training") } } -func (m ModelService) GetModelDetails(c context.Context, id *ModelId) (*ModelDetailsResponse, error) { - fmt.Println("Getting model Details from etcd ......") - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if client, err := getServiceClient(); err != nil { - return client.GetModelDetails(ctx, id) - } else { - return &ModelDetailsResponse{Status: Status_ERROR}, fmt.Errorf("error in invoking service method GetModelDetails for Model Training") - } +func (m ModelService) GetModelDetails(c context.Context, id *ModelDetailsRequest) (*ModelDetailsResponse, error) { + fmt.Println("Just get the model details stored in ETCD") + return &ModelDetailsResponse{Status: Status_ERROR}, fmt.Errorf("error in invoking service method GetModelDetails for Model Training") + } -func (m ModelService) GetTrainingStatus(c context.Context, id *ModelId) (*ModelDetailsResponse, error) { +func (m ModelService) GetTrainingStatus(c context.Context, id *ModelDetailsRequest) (*ModelDetailsResponse, error) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) - fmt.Println("Getting Training Status details from etcd .....") + fmt.Println("Update the Training Status details from etcd .....") defer cancel() - if client, err := getServiceClient(); err != nil { + if client, err := m.getServiceClient(); err != nil { return client.GetModelDetails(ctx, id) } else { return &ModelDetailsResponse{Status: Status_ERROR}, fmt.Errorf("error in invoking service method GetTrainingStatus for Model Training") } } +//message used to sign is of the form ("__create_model", mpe_address, current_block_number) +func (service *ModelService) verifySignerForCreateModel(request *AuthorizationDetails) error { + return utils.VerifySigner(service.getMessageBytes("__create_model", request), + request.GetSignature(), utils.ToChecksumAddress(request.UserAddress)) +} + +//message used to sign is of the form ("__update_model", mpe_address, current_block_number) +func (service *ModelService) verifySignerForUpdateModel(request *AuthorizationDetails) error { + return utils.VerifySigner(service.getMessageBytes("__update_model", request), + request.GetSignature(), utils.ToChecksumAddress(request.UserAddress)) +} + +//message used to sign is of the form ("__delete_model", mpe_address, current_block_number) +func (service *ModelService) verifySignerForDeleteModel(request *AuthorizationDetails) error { + return utils.VerifySigner(service.getMessageBytes("__delete_model", request), + request.GetSignature(), utils.ToChecksumAddress(request.UserAddress)) +} + +//message used to sign is of the form ("__delete_model", mpe_address, current_block_number) +func (service *ModelService) verifySignerForGetModelDetails(request *AuthorizationDetails) error { + return utils.VerifySigner(service.getMessageBytes("__get_model_details", request), + request.GetSignature(), utils.ToChecksumAddress(request.UserAddress)) +} + +//message used to sign is of the form ("__get_training_status", mpe_address, current_block_number) +func (service *ModelService) verifySignerForGetTrainingStatus(request *AuthorizationDetails) error { + return utils.VerifySigner(service.getMessageBytes("__get_training_status", request), + request.GetSignature(), utils.ToChecksumAddress(request.UserAddress)) +} + +//"__methodName", user_address, current_block_number +func (service *ModelService) getMessageBytes(prefixMessage string, request *AuthorizationDetails) []byte { + userAddress := utils.ToChecksumAddress(request.UserAddress) + message := bytes.Join([][]byte{ + []byte(prefixMessage), + userAddress.Bytes(), + abi.U256(big.NewInt(int64(request.CurrentBlock))), + }, nil) + return message +} + func NewModelService(channelService escrow.PaymentChannelService, serMetaData *blockchain.ServiceMetadata, - orgMetadata *blockchain.OrganizationMetaData, storage *ModelStorage) *ModelService { - return &ModelService{ - channelService: channelService, - serviceMetaData: serMetaData, - organizationMetaData: orgMetadata, - storage: storage, + orgMetadata *blockchain.OrganizationMetaData, storage *ModelStorage) ModelServer { + serviceURL := config.GetString(config.ModelTrainingEndpoint) + if config.IsValidUrl(serviceURL) { + return &ModelService{ + channelService: channelService, + serviceMetaData: serMetaData, + organizationMetaData: orgMetadata, + storage: storage, + serviceUrl: serviceURL, + } + } else { + return &NoModelSupportService{} } } diff --git a/training/storage.go b/training/storage.go index f9c471f0..a9287120 100644 --- a/training/storage.go +++ b/training/storage.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/singnet/snet-daemon/storage" "github.com/singnet/snet-daemon/utils" - "math/big" "reflect" ) @@ -25,13 +24,13 @@ type ModelUserKey struct { OrganizationId string ServiceId string GroupID string - ChannelId *big.Int + MethodName string ModelId string } func (key *ModelUserKey) String() string { return fmt.Sprintf("{ID:%v/%v/%v/%v/%v}", key.OrganizationId, - key.ServiceId, key.GroupID, key.ChannelId, key.ModelId) + key.ServiceId, key.GroupID, key.MethodName, key.ModelId) } type ModelUserData struct { @@ -39,6 +38,7 @@ type ModelUserData struct { AuthorizedAddresses []string Status string CreatedByAddress string + ModelId string } func serializeModelKey(key interface{}) (serialized string, err error) { diff --git a/training/training.proto b/training/training.proto index 7caf09b9..5f9309aa 100644 --- a/training/training.proto +++ b/training/training.proto @@ -1,10 +1,11 @@ syntax = "proto3"; package training; //Please note that the AI developers need to provide a server implementation of the gprc server of this proto. -message ModelId { +message ModelDetails { string model_id = 1; string method_name = 2; - AuthorizationDetails authorization = 5; + string description = 3; + } message AuthorizationDetails { @@ -34,12 +35,19 @@ message CreateModelRequest { bool is_publicly_accessible = 4; AuthorizationDetails authorization = 5; + + string model_description = 6; +} + +message ModelDetailsRequest { + ModelDetails model_details =1 ; + AuthorizationDetails authorization = 2; } message UpdateModelRequest { AuthorizedAddressList addressList = 1; - string model_id = 2; + ModelDetails model_details = 2; Status status = 3; string message = 4; bytes signature = 5; @@ -54,10 +62,9 @@ message AuthorizedAddressList { message ModelDetailsResponse { Status status = 1; - string model_id =2; + ModelDetails model_details =2; string message = 3; AuthorizedAddressList address_list = 4; - string method_name =5; //Added the below two for more details string service_id = 6; string organization_id = 7; @@ -72,7 +79,7 @@ service Model { rpc create_model(CreateModelRequest) returns (ModelDetailsResponse) {} rpc update_model_access(UpdateModelRequest) returns (ModelDetailsResponse) {} rpc delete_model(UpdateModelRequest) returns (ModelDetailsResponse) {} - rpc get_model_details(ModelId) returns (ModelDetailsResponse) {} - rpc get_training_status (ModelId) returns (ModelDetailsResponse) {} + rpc get_model_details(ModelDetailsRequest) returns (ModelDetailsResponse) {} + rpc get_training_status (ModelDetailsRequest) returns (ModelDetailsResponse) {} } diff --git a/utils/common.go b/utils/common.go new file mode 100644 index 00000000..33fbab80 --- /dev/null +++ b/utils/common.go @@ -0,0 +1,45 @@ +package utils + +import ( + "bytes" + "encoding/gob" + "github.com/ethereum/go-ethereum/common" + "github.com/singnet/snet-daemon/authutils" + log "github.com/sirupsen/logrus" +) + +func Serialize(value interface{}) (slice string, err error) { + var b bytes.Buffer + e := gob.NewEncoder(&b) + err = e.Encode(value) + if err != nil { + return + } + + slice = string(b.Bytes()) + return +} + +func Deserialize(slice string, value interface{}) (err error) { + b := bytes.NewBuffer([]byte(slice)) + d := gob.NewDecoder(b) + err = d.Decode(value) + return +} +func VerifySigner(message []byte, signature []byte, signer common.Address) error { + derivedSigner, err := authutils.GetSignerAddressFromMessage(message, signature) + if err != nil { + log.Error(err) + return err + } + if err = authutils.VerifyAddress(*derivedSigner, signer); err != nil { + return err + } + return nil +} + +func ToChecksumAddress(hexAddress string) common.Address { + address := common.HexToAddress(hexAddress) + mixedAddress := common.NewMixedcaseAddress(address) + return mixedAddress.Address() +} diff --git a/utils/serialize.go b/utils/serialize.go deleted file mode 100644 index d4b4ce06..00000000 --- a/utils/serialize.go +++ /dev/null @@ -1,25 +0,0 @@ -package utils - -import ( - "bytes" - "encoding/gob" -) - -func Serialize(value interface{}) (slice string, err error) { - var b bytes.Buffer - e := gob.NewEncoder(&b) - err = e.Encode(value) - if err != nil { - return - } - - slice = string(b.Bytes()) - return -} - -func Deserialize(slice string, value interface{}) (err error) { - b := bytes.NewBuffer([]byte(slice)) - d := gob.NewDecoder(b) - err = d.Decode(value) - return -}