From 7f5d8db856ce10a7460a5b8a77d32c246782f8bf Mon Sep 17 00:00:00 2001 From: Ryan Yanulites Date: Thu, 21 Nov 2024 09:35:42 -0700 Subject: [PATCH] initial txns support with wrapper func --- service/pkg/db/db.go | 1 + service/policy/attributes/attributes.go | 27 ++++++++++++++++--------- service/policy/db/policy.go | 19 +++++++++++++++++ 3 files changed, 37 insertions(+), 10 deletions(-) diff --git a/service/pkg/db/db.go b/service/pkg/db/db.go index 038ee0543..3db39d11e 100644 --- a/service/pkg/db/db.go +++ b/service/pkg/db/db.go @@ -54,6 +54,7 @@ func (t Table) Field(field string) string { // We can rename this but wanted to get mocks working. type PgxIface interface { Acquire(ctx context.Context) (*pgxpool.Conn, error) + Begin(ctx context.Context) (pgx.Tx, error) Exec(context.Context, string, ...any) (pgconn.CommandTag, error) QueryRow(context.Context, string, ...any) pgx.Row Query(context.Context, string, ...any) (pgx.Rows, error) diff --git a/service/policy/attributes/attributes.go b/service/policy/attributes/attributes.go index da59557ff..80cc9038d 100644 --- a/service/policy/attributes/attributes.go +++ b/service/policy/attributes/attributes.go @@ -54,19 +54,26 @@ func (s AttributesService) CreateAttribute(ctx context.Context, ActionType: audit.ActionTypeCreate, } - item, err := s.dbClient.CreateAttribute(ctx, req.Msg) + err := s.dbClient.RunInTx(ctx, func(txClient *policydb.PolicyDBClient) error { + item, err := s.dbClient.CreateAttribute(ctx, req.Msg) + if err != nil { + s.logger.Audit.PolicyCRUDFailure(ctx, auditParams) + return db.StatusifyError(err, db.ErrTextCreationFailed, slog.String("attribute", req.Msg.String())) + } + + s.logger.Debug("created new attribute definition", slog.String("name", req.Msg.GetName())) + + auditParams.ObjectID = item.GetId() + auditParams.Original = item + s.logger.Audit.PolicyCRUDSuccess(ctx, auditParams) + + rsp.Attribute = item + return nil + }) if err != nil { - s.logger.Audit.PolicyCRUDFailure(ctx, auditParams) - return nil, db.StatusifyError(err, db.ErrTextCreationFailed, slog.String("attribute", req.Msg.String())) + return nil, err } - s.logger.Debug("created new attribute definition", slog.String("name", req.Msg.GetName())) - - auditParams.ObjectID = item.GetId() - auditParams.Original = item - s.logger.Audit.PolicyCRUDSuccess(ctx, auditParams) - - rsp.Attribute = item return connect.NewResponse(rsp), nil } diff --git a/service/policy/db/policy.go b/service/policy/db/policy.go index 390362b3d..221ea0fa5 100644 --- a/service/policy/db/policy.go +++ b/service/policy/db/policy.go @@ -1,6 +1,9 @@ package db import ( + "context" + "fmt" + "github.com/opentdf/platform/protocol/go/common" "github.com/opentdf/platform/service/logger" "github.com/opentdf/platform/service/pkg/db" @@ -31,6 +34,22 @@ func NewClient(c *db.Client, logger *logger.Logger, configuredListLimitMax, conf return PolicyDBClient{c, logger, New(c.Pgx), ListConfig{limitDefault: configuredListLimitDefault, limitMax: configuredListLimitMax}} } +func (c *PolicyDBClient) RunInTx(ctx context.Context, query func(txClient *PolicyDBClient) error) error { + tx, err := c.Client.Pgx.Begin(ctx) + if err != nil { + return fmt.Errorf("failed to begin DB transaction: %w", err) + } + + txClient := &PolicyDBClient{c.Client, c.logger, c.Queries, c.listCfg} + + err = query(txClient) + if err != nil { + return tx.Rollback(ctx) + } + + return tx.Commit(ctx) +} + func getDBStateTypeTransformedEnum(state common.ActiveStateEnum) transformedState { switch state.String() { case common.ActiveStateEnum_ACTIVE_STATE_ENUM_ACTIVE.String():