forked from writefreely/writefreely
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_test.go
153 lines (134 loc) · 3.58 KB
/
main_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
package writefreely
import (
"context"
"database/sql"
"encoding/gob"
"errors"
"fmt"
uuid "github.com/nu7hatch/gouuid"
"github.com/stretchr/testify/assert"
"math/rand"
"os"
"strings"
"testing"
"time"
)
var testDB *sql.DB
type ScopedTestBody func(*sql.DB)
// TestMain provides testing infrastructure within this package.
func TestMain(m *testing.M) {
rand.Seed(time.Now().UTC().UnixNano())
gob.Register(&User{})
if runMySQLTests() {
var err error
testDB, err = initMySQL(os.Getenv("WF_USER"), os.Getenv("WF_PASSWORD"), os.Getenv("WF_DB"), os.Getenv("WF_HOST"))
if err != nil {
fmt.Println(err)
return
}
}
code := m.Run()
if runMySQLTests() {
if closeErr := testDB.Close(); closeErr != nil {
fmt.Println(closeErr)
}
}
os.Exit(code)
}
func runMySQLTests() bool {
return len(os.Getenv("TEST_MYSQL")) > 0
}
func initMySQL(dbUser, dbPassword, dbName, dbHost string) (*sql.DB, error) {
if dbUser == "" || dbPassword == "" {
return nil, errors.New("database user or password not set")
}
if dbHost == "" {
dbHost = "localhost"
}
if dbName == "" {
dbName = "writefreely"
}
dsn := fmt.Sprintf("%s:%s@tcp(%s:3306)/%s?charset=utf8mb4&parseTime=true", dbUser, dbPassword, dbHost, dbName)
db, err := sql.Open("mysql", dsn)
if err != nil {
return nil, err
}
if err := ensureMySQL(db); err != nil {
return nil, err
}
return db, nil
}
func ensureMySQL(db *sql.DB) error {
if err := db.Ping(); err != nil {
return err
}
db.SetMaxOpenConns(250)
return nil
}
// withTestDB provides a scoped database connection.
func withTestDB(t *testing.T, testBody ScopedTestBody) {
db, cleanup, err := newTestDatabase(testDB,
os.Getenv("WF_USER"),
os.Getenv("WF_PASSWORD"),
os.Getenv("WF_DB"),
os.Getenv("WF_HOST"),
)
assert.NoError(t, err)
defer func() {
assert.NoError(t, cleanup())
}()
testBody(db)
}
// newTestDatabase creates a new temporary test database. When a test
// database connection is returned, it will have created a new database and
// initialized it with tables from a reference database.
func newTestDatabase(base *sql.DB, dbUser, dbPassword, dbName, dbHost string) (*sql.DB, func() error, error) {
var err error
var baseName = dbName
if baseName == "" {
row := base.QueryRow("SELECT DATABASE()")
err := row.Scan(&baseName)
if err != nil {
return nil, nil, err
}
}
tUUID, _ := uuid.NewV4()
suffix := strings.Replace(tUUID.String(), "-", "_", -1)
newDBName := baseName + suffix
_, err = base.Exec("CREATE DATABASE " + newDBName)
if err != nil {
return nil, nil, err
}
newDB, err := initMySQL(dbUser, dbPassword, newDBName, dbHost)
if err != nil {
return nil, nil, err
}
rows, err := base.Query("SHOW TABLES IN " + baseName)
if err != nil {
return nil, nil, err
}
for rows.Next() {
var tableName string
if err := rows.Scan(&tableName); err != nil {
return nil, nil, err
}
query := fmt.Sprintf("CREATE TABLE %s LIKE %s.%s", tableName, baseName, tableName)
if _, err := newDB.Exec(query); err != nil {
return nil, nil, err
}
}
cleanup := func() error {
if closeErr := newDB.Close(); closeErr != nil {
fmt.Println(closeErr)
}
_, err = base.Exec("DROP DATABASE " + newDBName)
return err
}
return newDB, cleanup, nil
}
func countRows(t *testing.T, ctx context.Context, db *sql.DB, count int, query string, args ...interface{}) {
var returned int
err := db.QueryRowContext(ctx, query, args...).Scan(&returned)
assert.NoError(t, err, "error executing query %s and args %s", query, args)
assert.Equal(t, count, returned, "unexpected return count %d, expected %d from %s and args %s", returned, count, query, args)
}