Skip to content

Commit

Permalink
add global lock
Browse files Browse the repository at this point in the history
  • Loading branch information
junqiang.zhang committed Sep 8, 2022
1 parent db91132 commit 7c85447
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
12 changes: 12 additions & 0 deletions gdi.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ const (

//GDIPool 依赖注入容器
type GDIPool struct {
lock sync.Mutex //global lock
debug bool
scanPkgPaths []string
ignoreInterface bool
Expand Down Expand Up @@ -59,6 +60,7 @@ func NewGDIPool() *GDIPool {
modName := "main"
pool := &GDIPool{
debug: true,
lock: sync.Mutex{},
scanPkgPaths: []string{modName},
ignoreInterface: false,
autoCreate: true,
Expand Down Expand Up @@ -243,6 +245,8 @@ func (gdi *GDIPool) RegisterReadOnly(funcObjOrPtrs ...interface{}) {

//Init 在使用前必须先调用它
func (gdi *GDIPool) Init() *GDIPool {
gdi.lock.Lock()
defer gdi.lock.Unlock()

for k := 0; k < len(gdi.creator)*len(gdi.creator); k++ {
i := len(gdi.creator)
Expand Down Expand Up @@ -396,10 +400,12 @@ func (gdi *GDIPool) build(v reflect.Value, exitOnError bool, buildForTest bool)
if field.IsValid() && !field.IsNil() {
//TODO may be pannic
if field.Elem().Kind() == reflect.Struct {
gdi.ttvLocker.Lock()
if _, ok := gdi.typeToValuesForTest[field.Type()]; !ok {
gdi.log(fmt.Sprintf("Register field type %v from %v pkgPath:%v", field.Type(), v.Type(), pkgPath))
gdi.typeToValuesForTest[field.Type()] = field
}
gdi.ttvLocker.Unlock()
}
continue
}
Expand Down Expand Up @@ -459,11 +465,13 @@ func (gdi *GDIPool) build(v reflect.Value, exitOnError bool, buildForTest bool)
continue
} else {
if buildForTest {
gdi.ttvLocker.Lock()
if fv, ok := gdi.typeToValuesForTest[field.Type()]; ok {
gdi.warn(fmt.Sprintf("inject For Test fieldName:%v->%v of %v pkgPath:%v", fieldName, field.Type(), v.Type(), pkgPath))
field.Set(fv)
continue
}
gdi.ttvLocker.Unlock()
}
if gdi.autoCreate {
value := reflect.New(field.Type().Elem())
Expand Down Expand Up @@ -585,6 +593,8 @@ func (gdi *GDIPool) GetAllTypesByPackName(packageRegexp string) ([]reflect.Type,

// DIForTest 自动依懒注入
func (gdi *GDIPool) DIForTest(pointer interface{}) (e error) {
gdi.lock.Lock()
defer gdi.lock.Unlock()
defer func() {
if err := recover(); err != nil {
gdi.warn(fmt.Sprintf("%v", err))
Expand All @@ -609,6 +619,8 @@ func (gdi *GDIPool) DIForTest(pointer interface{}) (e error) {

// DI 自动依懒注入
func (gdi *GDIPool) DI(pointer interface{}) (e error) {
gdi.lock.Lock()
defer gdi.lock.Unlock()
defer func() {
if err := recover(); err != nil {
gdi.warn(fmt.Sprintf("%v", err))
Expand Down
10 changes: 9 additions & 1 deletion gdi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gdi
import (
"fmt"
"regexp"
"sync"
"testing"
)

Expand Down Expand Up @@ -148,7 +149,14 @@ func TestAll3(t *testing.T) {
Register(&c)
Init()

DIForTest(&s)
wg:=sync.WaitGroup{}
wg.Add(1)
go func() {
DIForTest(&s)
wg.Done()
}()
go DIForTest(&s)
wg.Wait()

if s.Student.Name != c.Student.Name {
t.Fail()
Expand Down

0 comments on commit 7c85447

Please sign in to comment.