Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Code for Authorization checks on calls made to CRUD operations for Model ID
  • Loading branch information
anandrgitnirman committed May 12, 2022
1 parent 6192e3f commit b0cc8df
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 69 deletions.
4 changes: 2 additions & 2 deletions snetd/cmd/components.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ type Components struct {
freeCallLockerStorage *storage.PrefixedAtomicStorage
tokenManager token.Manager
tokenService *escrow.TokenService
modelService *training.ModelService
modelService training.ModelServer
modelUserStorage *training.ModelStorage
}

Expand Down Expand Up @@ -558,7 +558,7 @@ func (components *Components) ModelUserStorage() *training.ModelStorage {
return components.modelUserStorage
}

func (components *Components) ModelService() *training.ModelService {
func (components *Components) ModelService() training.ModelServer {
if components.modelService != nil {
return components.modelService
}
Expand Down
204 changes: 172 additions & 32 deletions training/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,59 @@
package training

import (
"bytes"
"fmt"
"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/singnet/snet-daemon/blockchain"
"github.com/singnet/snet-daemon/config"
"github.com/singnet/snet-daemon/escrow"
"github.com/singnet/snet-daemon/utils"
log "github.com/sirupsen/logrus"
"golang.org/x/net/context"
"google.golang.org/grpc"
"math/big"
"time"
)

type IService interface {
}
type ModelService struct {
serviceMetaData *blockchain.ServiceMetadata
organizationMetaData *blockchain.OrganizationMetaData
channelService escrow.PaymentChannelService
storage *ModelStorage
serviceUrl string
}
type NoModelSupportService struct {
}

func getServiceClient() (client ModelClient, err error) {
serviceURL := config.GetString(config.ModelTrainingEndpoint)
conn, err := grpc.Dial(serviceURL, grpc.WithInsecure())
func (n NoModelSupportService) CreateModel(c context.Context, request *CreateModelRequest) (*ModelDetailsResponse, error) {
return &ModelDetailsResponse{Status: Status_ERROR},
fmt.Errorf("service end point is not defined or is invalid , please contact the AI developer")
}

func (n NoModelSupportService) UpdateModelAccess(c context.Context, request *UpdateModelRequest) (*ModelDetailsResponse, error) {
return &ModelDetailsResponse{Status: Status_ERROR},
fmt.Errorf("service end point is not defined or is invalid , please contact the AI developer")
}

func (n NoModelSupportService) DeleteModel(c context.Context, request *UpdateModelRequest) (*ModelDetailsResponse, error) {
return &ModelDetailsResponse{Status: Status_ERROR},
fmt.Errorf("service end point is not defined or is invalid , please contact the AI developer")
}

func (n NoModelSupportService) GetModelDetails(c context.Context, id *ModelDetailsRequest) (*ModelDetailsResponse, error) {
return &ModelDetailsResponse{Status: Status_ERROR},
fmt.Errorf("service end point is not defined or is invalid , please contact the AI developer")
}

func (n NoModelSupportService) GetTrainingStatus(c context.Context, id *ModelDetailsRequest) (*ModelDetailsResponse, error) {
return &ModelDetailsResponse{Status: Status_ERROR},
fmt.Errorf("service end point is not defined or is invalid for training , please contact the AI developer")
}

func (m ModelService) getServiceClient() (client ModelClient, err error) {
conn, err := grpc.Dial(m.serviceUrl, grpc.WithInsecure())
if err != nil {
log.WithError(err).Warningf("unable to connect to grpc endpoint: %v", err)
return nil, err
Expand All @@ -31,75 +64,182 @@ func getServiceClient() (client ModelClient, err error) {
client = NewModelClient(conn)
return
}
func (m ModelService) storeModelDetails(request *CreateModelRequest, response *ModelDetailsResponse) (err error) {
key := m.getModelKeyToCreate(request, response)
data := m.createModelData(request, response)
err = m.storage.Put(key, data)
return
}

func (m ModelService) CreateModel(c context.Context, request *CreateModelRequest) (*ModelDetailsResponse, error) {
func (m ModelService) updateModelDetails(request *UpdateModelRequest, response *ModelDetailsResponse) (err error) {
key := m.getModelKeyToUpdate(request)
data := m.getModelDataForUpdate(request)
err = m.storage.Put(key, data)
return
}
func (m ModelService) getModelKeyToCreate(request *CreateModelRequest, response *ModelDetailsResponse) (key *ModelUserKey) {
key = &ModelUserKey{
OrganizationId: config.GetString(config.OrganizationId),
ServiceId: config.GetString(config.ServiceId),
GroupID: m.organizationMetaData.GetGroupIdString(),
MethodName: request.MethodName,
ModelId: response.ModelDetails.ModelId,
}
return
}

func (m ModelService) getModelKeyToUpdate(request *UpdateModelRequest) (key *ModelUserKey) {
key = &ModelUserKey{
OrganizationId: config.GetString(config.OrganizationId),
ServiceId: config.GetString(config.ServiceId),
GroupID: m.organizationMetaData.GetGroupIdString(),
MethodName: request.ModelDetails.MethodName,
ModelId: request.ModelDetails.ModelId,
}
return
}

func (m ModelService) getModelDataForUpdate(request *UpdateModelRequest) (key *ModelUserData) {
key = &ModelUserData{

ModelId: request.ModelDetails.ModelId,
}
return
}

func (m ModelService) createModelData(request *CreateModelRequest, response *ModelDetailsResponse) (data *ModelUserData) {
data = &ModelUserData{
Status: string(response.Status),
CreatedByAddress: request.Authorization.UserAddress,
AuthorizedAddresses: request.Address,
isPublic: request.IsPubliclyAccessible,
ModelId: response.ModelDetails.ModelId,
}
return
}

func (m ModelService) CreateModel(c context.Context, request *CreateModelRequest) (response *ModelDetailsResponse,
err error) {
// verify the request that has come in
// make a call to the client
// if the response is successful , store details in etcd
// send back the response to the client
//TODO implement me
fmt.Println("Adding addresses to etcd ")
ctx, cancel := context.WithTimeout(context.Background(), time.Second*200)
defer cancel()
if client, err := getServiceClient(); err == nil {
return client.CreateModel(ctx, request)
if client, err := m.getServiceClient(); err == nil {
response, err = client.CreateModel(ctx, request)
if err == nil {
//store the details in etcd
log.Debugf("Adding addresses to etcd ...")
if err = m.storeModelDetails(request, response); err != nil {
return response, fmt.Errorf("issue with storing Model Id in the Daemon Storage %v", err)
}
}
} else {
return &ModelDetailsResponse{Status: Status_ERROR}, fmt.Errorf("Error in invoking service for Model Training")
return &ModelDetailsResponse{Status: Status_ERROR},
fmt.Errorf("error in invoking service for Model Training %v", err)
}

return
}

func (m ModelService) UpdateModelAccess(c context.Context, request *UpdateModelRequest) (*ModelDetailsResponse, error) {
func (m ModelService) UpdateModelAccess(c context.Context, request *UpdateModelRequest) (response *ModelDetailsResponse,
err error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
fmt.Println("Updating model access addresses to etcd ")
if client, err := getServiceClient(); err != nil {
return client.UpdateModelAccess(ctx, request)
if client, err := m.getServiceClient(); err != nil {
response, err = client.UpdateModelAccess(ctx, request)
if err = m.updateModelDetails(request, response); err != nil {
return response, fmt.Errorf("issue with storing Model Id in the Daemon Storage %v", err)
}
} else {
return &ModelDetailsResponse{Status: Status_ERROR}, fmt.Errorf("error in invoking service for Model Training")
}
return
}

func (m ModelService) DeleteModel(c context.Context, request *UpdateModelRequest) (*ModelDetailsResponse, error) {
fmt.Println("Deleting model addresses from etcd ")
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if client, err := getServiceClient(); err != nil {
if client, err := m.getServiceClient(); err != nil {
return client.DeleteModel(ctx, request)
} else {
return &ModelDetailsResponse{Status: Status_ERROR}, fmt.Errorf("error in invoking service for Model Training")
}
}

func (m ModelService) GetModelDetails(c context.Context, id *ModelId) (*ModelDetailsResponse, error) {
fmt.Println("Getting model Details from etcd ......")
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if client, err := getServiceClient(); err != nil {
return client.GetModelDetails(ctx, id)
} else {
return &ModelDetailsResponse{Status: Status_ERROR}, fmt.Errorf("error in invoking service method GetModelDetails for Model Training")
}
func (m ModelService) GetModelDetails(c context.Context, id *ModelDetailsRequest) (*ModelDetailsResponse, error) {
fmt.Println("Just get the model details stored in ETCD")
return &ModelDetailsResponse{Status: Status_ERROR}, fmt.Errorf("error in invoking service method GetModelDetails for Model Training")

}

func (m ModelService) GetTrainingStatus(c context.Context, id *ModelId) (*ModelDetailsResponse, error) {
func (m ModelService) GetTrainingStatus(c context.Context, id *ModelDetailsRequest) (*ModelDetailsResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
fmt.Println("Getting Training Status details from etcd .....")
fmt.Println("Update the Training Status details from etcd .....")
defer cancel()
if client, err := getServiceClient(); err != nil {
if client, err := m.getServiceClient(); err != nil {
return client.GetModelDetails(ctx, id)
} else {
return &ModelDetailsResponse{Status: Status_ERROR}, fmt.Errorf("error in invoking service method GetTrainingStatus for Model Training")
}
}

//message used to sign is of the form ("__create_model", mpe_address, current_block_number)
func (service *ModelService) verifySignerForCreateModel(request *AuthorizationDetails) error {
return utils.VerifySigner(service.getMessageBytes("__create_model", request),
request.GetSignature(), utils.ToChecksumAddress(request.UserAddress))
}

//message used to sign is of the form ("__update_model", mpe_address, current_block_number)
func (service *ModelService) verifySignerForUpdateModel(request *AuthorizationDetails) error {
return utils.VerifySigner(service.getMessageBytes("__update_model", request),
request.GetSignature(), utils.ToChecksumAddress(request.UserAddress))
}

//message used to sign is of the form ("__delete_model", mpe_address, current_block_number)
func (service *ModelService) verifySignerForDeleteModel(request *AuthorizationDetails) error {
return utils.VerifySigner(service.getMessageBytes("__delete_model", request),
request.GetSignature(), utils.ToChecksumAddress(request.UserAddress))
}

//message used to sign is of the form ("__delete_model", mpe_address, current_block_number)
func (service *ModelService) verifySignerForGetModelDetails(request *AuthorizationDetails) error {
return utils.VerifySigner(service.getMessageBytes("__get_model_details", request),
request.GetSignature(), utils.ToChecksumAddress(request.UserAddress))
}

//message used to sign is of the form ("__get_training_status", mpe_address, current_block_number)
func (service *ModelService) verifySignerForGetTrainingStatus(request *AuthorizationDetails) error {
return utils.VerifySigner(service.getMessageBytes("__get_training_status", request),
request.GetSignature(), utils.ToChecksumAddress(request.UserAddress))
}

//"__methodName", user_address, current_block_number
func (service *ModelService) getMessageBytes(prefixMessage string, request *AuthorizationDetails) []byte {
userAddress := utils.ToChecksumAddress(request.UserAddress)
message := bytes.Join([][]byte{
[]byte(prefixMessage),
userAddress.Bytes(),
abi.U256(big.NewInt(int64(request.CurrentBlock))),
}, nil)
return message
}

func NewModelService(channelService escrow.PaymentChannelService, serMetaData *blockchain.ServiceMetadata,
orgMetadata *blockchain.OrganizationMetaData, storage *ModelStorage) *ModelService {
return &ModelService{
channelService: channelService,
serviceMetaData: serMetaData,
organizationMetaData: orgMetadata,
storage: storage,
orgMetadata *blockchain.OrganizationMetaData, storage *ModelStorage) ModelServer {
serviceURL := config.GetString(config.ModelTrainingEndpoint)
if config.IsValidUrl(serviceURL) {
return &ModelService{
channelService: channelService,
serviceMetaData: serMetaData,
organizationMetaData: orgMetadata,
storage: storage,
serviceUrl: serviceURL,
}
} else {
return &NoModelSupportService{}
}
}

Expand Down
6 changes: 3 additions & 3 deletions training/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"fmt"
"github.com/singnet/snet-daemon/storage"
"github.com/singnet/snet-daemon/utils"
"math/big"
"reflect"
)

Expand All @@ -25,20 +24,21 @@ type ModelUserKey struct {
OrganizationId string
ServiceId string
GroupID string
ChannelId *big.Int
MethodName string
ModelId string
}

func (key *ModelUserKey) String() string {
return fmt.Sprintf("{ID:%v/%v/%v/%v/%v}", key.OrganizationId,
key.ServiceId, key.GroupID, key.ChannelId, key.ModelId)
key.ServiceId, key.GroupID, key.MethodName, key.ModelId)
}

type ModelUserData struct {
isPublic bool
AuthorizedAddresses []string
Status string
CreatedByAddress string
ModelId string
}

func serializeModelKey(key interface{}) (serialized string, err error) {
Expand Down
21 changes: 14 additions & 7 deletions training/training.proto
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
syntax = "proto3";
package training;
//Please note that the AI developers need to provide a server implementation of the gprc server of this proto.
message ModelId {
message ModelDetails {
string model_id = 1;
string method_name = 2;
AuthorizationDetails authorization = 5;
string description = 3;

}

message AuthorizationDetails {
Expand Down Expand Up @@ -34,12 +35,19 @@ message CreateModelRequest {
bool is_publicly_accessible = 4;

AuthorizationDetails authorization = 5;

string model_description = 6;
}

message ModelDetailsRequest {
ModelDetails model_details =1 ;
AuthorizationDetails authorization = 2;
}


message UpdateModelRequest {
AuthorizedAddressList addressList = 1;
string model_id = 2;
ModelDetails model_details = 2;
Status status = 3;
string message = 4;
bytes signature = 5;
Expand All @@ -54,10 +62,9 @@ message AuthorizedAddressList {

message ModelDetailsResponse {
Status status = 1;
string model_id =2;
ModelDetails model_details =2;
string message = 3;
AuthorizedAddressList address_list = 4;
string method_name =5;
//Added the below two for more details
string service_id = 6;
string organization_id = 7;
Expand All @@ -72,7 +79,7 @@ service Model {
rpc create_model(CreateModelRequest) returns (ModelDetailsResponse) {}
rpc update_model_access(UpdateModelRequest) returns (ModelDetailsResponse) {}
rpc delete_model(UpdateModelRequest) returns (ModelDetailsResponse) {}
rpc get_model_details(ModelId) returns (ModelDetailsResponse) {}
rpc get_training_status (ModelId) returns (ModelDetailsResponse) {}
rpc get_model_details(ModelDetailsRequest) returns (ModelDetailsResponse) {}
rpc get_training_status (ModelDetailsRequest) returns (ModelDetailsResponse) {}

}
Loading

0 comments on commit b0cc8df

Please sign in to comment.