flox>=0.9
adds heuristics for automatically choosing an appropriate strategy with dask arrays! Here I describe how.
flox
implements grouped reductions for chunked array types like cubed and dask using tree reductions.
Tree reductions (example) are a parallel-friendly way of computing common reduction operations like sum
, mean
etc.
Without flox, Xarray effectively shuffles — sorts the data to extract all values in a single group — and then runs the reduction group-by-group.
Depending on data layout or "chunking" this shuffle can be quite expensive.
With flox installed, Xarray instead uses its parallel-friendly tree reduction.
In many cases, this is a massive improvement.
Notice how much cleaner the graph is in this image:
See our previous blog post for more.
Two key realizations influenced the development of flox:
"time.month"
is exactly periodic, "time.dayofyear"
is approximately periodic (depending on calendar), "time.year"
is commonly a monotonic increasing array.These two properties are particularly relevant for "climatology" calculations (e.g. groupby("time.month").mean()
) — a common Xarray workload in the Earth Sciences.
Consider ds.groupby("time.year").mean()
, or the equivalent ds.resample(time="Y").mean()
for a 100 year long dataset of monthly averages with chunk size of 1 (or 4) along the time dimension.
This is a fairly common format for climate model output.
The small chunk size along time is offset by much larger chunk sizes along the other dimensions — commonly horizontal space (x, y
or latitude, longitude
).
A naive tree reduction would accumulate all averaged values into a single output chunk of size 100 — one value per year for 100 years. Depending on the chunking of the input dataset, this may overload the final worker's memory and fail catastrophically. More importantly, there is a lot of wasteful communication — computing on the last year of data is completely independent of computing on the first year of the data, and there is no reason the results for the two years need to reside in the same output chunk. This issue does not arise for regular reductions where the final result depends on the values in all chunks, and all data along the reduced axes are reduced down to one final value.
Thus flox
quickly grew two new modes of computing the groupby reduction.
First, method="blockwise"
which applies the grouped-reduction in a blockwise fashion.
This is great for resample(time="Y").mean()
where we group by "time.year"
, which is a monotonic increasing array.
With an appropriate (and usually quite cheap) rechunking, the problem is embarrassingly parallel.
Second, method="cohorts"
which is a bit more subtle.
Consider groupby("time.month")
for the monthly mean dataset i.e. grouping by an exactly periodic array.
When the chunk size along the core dimension "time" is a divisor of the period; so either 1, 2, 3, 4, or 6 in this case; groups tend to occur in cohorts ("groups of groups").
For example, with a chunk size of 4, monthly mean input data for the "cohort" Jan/Feb/Mar/Apr are always in the same chunk, and totally separate from any of the other months.
Here is a schematic illustration where each month is represented by a different shade of red:
This means that we can run the tree reduction for each cohort (three cohorts in total: JFMA | MJJA | SOND
) independently and expose more parallelism.
Doing so can significantly reduce compute times and in particular memory required for the computation.
Importantly if there isn't much separation of groups into cohorts; example, the groups are randomly distributed, then it's hard to do better than the standard method="map-reduce"
.
These strategies are great, but the downside is that some sophistication is required to apply them. Worse, they are hard to explain conceptually! I've tried! (example 1, example 2).
What we need is to choose the appropriate strategy automatically.
Fundamentally, we know:
We want to find all sets of groups that occupy similar sets of chunks.
For groups A,B,C,D
that occupy the following chunks (chunk 0 is the first chunk along the core-dimension or the axis of reduction)
A: [0, 1, 2] B: [1, 2, 3] D: [5, 6, 7, 8] E: [8] X: [0, 3]
We want to detect the cohorts {A,B,X}
and {C, D}
with the following chunks.
[A, B, X]: [0, 1, 2, 3] [C, D]: [5, 6, 7, 8]
Importantly, we do not want to be dependent on detecting exact patterns, and prefer approximate solutions and heuristics.
After a fun exploration involving such fun ideas as locality-sensitive hashing, and all-pair set similarity search, I settled on the following algorithm.
I use set containment, or a "normalized intersection", to determine the similarity the sets of chunks occupied by two different groups (Q
and X
).
C = |Q ∩ X| / |Q| ≤ 1; (∩ is set intersection)
Unlike Jaccard similarity, containment isn't skewed when one of the sets is much larger than the other.
The steps are as follows:
S[chunks, labels]
. S[i, j] = 1
when
label j
is present in chunk i
."blockwise"
when every group is contained to one block each."cohorts"
when every chunk only has a single group, but that group might extend across multiple chunksS
to compute an initial set of cohorts whose groups are in the same exact chunks (this is another groupby!).
Later we will want to merge together the detected cohorts when they occupy approximately the same chunks, using the containment metric.i
against all other groups j
as C = S.T @ S / number_chunks_per_group
."map-reduce"
and "cohorts"
, we need a summary measure of the degree to which the labels overlap with
each other. We can use sparsity --- the number of non-zero elements in C
divided by the number of elements in C
, C.nnz/C.size
.
We use sparsity --- the number of non-zero elements in C
divided by the number of elements in C
, C.nnz/C.size
. When sparsity is relatively high, we use "map-reduce"
, otherwise we use "cohorts"
.For more detail see the docs or the code. Suggestions and improvements are very welcome!
Here is C
for a range of chunk sizes from 1 to 12, for computing groupby("time.month")
of a monthly mean dataset, [the title on each image is (chunk size, sparsity)].
Given the above C
, flox will choose:
"blockwise"
for chunk size 1,"cohorts"
for (2, 3, 4, 6, 12),"map-reduce"
for the rest.Cool, isn't it?!
Importantly this inference is fast — 400ms for the US county GroupBy problem in our previous post! But we have not tried with bigger problems (example: GroupBy(100,000 watersheds) in the US).
flox' ability to do such inferences relies entirely on the input chunking, a big knob. A recent Xarray feature makes such rechunking a lot easier for time grouping:
1from xarray.groupers import TimeResampler 2 3rechunked = ds.chunk(time=TimeResampler("YE")) 4
will rechunk so that a year of data is in a single chunk.
Even so, it would be nice to automatically rechunk to minimize number of cohorts detected, or to a perfectly blockwise application when that's cheap.
A challenge here is that we have lost context when moving from Xarray to flox.
The string "time.month"
tells Xarray that I am grouping a perfectly periodic array with period 12; similarly
the string "time.dayofyear"
tells Xarray that I am grouping by a (quasi-)periodic array with period 365, and that group 366
may occur occasionally (depending on calendar).
But Xarray passes flox an array of integer group labels [1, 2, 3, 4, 5, ..., 1, 2, 3, 4, 5]
.
It's hard to infer the context from that!
Get in touch if you have ideas for how to do this inference.
One way to preserve context may be be to have Xarray's new Grouper objects report "preferred chunks" for a particular grouping.
This would allow a downstream system like flox
or cubed
or dask-expr
to take this in to account later (or even earlier!) in the pipeline.
That is an experiment for another day.