Skip to content

Commit

Permalink
Merge pull request #636 from denizyuret/dy/fix632
Browse files Browse the repository at this point in the history
CuArray.ptr => CuArray.baseptr or pointer(CuArray) in CUDA-2.1.0
  • Loading branch information
denizyuret authored Nov 28, 2020
2 parents aa3a861 + b941a9b commit 3b6f70a
Show file tree
Hide file tree
Showing 22 changed files with 196 additions and 101 deletions.
13 changes: 13 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
Knet v1.4.4 Release Notes
=========================

* Serialization and JLD2 support for KnetArray and RNN.
* Change eltype to Any in container types in serialize.jl.
* Compat fixes with CUDA 2.3 and Julia 1.6.
* Fixed #638 causing KnetArray broadcast/materialize!/dotview issue.
* Fixed Knet.seed! bug. (@egeonat)
* Added powerpc support. (@jdad)
* Fixed mnist labels in examples.


Knet v1.4.3 Release Notes
=========================
8a4fdbf 2020-10-16

* Upgrade to CUDA 2.0.
* Doc fixes.
Expand Down
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Knet"
uuid = "1902f260-5fb4-5aff-8c31-6271790ab950"
authors = ["Deniz Yuret <denizyuret@gmail.com>"]
version = "1.4.3"
version = "1.4.4"

[deps]
AutoGrad = "6710c13c-97f1-543f-91c5-74e8f7d95b35"
Expand All @@ -14,16 +14,17 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
AutoGrad = "1.2"
CUDA = "1.0, 2.0"
FileIO = "1.0"
JLD2 = "0.1, 0.2"
JLD2 = "0.1, 0.2, 0.3"
NNlib = "0.6, 0.7"
SpecialFunctions = "0.8, 0.9, 0.10"
SpecialFunctions = "0.8, 0.9, 0.10, 1.0"
julia = "1.0"

