forked from gorgonia/gorgonia
-
Notifications
You must be signed in to change notification settings - Fork 0
/
nn_test.go
58 lines (47 loc) · 1.39 KB
/
nn_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
package gorgonia
import (
"io/ioutil"
"runtime"
"testing"
"github.com/chewxy/gorgonia/tensor"
)
func dropoutTest(t *testing.T, dt tensor.Dtype) error {
g := NewGraph()
x := NewVector(g, dt, WithShape(10), WithName("x"), WithInit(RangedFrom(0)))
w := NewMatrix(g, dt, WithShape(20, 10), WithName("w"), WithInit(RangedFrom(0)))
w2 := NewMatrix(g, dt, WithShape(10, 20), WithName("w2"), WithInit(RangedFrom(0)))
wx := Must(Mul(w, x))
act := Must(Cube(wx))
do := Must(Dropout(act, 0.5))
act2 := Must(Cube(Must(Mul(w2, do))))
do2 := Must(Dropout(act2, 0.1))
cost := Must(Sum(do2))
_, err := Grad(cost, x, w, w2)
if err != nil {
ioutil.WriteFile("fullGraph.dot", []byte(g.ToDot()), 0644)
// t.Fatalf("%+v", err)
return err
}
prog, locMap, err := Compile(g)
// t.Logf("prog: %v", prog)
// logger := log.New(os.Stderr, "", 0)
// m := NewTapeMachine(prog, locMap, TraceExec(), BindDualValues(), WithLogger(logger), WithWatchlist())
m := NewTapeMachine(prog, locMap, TraceExec(), BindDualValues())
defer runtime.GC()
if err := m.RunAll(); err != nil {
// t.Errorf("%+v", err)
return err
}
return nil
}
func TestDropout(t *testing.T) {
// t.Skip()
if err := dropoutTest(t, Float64); err != nil {
t.Errorf("%+v", err)
}
if err := dropoutTest(t, Float32); err != nil {
t.Errorf("%+v", err)
}
// visual inspection
// ioutil.WriteFile("fullGraph.dot", []byte(g.ToDot()), 0644)
}