Skip to content
This repository was archived by the owner on May 5, 2019. It is now read-only.

WIP: Fix type instability in groupby() #12

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 29 additions & 30 deletions src/groupeddatatable/grouping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ function groupsort_indexer(x::AbstractVector, ngroups::Integer, null_last::Bool=

# count group sizes, location 0 for NULL
n = length(x)
# counts = x.pool
counts = fill(0, ngroups + 1)
for i = 1:n
@inbounds for i in 1:n
counts[x[i] + 1] += 1
end

Expand All @@ -52,14 +51,31 @@ function groupsort_indexer(x::AbstractVector, ngroups::Integer, null_last::Bool=

# this is our indexer
result = fill(0, n)
for i = 1:n
@inbounds for i in 1:n
label = x[i] + 1
result[where[label]] = i
where[label] += 1
end
result, where, counts
end

function fill_groups!{T}(x::AbstractVector, v::AbstractVector{T}, ngroups::Integer)
if T <: Nullable || Nullable <: T
cv = convert(NullableCategoricalVector, v)
anynulls = findfirst(cv.refs, 0) > 0
order = CategoricalArrays.order(cv.pool) .+ anynulls .- 1
else
cv = convert(CategoricalVector, v)
order = CategoricalArrays.order(cv.pool)
end

refs = cv.refs
@inbounds for i in eachindex(x, v)
x[i] += order[refs[i]] * ngroups
end
length(order)
end

"""
A view of an AbstractDataTable split into row groups

Expand Down Expand Up @@ -122,38 +138,21 @@ function groupby{T}(d::AbstractDataTable, cols::Vector{T})
## http://wesmckinney.com/blog/?p=489

ncols = length(cols)
# use CategoricalArray to get a set of integer references for each unique item
nv = NullableCategoricalArray(d[cols[ncols]])
# if there are NULLs, add 1 to the refs to avoid underflows in x later
anynulls = (findfirst(nv.refs, 0) > 0 ? 1 : 0)
# use UInt32 instead of the original array's integer size since the number of levels can be high
x = similar(nv.refs, UInt32)
for i = 1:nrow(d)
if nv.refs[i] == 0
x[i] = 1
else
x[i] = CategoricalArrays.order(nv.pool)[nv.refs[i]] + anynulls
end
end
# also compute the number of groups, which is the product of the set lengths
ngroups = length(levels(nv)) + anynulls
# if there's more than 1 column, do roughly the same thing repeatedly
for j = (ncols - 1):-1:1
nv = NullableCategoricalArray(d[cols[j]])
anynulls = (findfirst(nv.refs, 0) > 0 ? 1 : 0)
for i = 1:nrow(d)
if nv.refs[i] != 0
x[i] += (CategoricalArrays.order(nv.pool)[nv.refs[i]] + anynulls - 1) * ngroups
end
end
ngroups = ngroups * (length(levels(nv)) + anynulls)
x = ones(UInt32, nrow(d))
ngroups = 1
for j in ncols:-1:1
newgroups = fill_groups!(x, d[j], ngroups)
# compute the number of groups, which is the product of the set lengths
ngroups = ngroups * newgroups
# TODO if ngroups is really big, shrink it
end
(idx, starts) = groupsort_indexer(x, ngroups)
# Remove zero-length groupings
starts = _uniqueofsorted(starts)
ends = starts[2:end] - 1
GroupedDataTable(d, cols, idx, starts[1:end-1], ends)
ends = starts[2:end]
ends .-= 1
pop!(starts)
GroupedDataTable(d, cols, idx, starts, ends)
end
groupby(d::AbstractDataTable, cols) = groupby(d, [cols])

Expand Down