Skip to content

Commit

Permalink
RFC: Support JuliaFolds interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed May 14, 2021
1 parent fe14a82 commit ec17e62
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 2 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@ version = "1.2.16"

[deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
FGenerators = "4fd0377b-cfdc-4941-97f4-8d7ddbb8981e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SplittablesBase = "171d559e-b47b-412a-8079-5efa626c420e"

[compat]
julia = "1"

[extras]
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"
SplittablesTesting = "3bda5eb5-c32a-4f64-8618-df3be8968470"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Folds", "SplittablesTesting", "Test"]
1 change: 1 addition & 0 deletions src/SentinelArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ end

include("chainedvector.jl")
include("missingvector.jl")
include("folds.jl")

include("precompile.jl")
_precompile_()
Expand Down
17 changes: 17 additions & 0 deletions src/folds.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using FGenerators: @fgenerator, @yield
using SplittablesBase: SplittablesBase

@fgenerator(A::ChainedVector) do
for array in A.arrays
for x in array
@yield x
end
end
end

function SplittablesBase.halve(A::ChainedVector)
chunk = searchsortedfirst(A.inds, length(A) ÷ 2)
left = @view A.arrays[1:chunk]
right = @view A.arrays[chunk+1:end]
return (Iterators.flatten(left), Iterators.flatten(right))
end
21 changes: 20 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using SentinelArrays, Test, Random
using SentinelArrays, Test, Random, SplittablesTesting, Folds

@testset "SentinelArrays" begin

Expand Down Expand Up @@ -580,4 +580,23 @@ deleteat!(c2, Int[])
end
end

unitrange_chainedvectors = map(1:100) do seed
rng = MersenneTwister(seed)
ends = cumsum(rand(rng, 0:9, rand(rng, 1:100)))
starts = pushfirst!(ends[1:end-1] .+ 1, 1)
cv = ChainedVector((:).(starts, ends))
return (label = "seed=$seed", data = cv)
end

@testset "Folds" begin
@testset "$label" for (label, data) in unitrange_chainedvectors
@test Folds.collect(data, SequentialEx()) == 1:length(data)
@test Folds.collect(data) == 1:length(data)
end
end

@testset "SplittablesBase" begin
SplittablesTesting.test_unordered(unitrange_chainedvectors)
end

end

0 comments on commit ec17e62

Please sign in to comment.