diff --git a/handlers/bounty.go b/handlers/bounty.go index 885391c93..5a1e5d13a 100644 --- a/handlers/bounty.go +++ b/handlers/bounty.go @@ -606,15 +606,14 @@ func formatPayError(errorMsg string) db.InvoicePayError { } } -func GetLightningInvoice(payment_request string) (db.InvoiceResult, db.InvoiceError) { +func (h *bountyHandler) GetLightningInvoice(payment_request string) (db.InvoiceResult, db.InvoiceError) { url := fmt.Sprintf("%s/invoice?payment_request=%s", config.RelayUrl, payment_request) - client := &http.Client{} req, err := http.NewRequest(http.MethodGet, url, nil) req.Header.Set("x-user-token", config.RelayAuthKey) req.Header.Set("Content-Type", "application/json") - res, _ := client.Do(req) + res, _ := h.httpClient.Do(req) if err != nil { log.Printf("Request Failed: %s", err) @@ -693,9 +692,9 @@ func (h *bountyHandler) PayLightningInvoice(payment_request string) (db.InvoiceP } } -func GetInvoiceData(w http.ResponseWriter, r *http.Request) { +func (h *bountyHandler) GetInvoiceData(w http.ResponseWriter, r *http.Request) { paymentRequest := chi.URLParam(r, "paymentRequest") - invoiceData, invoiceErr := GetLightningInvoice(paymentRequest) + invoiceData, invoiceErr := h.GetLightningInvoice(paymentRequest) if invoiceErr.Error != "" { w.WriteHeader(http.StatusForbidden) @@ -707,7 +706,7 @@ func GetInvoiceData(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(invoiceData) } -func PollInvoice(w http.ResponseWriter, r *http.Request) { +func (h *bountyHandler) PollInvoice(w http.ResponseWriter, r *http.Request) { ctx := r.Context() pubKeyFromAuth, _ := ctx.Value(auth.ContextKey).(string) paymentRequest := chi.URLParam(r, "paymentRequest") @@ -719,7 +718,7 @@ func PollInvoice(w http.ResponseWriter, r *http.Request) { return } - invoiceRes, invoiceErr := GetLightningInvoice(paymentRequest) + invoiceRes, invoiceErr := h.GetLightningInvoice(paymentRequest) if invoiceErr.Error != "" { w.WriteHeader(http.StatusForbidden) @@ -729,14 +728,14 @@ func PollInvoice(w http.ResponseWriter, r *http.Request) { if invoiceRes.Response.Settled { // Todo if an invoice is settled - invoice := db.DB.GetInvoice(paymentRequest) - invData := db.DB.GetUserInvoiceData(paymentRequest) - dbInvoice := db.DB.GetInvoice(paymentRequest) + invoice := h.db.GetInvoice(paymentRequest) + invData := h.db.GetUserInvoiceData(paymentRequest) + dbInvoice := h.db.GetInvoice(paymentRequest) // Make any change only if the invoice has not been settled if !dbInvoice.Status { if invoice.Type == "BUDGET" { - db.DB.AddAndUpdateBudget(invoice) + h.db.AddAndUpdateBudget(invoice) } else if invoice.Type == "KEYSEND" { url := fmt.Sprintf("%s/payment", config.RelayUrl) @@ -746,12 +745,11 @@ func PollInvoice(w http.ResponseWriter, r *http.Request) { jsonBody := []byte(bodyData) - client := &http.Client{} req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody)) req.Header.Set("x-user-token", config.RelayAuthKey) req.Header.Set("Content-Type", "application/json") - res, _ := client.Do(req) + res, _ := h.httpClient.Do(req) if err != nil { log.Printf("Request Failed: %s", err) @@ -767,13 +765,13 @@ func PollInvoice(w http.ResponseWriter, r *http.Request) { keysendRes := db.KeysendSuccess{} err = json.Unmarshal(body, &keysendRes) - bounty, err := db.DB.GetBountyByCreated(uint(invData.Created)) + bounty, err := h.db.GetBountyByCreated(uint(invData.Created)) if err == nil { bounty.Paid = true } - db.DB.UpdateBounty(bounty) + h.db.UpdateBounty(bounty) } else { // Unmarshal result keysendError := db.KeysendError{} @@ -782,7 +780,7 @@ func PollInvoice(w http.ResponseWriter, r *http.Request) { } } // Update the invoice status - db.DB.UpdateInvoice(paymentRequest) + h.db.UpdateInvoice(paymentRequest) } } diff --git a/handlers/bounty_test.go b/handlers/bounty_test.go index 410deb099..4a1a9291d 100644 --- a/handlers/bounty_test.go +++ b/handlers/bounty_test.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/stakwork/sphinx-tribes/utils" "io" "net/http" "net/http/httptest" @@ -16,6 +15,8 @@ import ( "testing" "time" + "github.com/stakwork/sphinx-tribes/utils" + "github.com/go-chi/chi" "github.com/lib/pq" "github.com/stakwork/sphinx-tribes/auth" @@ -1431,3 +1432,86 @@ func TestBountyBudgetWithdraw(t *testing.T) { mockHttpClient.AssertCalled(t, "Do", mock.AnythingOfType("*http.Request")) }) } + +func TestPollInvoice(t *testing.T) { + ctx := context.Background() + mockDb := &dbMocks.Database{} + mockHttpClient := &mocks.HttpClient{} + bHandler := NewBountyHandler(mockHttpClient, mockDb) + + unauthorizedCtx := context.WithValue(ctx, auth.ContextKey, "") + authorizedCtx := context.WithValue(ctx, auth.ContextKey, "valid-key") + + t.Run("Should test that a 401 error is returned if a user is unauthorized", func(t *testing.T) { + r := chi.NewRouter() + r.Post("/poll/invoice/{paymentRequest}", bHandler.PollInvoice) + + rr := httptest.NewRecorder() + req, err := http.NewRequestWithContext(unauthorizedCtx, http.MethodPost, "/poll/invoice/1", bytes.NewBufferString(`{}`)) + if err != nil { + t.Fatal(err) + } + + r.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusUnauthorized, rr.Code, "Expected 401 error if a user is unauthorized") + }) + + t.Run("Should test that a 403 error is returned if there is an invoice error", func(t *testing.T) { + expectedUrl := fmt.Sprintf("%s/invoice?payment_request=%s", config.RelayUrl, "1") + + r := io.NopCloser(bytes.NewReader([]byte(`{"success": false, "error": "Internel server error"}`))) + mockHttpClient.On("Do", mock.MatchedBy(func(req *http.Request) bool { + return req.Method == http.MethodGet && expectedUrl == req.URL.String() && req.Header.Get("x-user-token") == config.RelayAuthKey + })).Return(&http.Response{ + StatusCode: 500, + Body: r, + }, nil).Once() + + ro := chi.NewRouter() + ro.Post("/poll/invoice/{paymentRequest}", bHandler.PollInvoice) + + rr := httptest.NewRecorder() + req, err := http.NewRequestWithContext(authorizedCtx, http.MethodPost, "/poll/invoice/1", bytes.NewBufferString(`{}`)) + if err != nil { + t.Fatal(err) + } + + ro.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusForbidden, rr.Code, "Expected 403 error if there is an invoice error") + mockHttpClient.AssertExpectations(t) + }) + + t.Run("If the invoice is settled and the invoice.Type is equal to BUDGET the invoice amount should be added to the organization budget and the payment status of the related invoice should be sent to true on the payment history table", func(t *testing.T) { + expectedUrl := fmt.Sprintf("%s/invoice?payment_request=%s", config.RelayUrl, "1") + + r := io.NopCloser(bytes.NewReader([]byte(`{"success": true, "response": { "settled": true, "payment_request": "1", "payment_hash": "payment_hash", "preimage": "preimage", "Amount": "1000"}}`))) + mockHttpClient.On("Do", mock.MatchedBy(func(req *http.Request) bool { + return req.Method == http.MethodGet && expectedUrl == req.URL.String() && req.Header.Get("x-user-token") == config.RelayAuthKey + })).Return(&http.Response{ + StatusCode: 200, + Body: r, + }, nil).Once() + + mockDb.On("GetInvoice", "1").Return(db.InvoiceList{Type: "BUDGET"}) + mockDb.On("GetUserInvoiceData", "1").Return(db.UserInvoiceData{Amount: 1000, UserPubkey: "UserPubkey", RouteHint: "RouteHint", Created: 1234}) + mockDb.On("GetInvoice", "1").Return(db.InvoiceList{Status: false}) + mockDb.On("AddAndUpdateBudget", mock.Anything).Return(db.PaymentHistory{}) + mockDb.On("UpdateInvoice", "1").Return(db.InvoiceList{}).Once() + + ro := chi.NewRouter() + ro.Post("/poll/invoice/{paymentRequest}", bHandler.PollInvoice) + + rr := httptest.NewRecorder() + req, err := http.NewRequestWithContext(authorizedCtx, http.MethodPost, "/poll/invoice/1", bytes.NewBufferString(`{}`)) + if err != nil { + t.Fatal(err) + } + + ro.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + mockHttpClient.AssertExpectations(t) + }) +} diff --git a/handlers/organizations.go b/handlers/organizations.go index 87a5dcdfe..7022c7eb5 100644 --- a/handlers/organizations.go +++ b/handlers/organizations.go @@ -18,6 +18,7 @@ import ( type organizationHandler struct { db db.Database generateBountyHandler func(bounties []db.Bounty) []db.BountyResponse + getLightningInvoice func(payment_request string) (db.InvoiceResult, db.InvoiceError) } func NewOrganizationHandler(db db.Database) *organizationHandler { @@ -25,6 +26,7 @@ func NewOrganizationHandler(db db.Database) *organizationHandler { return &organizationHandler{ db: db, generateBountyHandler: bHandler.GenerateBountyResponse, + getLightningInvoice: bHandler.GetLightningInvoice, } } @@ -630,7 +632,7 @@ func GetPaymentHistory(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(paymentHistoryData) } -func PollBudgetInvoices(w http.ResponseWriter, r *http.Request) { +func (oh *organizationHandler) PollBudgetInvoices(w http.ResponseWriter, r *http.Request) { ctx := r.Context() pubKeyFromAuth, _ := ctx.Value(auth.ContextKey).(string) uuid := chi.URLParam(r, "uuid") @@ -641,10 +643,10 @@ func PollBudgetInvoices(w http.ResponseWriter, r *http.Request) { return } - orgInvoices := db.DB.GetOrganizationInvoices(uuid) + orgInvoices := oh.db.GetOrganizationInvoices(uuid) for _, inv := range orgInvoices { - invoiceRes, invoiceErr := GetLightningInvoice(inv.PaymentRequest) + invoiceRes, invoiceErr := oh.getLightningInvoice(inv.PaymentRequest) if invoiceErr.Error != "" { w.WriteHeader(http.StatusForbidden) @@ -654,9 +656,9 @@ func PollBudgetInvoices(w http.ResponseWriter, r *http.Request) { if invoiceRes.Response.Settled { if !inv.Status && inv.Type == "BUDGET" { - db.DB.AddAndUpdateBudget(inv) + oh.db.AddAndUpdateBudget(inv) // Update the invoice status - db.DB.UpdateInvoice(inv.PaymentRequest) + oh.db.UpdateInvoice(inv.PaymentRequest) } } } diff --git a/routes/bounty.go b/routes/bounty.go index f25e6050e..6e5049747 100644 --- a/routes/bounty.go +++ b/routes/bounty.go @@ -25,7 +25,7 @@ func BountyRoutes() chi.Router { r.Get("/created/{created}", bountyHandler.GetBountyByCreated) r.Get("/count/{personKey}/{tabType}", handlers.GetUserBountyCount) r.Get("/count", handlers.GetBountyCount) - r.Get("/invoice/{paymentRequest}", handlers.GetInvoiceData) + r.Get("/invoice/{paymentRequest}", bountyHandler.GetInvoiceData) r.Get("/filter/count", handlers.GetFilterCount) }) diff --git a/routes/index.go b/routes/index.go index 3efe324fa..1615fa30e 100644 --- a/routes/index.go +++ b/routes/index.go @@ -24,6 +24,7 @@ func NewRouter() *http.Server { authHandler := handlers.NewAuthHandler(db.DB) channelHandler := handlers.NewChannelHandler(db.DB) botHandler := handlers.NewBotHandler(db.DB) + bHandler := handlers.NewBountyHandler(http.DefaultClient, db.DB) r.Mount("/tribes", TribeRoutes()) r.Mount("/bots", BotsRoutes()) @@ -75,7 +76,7 @@ func NewRouter() *http.Server { r.Post("/badges", handlers.AddOrRemoveBadge) r.Delete("/channel/{id}", channelHandler.DeleteChannel) r.Delete("/ticket/{pubKey}/{created}", handlers.DeleteTicketByAdmin) - r.Get("/poll/invoice/{paymentRequest}", handlers.PollInvoice) + r.Get("/poll/invoice/{paymentRequest}", bHandler.PollInvoice) r.Post("/meme_upload", handlers.MemeImageUpload) r.Get("/admin/auth", authHandler.GetIsAdmin) }) diff --git a/routes/organizations.go b/routes/organizations.go index feddba0df..66041c99a 100644 --- a/routes/organizations.go +++ b/routes/organizations.go @@ -35,7 +35,7 @@ func OrganizationRoutes() chi.Router { r.Get("/budget/{uuid}", organizationHandlers.GetOrganizationBudget) r.Get("/budget/history/{uuid}", organizationHandlers.GetOrganizationBudgetHistory) r.Get("/payments/{uuid}", handlers.GetPaymentHistory) - r.Get("/poll/invoices/{uuid}", handlers.PollBudgetInvoices) + r.Get("/poll/invoices/{uuid}", organizationHandlers.PollBudgetInvoices) r.Get("/invoices/count/{uuid}", handlers.GetInvoicesCount) r.Delete("/delete/{uuid}", organizationHandlers.DeleteOrganization) })