[extras]
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,4 @@ Knet is an open-source project and we are always open to new contributions: bug
fixes, feature requests and contributions, new machine learning models and operators,
inspiring examples, benchmarking results are all welcome. See [Tips for Developers](https://denizyuret.github.io/Knet.jl/latest/install/#Tips-for-developers) for instructions.

Contributors: Can Gümeli, Carlo Lucibello, Ekin Akyürek, Ekrem Emre Yurdakul, Emre Ünal, Emre Yolcu, Enis Berk, Erenay Dayanık, İlker Kesen, Kai Xu, Meriç Melike Softa, Mike Innes, Onur Kuru, Ozan Arkan Can, Ömer Kırnap, Phuoc Nguyen, Rene Donner, Tim Besard, Zhang Shiwei.
Contributors: Can Gümeli, Carlo Lucibello, Ege Onat, Ekin Akyürek, Ekrem Emre Yurdakul, Emre Ünal, Emre Yolcu, Enis Berk, Erenay Dayanık, İlker Kesen, Kai Xu, Meriç Melike Softa, Mike Innes, Onur Kuru, Ozan Arkan Can, Ömer Kırnap, Phuoc Nguyen, Rene Donner, Tim Besard, Zhang Shiwei.
2 changes: 0 additions & 2 deletions src/Knet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ include("ops20/Ops20.jl")
include("ops20_gpu/Ops20_gpu.jl")
include("ops21/Ops21.jl")
include("ops21_gpu/Ops21_gpu.jl")
include("fileio_gpu/FileIO_gpu.jl")
include("train20/Train20.jl")
# include("layers21/Layers21.jl")

Expand All @@ -40,7 +39,6 @@ end
using AutoGrad #: @diff, AutoGrad, Param, cat1d, grad, gradloss, params, value
using Knet.LibKnet8 #: libknet8, @knet8, @knet8r, gpu
using Knet.KnetArrays #: KnetArray, gc, knetgc, ka, setseed, seed!
using Knet.FileIO_gpu #: cpucopy, gpucopy
using Knet.Ops20 #: RNN, accuracy, batchnorm, bce, bmm, bnmoments, bnparams, conv4, deconv4, dropout, elu, invx, logistic, logp, logsoftmax, logsumexp, mat, nll, pool, relu, rnnforw, rnninit, rnnparam, rnnparams, selu, sigm, softmax, unpool, zeroone
using Knet.Train20 #: Adadelta, Adagrad, Adam, Momentum, Nesterov, Rmsprop, SGD, Sgd, adadelta, adadelta!, adagrad, adagrad!, adam, adam!, atype, bilinear, converge, converge!, gaussian, goldensection, hyperband, minibatch, momentum, momentum!, nesterov, nesterov!, optimizers, param, param0, progress, progress!, rmsprop, rmsprop!, sgd, sgd!, train!, training, update!, xavier, xavier_normal, xavier_uniform

Expand Down
1 change: 0 additions & 1 deletion src/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ to models built with older operator / layer sets.

* **autograd_gpu:** implementations of AutoGrad functions for GPU arrays.
* **cuarrays:** implementations of Base functions for CuArrays.
* **fileio_gpu:** implementations of FileIO functions for GPU arrays.
* **knetarrays:** KnetArrays and their Base functions.
* **libknet8:** hand-written CUDA kernels.
* **ops20:** the Knet.Ops20 operator set.
Expand Down
1 change: 1 addition & 0 deletions src/cuarrays/CuArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ include("convert.jl")
include("getindex.jl")
include("reduction.jl")
include("cubytes.jl"); export cuarrays, cubytes
include("jld2.jl")

end
8 changes: 8 additions & 0 deletions src/cuarrays/jld2.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using JLD2
using CUDA: CuArray

struct JLD2CuArray{T,N}; array::Array{T,N}; end
JLD2.writeas(::Type{CuArray{T,N}}) where {T,N} = JLD2CuArray{T,N}
JLD2.wconvert(::Type{JLD2CuArray{T,N}}, x::CuArray{T,N}) where {T,N} = JLD2CuArray(Array(x))
JLD2.rconvert(::Type{CuArray{T,N}}, x::JLD2CuArray{T,N}) where {T,N} = CuArray(x.array)

13 changes: 0 additions & 13 deletions src/fileio_gpu/FileIO_gpu.jl

This file was deleted.

4 changes: 0 additions & 4 deletions src/fileio_gpu/README.md

This file was deleted.

4 changes: 4 additions & 0 deletions src/knetarrays/KnetArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ include("broadcast.jl")
include("cat.jl")
include("comparison.jl")
include("copy.jl")
include("deepcopy.jl"); export cpucopy, gpucopy
include("dotview.jl")
include("linalg.jl")
include("random.jl"); export setseed, seed!
Expand All @@ -20,4 +21,7 @@ include("binary.jl")
include("unary.jl")
include("reduction.jl")

include("serialization.jl") # serialize and deserialize of KnetArrays
include("jld2.jl"); export save, load, @save, @load # deprecated, use FileIO and JLD2

end
56 changes: 29 additions & 27 deletions src/fileio_gpu/serialize.jl → src/knetarrays/deepcopy.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
using Knet.KnetArrays: KnetPtr, KnetArray, Cptr
using Knet.Ops20: RNN
using AutoGrad: Param
using CUDA: CUDA, functional, CuPtr
using CUDA: CUDA, functional, CuPtr, CuArray

const JLDMODE=Val(0)
const GPUMODE=Val(1)
const CPUMODE=Val(2)

# Do not use type asserts because type may change
serialize(x) = _ser(x,IdDict(),JLDMODE)
jld2serialize(x) = _ser(x,IdDict(),JLDMODE) # Regular JLD2 functions now work, this is no longer needed
gpucopy(x) = _ser(x,IdDict(),GPUMODE)
cpucopy(x) = _ser(x,IdDict(),CPUMODE)

Expand Down Expand Up @@ -43,38 +42,34 @@ function _ser(x::KnetArray{T,N},s::IdDict,m::typeof(JLDMODE)) where {T,N}
return s[x]
end

function _ser(x::RNN, s::IdDict, m::Val)
if !haskey(s,x)
# we need rd,dd only if there is a gpu, we are not in cpumode,
# and if we are in jldmode we are loading, not saving
# if (CUDA.functional() && m != CPUMODE && !(m == JLDMODE && x.rnnDesc != nothing))
# dd = DD(dropout=x.dropout,seed=x.seed)
# rd = RD(x.hiddenSize,x.numLayers,dd,x.inputMode,x.direction,x.mode,x.algo,x.dataType)
# else
# rd = dd = nothing
# end
_ser(x::KnetArray,s::IdDict,::typeof(GPUMODE))=x
_ser(x::KnetArray,s::IdDict,::typeof(CPUMODE))=(haskey(s,x) ? s[x] : s[x]=Array(x))

# 20200806: We no longer need to load/save rd/dd: rnnforw will construct as needed.
rd = dd = nothing
_ser(x::CuArray,s::IdDict,::typeof(GPUMODE))=x
_ser(x::CuArray,s::IdDict,::typeof(CPUMODE))=(haskey(s,x) ? s[x] : s[x]=Array(x))

# dx, dhx, dcx are temporary fields used by rnnback, they do not need to be copied
# gcnode sets dx.ptr to C_NULL which breaks serialize, best not to try
s[x] = RNN(_ser(x.w,s,m), _ser(x.h,s,m), _ser(x.c,s,m), x.inputSize, x.hiddenSize, x.numLayers, x.dropout, x.seed, x.inputMode, x.direction, x.mode, x.algo, x.dataType, rd, dd, nothing, nothing, nothing)
end
return s[x]
end

# Partially fixes the issue: when KA converts to A because no gpu, surrounding parametric types remain Param{KA}.
# However other container types that include KnetArray may still have an inconsistent parametric type problem.
_ser(x::Param, s::IdDict, m::Val)=(haskey(s,x) ? s[x] : s[x]=Param(_ser(x.value,s,m),_ser(x.opt,s,m)))

_ser(x::KnetArray,s::IdDict,::typeof(GPUMODE))=x
_ser(x::KnetArray,s::IdDict,::typeof(CPUMODE))=(haskey(s,x) ? s[x] : s[x]=Array(x))
_ser(x::Array, s::IdDict, m::Val) = (haskey(s, x) ? s[x] : s[x] = _ser_array_t(x, eltype(x), s, m))

function _ser_array_t(@nospecialize(x), T, s::IdDict, m::Val)
if !isbitstype(T)
map(xi->_ser(xi,s,m), x)
# map(xi->_ser(xi,s,m), x) # this fails with unassigned values
dest = similar(x,Any) # convert eltype to Any because it may change
# stackdict[x] = dest # we do this in the caller
for i = 1:(length(x)::Int)
if ccall(:jl_array_isassigned, Cint, (Any, Csize_t), x, i-1) != 0
xi = ccall(:jl_arrayref, Any, (Any, Csize_t), x, i-1)
if !isbits(xi)
xi = _ser(xi, s, m) # deepcopy_internal(xi, stackdict)::typeof(xi)
end
ccall(:jl_arrayset, Cvoid, (Any, Any, Csize_t), dest, xi, i-1)
end
end
return dest
elseif m === GPUMODE
KnetArray(x)
else
Expand Down Expand Up @@ -125,14 +120,21 @@ function _ser(@nospecialize(x), stackdict::IdDict, m::Val)
return y
end

function _ser(x::Union{Dict,IdDict}, s::IdDict, m::Val)
function _ser(x::Union{Dict{K,V},IdDict{K,V}}, s::IdDict, m::Val) where {K,V}
if haskey(s, x)
return s[x] # removed ::typeof(x)
end
if isbitstype(eltype(x))
if isbitstype(K) && isbitstype(V)
return (s[x] = x)
end
dest = empty(x)
# Type of key or value may change unless isbitstype
if isbitstype(K)
dest = empty(x, K, Any)
elseif isbitstype(V)
dest = empty(x, Any, V)
else
dest = empty(x, Any, Any)
end
s[x] = dest
for (k, v) in x
dest[_ser(k, s, m)] = _ser(v, s, m)
Expand Down
13 changes: 12 additions & 1 deletion src/knetarrays/getindex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,23 @@ function materialize!(A::SubArray{T,N,<:KnetArray}, B) where {T,N}
materialize!(_A, _B)
end

# Ambiguity fix:
# For contiguous I, dotview(A, I...) gives a shared-memory KnetArray rather than a view
function materialize!(A::KnetArray, B) where {T,N}
_A = CuArray(A)
_B = (B isa KnetArray || B isa AbstractArray ? CuArray(B) : B)
materialize!(_A, _B)
end

# Ambiguity fixes:
function materialize!(A::SubArray{T,N,<:KnetArray}, B::Broadcasted{S}) where {T,N,S}
_A = view(CuArray(A.parent), A.indices...)
materialize!(_A, B)
end

function materialize!(A::KnetArray{T,N}, B::Broadcasted{S}) where {T,N,S}
_A = CuArray(A)
materialize!(_A, B)
end


# The following fallback version tried to do all allocations using KnetArrays but was recently broken (Issue 618).
Expand Down
60 changes: 24 additions & 36 deletions src/fileio_gpu/jld.jl → src/knetarrays/jld2.jl
Original file line number Diff line number Diff line change
@@ -1,52 +1,38 @@
export save, load, @save, @load
import FileIO # save, load
using JLD2: JLD2, JLDWriteSession, jldopen, isgroup, lookup_offset
#include("serialize.jl") ## serialize ## TODO: extend Base.Serialization, cover CuArrays
import JLD2, FileIO

"""
Knet.save(filename, args...; kwargs...)
# With the following standard FileIO.save, FileIO.load, JLD2.@save, JLD2.@load should work
struct JLD2KnetArray{T,N}; array::Array{T,N}; end
JLD2.writeas(::Type{KnetArray{T,N}}) where {T,N} = JLD2KnetArray{T,N}
JLD2.wconvert(::Type{JLD2KnetArray{T,N}}, x::KnetArray{T,N}) where {T,N} = JLD2KnetArray(Array(x))
JLD2.rconvert(::Type{KnetArray{T,N}}, x::JLD2KnetArray{T,N}) where {T,N} = KnetArray(x.array)

Call `FileIO.save` after serializing Knet specific args.

File format is determined by the filename extension. JLD and JLD2 are supported. Other formats
may work if supported by FileIO, please refer to the documentation of FileIO and the specific
format. Example:
# These are deprecated functions and macros for backward compatibility and loading old files

Knet.save("foo.jld2", "name1", value1, "name2", value2)
"""
function save(fname,args...;kwargs...)
FileIO.save(fname,serialize.(args)...;kwargs...)
function save(file, args...; options...)
@warn "Knet.save is deprecated, please use FileIO.save/load instead" maxlog=1
FileIO.save(file, jld2serialize.(args)...; options...)
end

"""
Knet.load(filename, args...; kwargs...)
Call `FileIO.load` then deserialize Knet specific values.
File format is determined by FileIO. JLD and JLD2 are supported. Other formats may work if
supported by FileIO, please refer to the documentation of FileIO and the specific format.
Example:
Knet.load("foo.jld2") # returns a ("name"=>value) dictionary
Knet.load("foo.jld2", "name1") # returns the value of "name1" in "foo.jld2"
Knet.load("foo.jld2", "name1", "name2") # returns tuple (value1, value2)
"""
function load(fname,args...;kwargs...)
serialize(FileIO.load(fname,args...;kwargs...))
function load(file, args...; options...)
@warn "Knet.load is deprecated, please use FileIO.save/load instead" maxlog=1
jld2serialize(FileIO.load(file, args...; options...))
end


"""
Knet.@save "filename" variable1 variable2...
Save the values of the specified variables to filename in JLD2 format.
When called with no variable arguments, write all variables in the global scope of the current
module to filename. See [JLD2](https://github.com/JuliaIO/JLD2.jl).
This macro is deprecated, please use `JLD2.@save` instead.
"""
macro save(filename, vars...)
if isempty(vars)
# Save all variables in the current module
quote
@warn "Knet.@save is deprecated, please use JLD2.@save/@load instead" maxlog=1
let
m = $(__module__)
f = JLD2.jldopen($(esc(filename)), "w")
Expand All @@ -58,7 +44,7 @@ macro save(filename, vars...)
v = getfield(m, vname)
if !isa(v, Module)
try
write(f, s, serialize(v), wsession)
write(f, s, jld2serialize(v), wsession)
catch e
if isa(e, PointerException)
@warn("skipping $vname because it contains a pointer")
Expand All @@ -77,10 +63,11 @@ macro save(filename, vars...)
else
writeexprs = Vector{Expr}(undef, length(vars))
for i = 1:length(vars)
writeexprs[i] = :(write(f, $(string(vars[i])), serialize($(esc(vars[i]))), wsession))
writeexprs[i] = :(write(f, $(string(vars[i])), jld2serialize($(esc(vars[i]))), wsession))
end

quote
@warn "Knet.@save is deprecated, please use JLD2.@save/@load instead" maxlog=1
JLD2.jldopen($(esc(filename)), "w") do f
wsession = JLD2.JLDWriteSession()
$(Expr(:block, writeexprs...))
Expand All @@ -91,11 +78,11 @@ end

"""
Knet.@load "filename" variable1 variable2...
Load the values of the specified variables from filename in JLD2 format.
When called with no variable arguments, load all variables in filename. See
[JLD2](https://github.com/JuliaIO/JLD2.jl).
This macro is deprecated, please use `JLD2.@load` instead.
"""
macro load(filename, vars...)
if isempty(vars)
Expand All @@ -117,8 +104,9 @@ macro load(filename, vars...)
end
end
return quote
@warn "Knet.@load is deprecated, please use JLD2.@save/@load instead" maxlog=1
($([esc(x) for x in vars]...),) = JLD2.jldopen($(esc(filename))) do f
($([:(serialize(read(f, $(string(x))))) for x in vars]...),)
($([:(jld2serialize(read(f, $(string(x))))) for x in vars]...),)
end
$(Symbol[v for v in vars]) # convert to Array
end
Expand Down
2 changes: 1 addition & 1 deletion src/knetarrays/karray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ end

# Extend function KnetArray to create a memory shared KnetArray from CuArray:
function KnetArray(x::CuArray{T,N}) where {T,N}
p = Base.bitcast(Cptr, x.ptr)
p = Base.bitcast(Cptr, pointer(x))
k = KnetPtr(p, sizeof(x), Int(CUDA.device().handle), x)
KnetArray{T,N}(k, size(x))
end
Expand Down
18 changes: 18 additions & 0 deletions src/knetarrays/serialization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# From @jbaron https://github.com/denizyuret/Knet.jl/issues/587

using Serialization

"""
Enable saving and loading of models by specialized KnetArray methods for Julia serialization
This will effectively move a GPU weight to the CPU before serializing it and move it back to
the GPU when deserializing.
"""
function Serialization.serialize(s::Serialization.AbstractSerializer, p::KnetArray)
Serialization.serialize_type(s, typeof(p))
Serialization.serialize(s, Array(p))
end

function Serialization.deserialize(s::Serialization.AbstractSerializer, t::Type{<:KnetArray})
arr = Serialization.deserialize(s)
return KnetArray(arr)
end
Loading

2 comments on commit 3b6f70a

@denizyuret
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/25457

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.4.4 -m "<description of version>" 3b6f70a00329038e7e5ceb4acf57d1ff12777689
git push origin v1.4.4

Please sign in to comment.