diff --git a/_posts/sklearn-perf-context.md b/_posts/sklearn-perf-context.md new file mode 100644 index 0000000..84d9b5e --- /dev/null +++ b/_posts/sklearn-perf-context.md @@ -0,0 +1,126 @@ +Title: Performance and scikit-learn (1/4) +Date: 2021-12-16 +Category: scikit-learn +Slug: sklearn-perf-context +Lang: en +Authors: Julien Jerphanion +Summary: Context: the current state of scikit-learn performance +Status: Published + +## High-level overview of the scikit-learn dependences + +scikit-learn is mainly written in Python and is built on top of +some core libraries of the scientific Python ecosystem. + +This ecosystem allows _high expressiveness_ and +_interactivity_: one can perform complex operations in a few +lines of code and get the results straight away. + +It also allowed setting up simple conventions which makes the +code-base algorithms easy to understand and improve +for new contributors. + +It also allows delegating most of the complex operations +to well-tested third-party libraries. For instance, calls +to functions implemented in +[`numpy.linalg`](https://numpy.org/doc/stable/reference/routines.linalg.html), +[`scipy.linalg`](https://docs.scipy.org/doc/scipy/reference/linalg.html ), and +[`scipy.sparse.linalg`](https://docs.scipy.org/doc/scipy/reference/sparse.linalg.html) +are delegated to +[BLAS](https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms), +[LAPACK](https://www.netlib.org/lapack/), +and [ARPACK](https://www.caam.rice.edu/software/ARPACK/) interfaces. + +## Main reasons for limited performance + +The PyData stack is simple but is not tailored for optimal performance +for several reasons. + +### CPython internals + +CPython -- the main implementation of Python -- is slow. + +First, CPython has _an interpreter_: there's a cost in converting Python +instructions into another intermediate representation -- +the _byte-code_ -- and executing the instructions by interpreting their +byte-code. + +Secondly, nearly every value in CPython is _boxed_ into a `PyObject` +-- implemented as a C struct. As such, simple operations +(like adding two floats) come with a non-negligible dispatch +overhead as the interpreter has to check the type which is unknown +in advance. + +Thirdly, CPython for memory management relies on a global +mutex on its interpreter called the _Global Interpreter Lock_[ref]For more information about the GIL, see +[this reference from the Python Wiki](https://wiki.python.org/moin/GlobalInterpreterLock).[/ref]. +This mechanism comes in handy but computations are restricted in +most cases to sequential execution in a single thread, removing the benefit +of using threads. + +### Memory-hierarchy suboptimal implementations + +`numpy` supports high-level operations but this comes with intermediate +and dynamically-allocated arrays. + +Moreover, this pattern is inefficient from a memory perspective: +during the execution, blocks of data are moved back and forth +between the RAM and the different CPU caches several times, not +making optimal use of the caches. + +For instance, based on this minimalistic toy example: +```python +import numpy as np + +A = np.random.rand(100_000_000) +B = np.random.rand(100_000_000) + +X = np.exp(A + B) +``` + +The following is performed: + + - a first temporary array is allocated for `A + B` + - a second array is allocated to store `np.exp(A + B)` and + the first temporary array is discarded + +This temporary allocation makes the implementation suboptimal +as memory allocation on the heap is slow. + +Furthermore, high-level operations on `X` come with more data +moves between the RAM and the CPU than needed to compute the +elements of `X` and hardly make use of the memory hierarchy +and the size of the caches. + +### No "bare-metal" data-structures + +The Python ecosystem comes with a few high-level containers +such as numpy arrays, and pandas DataFrames. + +Contrarily to other languages' standard libraries (like the one of +C++), no "bare-metal" data structures, including heaps, or +memory-contiguous resizable buffers (as implemented in C++ by +[`std::priority_queue`](https://en.cppreference.com/w/cpp/container/priority_queue) +and [`std::vector`](https://en.cppreference.com/w/cpp/container/vector)) +are available to implement some algorithms efficiently +from both a computational complexity and a technical perspective. + +## Cython: combining the conciseness of Python and the speed of C + +In brief, Cython allows transpiling a superset of Python to C code and allows using code that was written in C or C++, which makes bypassing some of CPython's internals possible. Moreover, Cython allows using [OpenMP](https://www.openmp.org/specifications/), an API that allows using lower-level parallelism primitives for implementations written in C or Fortran[ref] For more information on Cython, see [its documentation](https://cython.readthedocs.io/en/latest/).[/ref]. + +In most cases, features provided by Cython are sufficient enough to reach optimal implementations for many scientific algorithms for which static tasks scheduling -- at the level of C via OpenMP -- is the most natural and optimal one. +Plus, its syntax makes this language expressive enough to get nearly optimal performance while keeping the instructions short and concise -- which is a real advantage for developers coming from Python who are looking for performance and relief for C or C++ developers[ref]Compared to C or C++, Cython is also less verbose and can be integrated Python build system more easily.[/ref]. + +As such, many algorithms in `scikit-learn` are implemented in Cython performance, some of which use OpenMP when possible. This is for instance the case of `KMeans` which was initially written in Python using numpy and which was rewritten in Cython by Jérémie du Boisberranger, improving the execution time by a factor of 5 for this algorithm[ref]For more information about `KMeans`, see the original contribution, +[`scikit-learn#11950`](https://github.com/scikit-learn/scikit-learn/pull/11950), and [this blog +post](https://scikit-learn.fondation-inria.fr/implementing-a-faster-kmeans-in-scikit-learn-0-23-2/).[/ref]. + +In the following posts, the case of $k$-nearest neighbors search -- the base routine +for `KNearestNeighborsClassifier`, `KNearestNeighborsRegressor` and other scikit-learn interfaces -- is covered +and a new Cython implementation is proposed. + +--- + +## Notes + diff --git a/_posts/sklearn-perf-knn.md b/_posts/sklearn-perf-knn.md new file mode 100644 index 0000000..1a97888 --- /dev/null +++ b/_posts/sklearn-perf-knn.md @@ -0,0 +1,70 @@ +Title: Performance and scikit-learn (2/4) +Date: 2021-12-17 +Category: scikit-learn +Slug: sklearn-perf-knn +Lang: en +Authors: Julien Jerphanion +Summary: Hardware scalability issue: the k-neighbors search example +Status: Published + +## $k$-nearest neighbors search in scikit-learn + +$k$-nearest neighbors search is at the base of many implementations used within scikit-learn. + +For instance, it is used in Affinity Propagation, BIRCH, Mean Shift, OPTICS, +Spectral Clustering, TSNE, KNeighbors Regressor, and KNeighbors Classifier. + +Whilst many libraries implement approximated versions of $k$-nearest neighbors search to speed-up +the execution time[ref]Approximate nearest neighbors search algorithms come in many different +flavors, and there's even [a benchmark suite comparing them!](https://ann-benchmarks.com/).[/ref], scikit-learn's implementation aims at returning the exact $k$-nearest neighbors. + + +## Computing chunks of the distance matrix computations + +The general steps for $k$-nearest neighbors search are: + + - Compute $\mathbf{D}_d(\mathbf{X}, \mathbf{Y})$, the distance matrix between the vectors of two + arrays $\mathbf{X}$ and $\mathbf{Y}$. + - Reduce rows of $\mathbf{D}_d(\mathbf{X}, \mathbf{Y})$ appropriately for the given algorithm: + for instance, the adapted reduction for $k$-nn search is to return the $k$ smallest indices of values in an ordered set. +In what follows, we call this reduction $\texttt{argkmin}$. + - Perform extra operations on results of the reductions (here sort values). + +Generally, one does not compute $\mathbf{D}_d(\mathbf{X}, \mathbf{Y})$ entirely because its +space complexity is $\Theta(n_x \times n_y)$. Practically, +$\mathbf{D}_d(\mathbf{X}, \mathbf{Y})$ does not fit in RAM for sound datasets. + +Thus, in practice, one computes chunks of this dataset and reduced them directly. +This is what was performed as of scikit-learn 1.0[ref][`KNearestNeighbors.kneighbors`](https://github.com/scikit-learn/scikit-learn/blob/c762c407873b8d6417b1c2ff78d19d82550e48d3/sklearn/neighbors/_base.py#L650) is the interface to look for.[/ref]. + + +## Current issues + +The current implementation relies on a general parallelization scheme using higher-level functions with +[`joblib.Parallel`](https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html). + +Technically, this is not the most efficient: working at this level with views on numpy arrays moves +large chunks of data back and forth several times between the RAM and the CPUs' caches, hardly +make use of caches, and allocate temporary results. + +Moreover, the cost of manipulating those chunks of data in the CPython interpreter causes a non +negligible overhead because they are associated with Python objects which are bound to the +Global Lock Interpreter for counting their references. + +Hence, this does not allow getting proper _hardware scalability_: an ideal parallel +implementation of an algorithm would run close to $n$ times faster when running on +$n$ threads on $n$ CPU cores compared to sequentially on $1$ thread. + +For instance, the current implementation of $k$-nearest neighbors search of scikit-learn +cannot efficiently leverage all the available CPUs on a machine -- as shown by the figure below. + +![Scalability of `kneighbors` as of scikit-learn 1.0](https://user-images.githubusercontent.com/13029839/155144242-6c041729-154b-47aa-9069-3a7d26deef5a.svg) + +When using $8$ threads, it only run $\times 2$ faster than the sequential implementation +and adding more threads and CPUs beyond $8$ does not help to get better performance. + +In the next blog, we go over the design of a new implementation to improve the scalability of $k$-nn search. + +--- + +## Notes diff --git a/_posts/sklearn-perf-pdr-extra.md b/_posts/sklearn-perf-pdr-extra.md new file mode 100644 index 0000000..79ad0fc --- /dev/null +++ b/_posts/sklearn-perf-pdr-extra.md @@ -0,0 +1,293 @@ +Title: Performance and scikit-learn (4/4) +Date: 2021-12-19 +Category: scikit-learn +Slug: sklearn-perf-pdr-extra +Lang: en +Authors: Julien Jerphanion +Summary: Pairwise Distances Reductions: extra notes on technical details, benchmarks and further work. +Status: Published + +Following up with [this initial post on the design of `PairwiseDistancesReductions`](sklearn-perf-pdr.html), +more details are given regarding experiments' results for performance assessment and about future extensions' design. + +## `PairwiseDistancesArgKmin`: Performance improvements + +[`KNeighborsMixing.kneighbors`](https://github.com/scikit-learn/scikit-learn/blob/f924bc8a1da541fa63b649046cedbc51d1024464/sklearn/neighbors/_base.py#L647) is the _de facto_ best proxy +for accessing performance of the implementation used in most cases : `EuclideanPairwiseDistancesArgKmin`. + +In what follows, experiments testing this interface are made on two aspects: hardware scalability and computational efficiency. + +### Hardware scalability + +This is the hardware scalability of `kneighbors` in scikit-learn `1.0`: + +![Scalability of argkmin reductions in scikit-learn 1.0](https://user-images.githubusercontent.com/13029839/155144242-6c041729-154b-47aa-9069-3a7d26deef5a.svg) + +This is the hardware the scalability of `kneighbors` as proposed in `sklearn#22134`: + +![Scalability of argkmin reductions using the proposed `PairwiseDistancesReductionArgKmin`](https://user-images.githubusercontent.com/13029839/155096010-d143649b-3904-4d80-b3d0-5017724d19ad.svg) + +The proposed implementation provides a better hardware scalability than the previous one. + +The plateau after 64 cores can be explained by Amdahl's law[ref]Gene M. Amdahl. 1967. Validity of the single processor approach to achieving large scale computing capabilities. In Proceedings of the April 18-20, 1967, spring joint computer conference (AFIPS '67 (Spring)). Association for Computing Machinery, New York, NY, USA, 483–485. DOI: [`https://doi.org/10.1145/1465482.1465560`](https://doi.org/10.1145/1465482.1465560)[/ref]: as the number of threads grows, the parallel portion of +the +algorithm becomes negligeable compared to its sequential portion, reaching a limit in term of computational time -- the execution period of the sequential part -- hence causing speed-up ratio to stop increasing. Moreover, the small drop in speed-up for 128 threads can be explained by the overhead of setting up threads which becomes non-negligeable compared to the actual computations made in each thread. + +### Computational efficiency of `EuclideanPairwiseDistancesArgKmin` + +On distributions of GNU/Linux, [`perf(1)`](https://man7.org/linux/man-pages/man1/perf.1.html) comes in handy to introspect a program execution in details[ref]If you are using another OS, `perf(1)` won't be usable. Still, you should be able to perform similar inspections using [`dtrace`](https://www.brendangregg.com/dtrace.html).[/ref]. + +Here, we inspect where CPUs cycles are spent, as well as L3 caches misses and L3 caches hits using the following script on a machine having 20 physical cores[ref]The CPUs used are: Intel(R) Xeon(R) CPU E5-2660 v2 @ 2.20GHz[/ref]: + +```python +# kneighbors_perf.py + +import numpy as np +import os +from sklearn.neighbors import NearestNeighbors + + +if __name__ == "__main__": + + n_train = 100_000 + n_test = 100_000 + n_features = 30 + + rng = np.random.RandomState(0) + + # We persist datasets on disk so as to solely have + # `perf(1)` introspect the events for the core + # of the computations: `kneighbors`. + + X_train_file = "X_train.npy" + X_test_file = "X_test.npy" + + if os.path.exists(X_train_file): + X_train = np.load(X_train_file) + else: + X_train = rng.rand(n_train, n_features) + np.save(X_train_file, X_train) + + if os.path.exists(X_test_file): + X_test = np.load(X_test_file) + else: + X_test = rng.rand(n_test, n_features) + np.save(X_test_file, X_test) + + est = NearestNeighbors(n_neighbors=10, algorithm="brute").fit(X=X_train) + + # FastEuclideanPairwiseDistancesArgKmin will be used under the hood. + est.kneighbors(X_test) +``` + +And the following call to [`perf-record(1)`](https://man7.org/linux/man-pages/man1/perf-record.1.html)[ref]You might need to adapt the events because they change from one architecture to another. See [`perf-list(1)`](https://man7.org/linux/man-pages/man1/perf-list.1.html).[/ref]: + +```sh +perf record -e \ + cycles,\ # Record CPU cycles + mem_load_uops_retired.llc_miss,\ # Record L3 caches' misses + mem_load_uops_retired.llc_hit \ # Record L3 caches' hits + python kneighbors_perf.py +``` + +this dumps a binary `perf.data` file which can be explored using [`perf-report(1)`](https://man7.org/linux/man-pages/man1/perf-report.1.html): + +```sh +perf report --hierarchical \ # to be able to see overhead hierarchically + --inline # to annotate with callgraph addresses +``` + +**On CPUs cycles** + +This is the report for the `cycles` events. + +``` +Samples: 543K of event 'cycles:u', Event count (approx.): 335205056539 + +- 100.00% python ▒ + - 68.07% libopenblasp-r0.3.18.so ▒ + 57.45% [.] dgemm_kernel_SANDYBRIDGE ◆ + 4.51% [.] dgemm_beta_SANDYBRIDGE ▒ + 3.33% [.] dgemm_incopy_SANDYBRIDGE ▒ + 2.59% [.] dgemm_oncopy_SANDYBRIDGE ▒ + 0.09% [.] dgemm_tn ▒ + 0.04% [.] blas_thread_server ▒ + 0.01% [.] dgemm_ ▒ + 0.01% [.] ddot_kernel_8 ▒ + 0.01% [.] blas_memory_free ▒ + 0.01% [.] blas_memory_alloc ▒ + 0.00% [.] dgemm_small_matrix_permit_SANDYBRIDGE ▒ + 0.00% [.] dot_compute ▒ + 0.00% [.] ddot_k_SANDYBRIDGE ▒ + 0.00% [.] ddot_ ▒ + - 22.17% _pairwise_distances_reduction.cpython-39-x86_64-linux-gnu.▒ + 22.16% [.] __pyx_f_7sklearn_7metrics_29_pairwise_distances_red▒ + 0.00% [.] __pyx_memoryview_slice_memviewslice ▒ + 0.00% [.] __pyx_f_7sklearn_7metrics_29_pairwise_distances_red▒ + 0.00% [.] __pyx_f_7sklearn_7metrics_29_pairwise_distances_red▒ + 0.00% [.] __pyx_f_7sklearn_7metrics_29_pairwise_distances_red▒ + 0.00% [.] __pyx_f_7sklearn_7metrics_29_pairwise_distances_red▒ + - 9.25% _heap.cpython-39-x86_64-linux-gnu.so ▒ + 9.24% [.] __pyx_fuse_1__pyx_f_7sklearn_5utils_5_heap_heap_pus▒ + 0.01% [.] __pyx_fuse_1__pyx_f_7sklearn_5utils_5_heap_simultan▒ + + 0.20% python3.9 ▒ + - 0.15% libgomp.so.1.0.0 ▒ + 0.15% [.] do_wait ▒ + 0.00% [.] gomp_barrier_wait_end ▒ + 0.00% [.] gomp_thread_start ▒ + 0.00% [.] gomp_team_barrier_wait_end ▒ + 0.00% [.] futex_wake ▒ +``` + +Most of the CPUs cycles are spent in GEMM. The rest of them are mainly used +to iterate on the chunks of the distance matrix, pushing values and indices on the +max-heaps. + +Note that the calls of the parallelisation using OpenMP via Cython and +of the CPython interpreter comes with negligeable overhead. + +Assuming most readers are curious and like getting into details, we can actually look at the +kind of CPU instructions which are being used in `dgemm_kernel_SANDYBRIDGE`[ref]Unmangling `dgemm_kernel_SANDYBRIDGE`: this is the core (`kernel`) of the float64/double (`d`) implementation of GEMM for the [Sandy Bridge architecture](https://en.wikichip.org/wiki/intel/microarchitectures/sandy_bridge_(client)).[/ref], the critical +region. + +``` +Samples: 543K of event 'cycles:u', 4000 Hz, Event count (approx.): 335205056539 +dgemm_kernel_SANDYBRIDGE + 0.94 │ vmulpd %ymm1,%ymm3,%ymm7 + 0.50 │ vpermilpd $0x5,%ymm2,%ymm3 + 0.52 │ vaddpd %ymm14,%ymm6,%ymm14 + 1.11 │ vaddpd %ymm12,%ymm7,%ymm12 + 1.55 │ vmulpd %ymm0,%ymm4,%ymm6 + 0.25 │ vmulpd %ymm0,%ymm5,%ymm7 + 0.51 │ vmovapd 0xc0(%rdi),%ymm0 + 1.81 │ vaddpd %ymm11,%ymm6,%ymm11 + 1.65 │ vaddpd %ymm9,%ymm7,%ymm9 + 0.71 │ vmulpd %ymm1,%ymm4,%ymm6 + 0.33 │ vmulpd %ymm1,%ymm5,%ymm7 + 0.77 │ vaddpd %ymm10,%ymm6,%ymm10 + 2.08 │ vaddpd %ymm8,%ymm7,%ymm8 + 0.86 │ vmovapd 0xe0(%rdi),%ymm1 + 0.85 │ vmulpd %ymm0,%ymm2,%ymm6 + 0.85 │ vperm2f128 $0x3,%ymm2,%ymm2,%ymm4 + 0.97 │ vmulpd %ymm0,%ymm3,%ymm7 + 0.85 │ vperm2f128 $0x3,%ymm3,%ymm3,%ymm5 + 0.22 │ add $0x100,%rdi + 0.38 │ vaddpd %ymm15,%ymm6,%ymm15 + 1.62 │ vaddpd %ymm13,%ymm7,%ymm13 + 1.12 │ prefetcht0 0x2c0(%rdi) + 0.23 │ vmulpd %ymm1,%ymm2,%ymm6 + 0.80 │ vmovapd (%rsi),%ymm2 +``` + +Most of the instructions there are SIMD instructions. + +If the reader is interested in knowing how those instructions are used, they can have a look at [`OpenBLAS/kernel/x84_64/dgemm_kernel_4x8_sandy.S`](https://github.com/xianyi/OpenBLAS/blob/8cec83bdfb82effda2075309af5ca36df79f1a8e/kernel/x86_64/dgemm_kernel_4x8_sandy.S) which comes which a setup of compilers' macros to define the computations at a high-level in assembly. + +**On L3 cache hits and L3 cache misses** + +One can inspect the report of the `mem_load_uops_retired.llc_miss` events for L3 cache misses[ref]`"llc"` in `"llc_miss"` stands for "last level cache", which on most architectures is the L3 -- i.e. third level -- cache.[/ref]: +``` +Samples: 88 of event 'mem_load_uops_retired.llc_miss:u', Event count (approx.): +543K cycles:u ▒ +- 100.00% python ▒ + - 82.95% libopenblasp-r0.3.18.so ▒ + 81.82% [.] dgemm_incopy_SANDYBRIDGE ▒ + 1.14% [.] dgemm_kernel_SANDYBRIDGE ▒ + + 7.95% [unknown] ▒ + - 6.82% _pairwise_distances_reduction.cpython-39-x86_64-linux-gnu.▒ + 6.82% [.] __pyx_f_7sklearn_7metrics_29_pairwise_distances_red▒ + + 2.27% python3.9 ▒ + ▒ +``` + +One can inspect the report of the `mem_load_uops_retired.llc_hit` events for L3 cache hits: +``` +Samples: 984 of event 'mem_load_uops_retired.llc_hit:u', Event count (approx.): +543K cycles:u ▒ +- 100.00% python ▒ + - 66.26% libopenblasp-r0.3.18.so ▒ + 31.00% [.] dgemm_kernel_SANDYBRIDGE ▒ + 19.21% [.] dgemm_incopy_SANDYBRIDGE ▒ + 10.16% [.] dgemm_oncopy_SANDYBRIDGE ▒ + 3.66% [.] dgemm_tn ▒ + 1.12% [.] blas_memory_alloc ▒ + 0.51% [.] dgemm_ ▒ + 0.41% [.] blas_memory_free ▒ + 0.20% [.] dgemm_beta_SANDYBRIDGE ▒ + + 16.36% [unknown] ▒ + + 8.84% python3.9 ▒ + - 5.08% _pairwise_distances_reduction.cpython-39-x86_64-linux-gnu.▒ + 4.98% [.] __pyx_f_7sklearn_7metrics_29_pairwise_distances_red▒ + 0.10% [.] __pyx_memoryview_slice_memviewslice ▒ + + 1.83% _heap.cpython-39-x86_64-linux-gnu.so ▒ + + 0.71% libpthread-2.28.so ▒ + + 0.30% ld-2.28.so ▒ + + 0.30% libc-2.28.so ▒ + + 0.20% _cython_blas.cpython-39-x86_64-linux-gnu.so ▒ +``` + +The L3 cache hits and misses happens exactly where we ought them to -- that is in the +critical region computing chunks of the distance matrix with GEMM. + +In the critical region, one instruction out of ten[ref]This is a rough estimation based on the number +of sampled events, namely 984 for L3 cache hits and 88 for L3 caches misses.[/ref] is missing the L3 cache, +showing that the data-structures used to compute the chunks of the distance matrix generally stay the +L3 caches as intended[ref]For maximum performance, one can adapt $\text{chunk_size}$ for the L3 cache size of the +machine they use. This can be done changing the `pairwise_dist_chunk_size` option with `sklearn.set_config`.[/ref]. + +### Conclusion + +In what we just have covered: + + - The computations scale linearly with respect to the number of threads used, reaching theoretical limits. + - The interactions with CPython interpreter are minimized. + - The L3 caches are properly used. + - SIMD instructions are effectively used in critical sections. + +Hence, this shows that the parallel execution of the algorithm is efficient[ref]If this can be made more efficient, feel free to propose in another dedicated PR![/ref]. + +## float32 datasets pairs support for `PairwiseDistancesReduction` + +### Design + +The implementation whose details have been covered hereinbefore only address the case of pair of float64 datasets pairs. The support for to 32bit datasets pairs can be addressed using [Tempita](https://pyrocore.readthedocs.io/en/latest/tempita.html) so as to expand the previous interfaces support for float64 to float32[ref]Cython does not support templating but Tempita allows treating most cases needing it.[/ref]. The full design proposal and performance assessement is given +in +[`sklearn#22590`](https://github.com/scikit-learn/scikit-learn/pull/22590). + +### Hardware scalability + +The current experimentations show that the port of `PairwiseDistancesArgKmin` for 32bit datasets also has a good hardware scalability: + +![Hardware scalability of `PairwiseDistancesReductionArgKmin` on 32bit datasets](https://user-images.githubusercontent.com/13029839/155859972-637795e7-b959-4cba-afcc-369b0e84d92e.png) + +Its integration first necessitates adapting the test suite for 32bit datasets. + +## `PairwiseDistancesRadiusNeighborhood`: a concrete `PairwiseDistancesReductions` for radius-based querying + +### Design + +The reductions for the radius neighborhood queries can efficiently be implemented using resizable buffers. In Cython, this can easily be implemented using `std::vectors`, with [some adaptation to return them as numpy arrays safely](https://github.com/cython/cython/issues/4487). This has been implemented in [`sklearn#22320`](https://github.com/scikit-learn/scikit-learn/pull/22320). + +### Hardware scalability + +The implementation offer a better hardware scalability than the previous one: + +![Hardware scalability of `PairwiseDistancesRadiusNeighborhood` without mimalloc](https://user-images.githubusercontent.com/13029839/155114222-f6d0cc14-786b-4c3b-9bdb-c4a46ef7a944.png) + +Yet, this new implementation suffers from concurrent reallocation in threads, namely when vectors' buffers are being reallocated when new elemented are pushed-back. This concurrent reallocation causes some drops in performance as calls to `malloc(3)` (used under the hood for reallocations of `std::vectors`' buffers) lock by default in the compilers' standard libraries' implementations[ref]This is for instance the case in `malloc_state`, one of the main C structures in [the implementation of +`malloc(3)` +in `glibc`](https://sourceware.org/git/?p=glibc.git;a=blob;f=malloc/malloc.c;hb=HEAD#l1832).[/ref]. + +A simple alleviation for this is to use another implementation of `malloc(3)` such as [`mimalloc`](https://www.microsoft.com/en-us/research/publication/mimalloc-free-list-sharding-in-action/)'s[ref]For more information, see [this gist](https://gist.github.com/jjerphan/17d38a21a85931b448886087b11d2d19).[/ref], which limits race conditions in threads and thus improve the hardware scalability: + +![Hardware scalability of `PairwiseDistancesRadiusNeighborhood` with mimalloc](https://user-images.githubusercontent.com/13029839/155114219-6c2d5434-52fd-4b22-a0dc-5f6655fae639.png) + +## Subsequent work + +Ideas of subsequent work are listed here on [`sklearn#25888`](https://github.com/scikit-learn/scikit-learn/issues/25888). + +Finally, many things can be imagined for the second point. Some other and similar patterns using Gram matrices of positive definite kernels[ref]Hofmann, Thomas and Schölkopf, Bernhard and Smola, Alexander J., Kernel methods in machine learning. DOI: [`https://dx.doi.org/10.1214/009053607000000677`](https://dx.doi.org/10.1214/009053607000000677)[/ref] instead of distance matrices exist and could be optimised. + +--- + +## Notes diff --git a/_posts/sklearn-perf-pdr.md b/_posts/sklearn-perf-pdr.md new file mode 100644 index 0000000..2a36217 --- /dev/null +++ b/_posts/sklearn-perf-pdr.md @@ -0,0 +1,188 @@ +Title: Performance and scikit-learn (3/4) +Date: 2021-12-18 +Category: scikit-learn +Slug: sklearn-perf-pdr +Lang: en +Authors: Julien Jerphanion +Summary: Pairwise Distances Reductions: Abstracting the k-nn search pattern +Status: Published + + +## Context + +We have seen that $\text{argkmin}$ is the reduction that is performed on pairwise distances for $k$-nearest neighbors search. + +Yet, there exist other reductions over pairwise distances ($\text{argmin}$, threshold filtering, cumulative sum, etc.) which are at the core of the computational foundations of many machine learning algorithms. + +This blog post presents a design that takes into account the requirements of the existing implementations +to introduce a set of new abstractions to implement reductions over pairwise distances: `PairwiseDistancesReduction`. This set of interfaces aims at reimplementing patterns that are similar to the $k$-nn search in Cython, to improve the +performance of its computational foundations, and thus the ones of its user-facing interfaces. + +To our knowledge, though some projects like [KeOps](https://www.kernel-operations.io/keops/index.html) implement those patterns efficiently for GPUs, no project implements such operations for CPUs efficiently. + +> 💡 This blog post won't introduce every technical detail for the sake of conciseness, maintenance, and to respect the [single source of truth principle](https://en.wikipedia.org/wiki/Single_source_of_truth) as much as possible. The implementations are available in the [`sklearn.metrics._pairwise_distances_reduction`](https://github.com/scikit-learn/scikit-learn/tree/main/sklearn/metrics/_pairwise_distances_reduction) submodule. + +> 💡 [This presentation](https://docs.google.com/presentation/d/1RwX_P9lnsb9_YRZ0cA88l3VoEYhiMndQYoKLLF_0Dv0/edit?usp=sharing) gives elements of the design of `PairwiseDistancesReductions`. + +## Notation + +In what follows, the following notations are used: + + - $p$: the dimension of vectors + - $[n] \triangleq \{0, \cdots, n - 1\}$ + - $\mathbf{X} \in \mathbb{R}^{n_x \times p}$: a first dataset + - $\mathbf{X}_{i\cdot} \in \mathbb{R}^{p}$: the $i$-th vector of $\mathbf{X}$ + - $\mathbf{Y} \in \mathbb{R}^{n_y \times p}$: a second dataset + - $\mathbf{Y}_{j\cdot} \in \mathbb{R}^{p}$: the $j$-th vector of $\mathbf{Y}$ + - $c$: the chunk size, i.e. the number of vectors in a chunk (a group of adjacent vectors) + - $c_x \triangleq \left\lceil \frac{n_x}{c} \right\rceil$, the number of chunks for $\mathbf{X}$ + - $c_y \triangleq \left\lceil \frac{n_y}{c} \right\rceil$, the number of chunks for $\mathbf{Y}$ + - $(\mathbf{X}_c^{(l)})_{l \in [c_x]}$: the ordered family of all the chunks of $\mathbf{X}$ + - $(\mathbf{Y}_c^{(k)})_{k \in [c_y]}$: the ordered family of all the chunks of $\mathbf{Y}$ + - $\mathbf{C}_\text{chunk_size}\mathbf{(X, Y)} \triangleq \left(\mathbf{X}_c^{(l)}, \mathbf{Y}_c^{(k)}\right)_{(l,k) \in [c_x] \times [c_y] }$: the ordered family of all the pairs of chunks + - $d$, the distance metric to use + +$$ +d: \mathbb{R}^{p} \times \mathbb{R}^{p} \longrightarrow \mathbb{R}_+ +$$ + + - $\mathbf{D}_d(\mathbf{A}, \mathbf{B}) \in \mathbf{R}^{n_a \times n_b}$ the distance matrix for $d$ between vectors of two matrices $\mathbf{A} \in \mathbb{R}^{n_a \times p}$ and $\mathbf{B} \in \mathbb{R}^{n_b \times p}$: + +$$ +\forall (i, j) \in [n_a]\times [n_b], \quad \mathbf{D}_d(\mathbf{A}, \mathbf{B})_{i,j} = d\left(\mathbf{A}_i, \mathbf{B}_j\right) +$$ + + - $k$: parameter for the $\text{argkmin}$ operation at the base of $k$ nearest neighbors search + + +Moreover, the terms "samples" and "vectors" will also be used interchangeably. + +## Requirements for reductions over pairwise distances + +The following requirements are currently supported within scikit-learn's implementations: + + - Support for float32 datasets pairs and float64 datasets pairs + - Support for fused $\{\text{sparse}, \text{dense}\}^2$ datasets pairs, i.e.: + - dense $\mathbf{X}$ and dense $\mathbf{Y}$ + - sparse $\mathbf{X}$ and dense $\mathbf{Y}$ + - dense $\mathbf{X}$ and sparse $\mathbf{Y}$ + - sparse $\mathbf{X}$ and sparse $\mathbf{Y}$ + - Support all the distance metrics as defined via [`sklearn.metrics.DistanceMetric`](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.DistanceMetric.html) + - Parallelise computations effectively on all cores + - Prevent threads' oversubscription[ref]Threads' oversubscription happens when threads are spawned at various levels of parallelism, causing the OS to use more threads than necessary for the execution of the program to be optimal.[/ref] (by OpenMP, joblib, or any BLAS implementations) + - Implement adapted operations for each reduction ($\text{argmin}$, $\text{argkmin}$, threshold filtering, cumulative sum, etc.) + - Support generic returned values for reductions (varying number, varying types, varying shapes, etc.) + - Optimise the Euclidean distance metric computations + +## Proposed design + +The following design proposes treating the given requirements as independently from one another as possible. + +### `DatasetsPair`: an abstract class for manipulating datasets[ref]We use the term "abstract class" here to talk about the design: no such concept exists in Cython.[/ref] + +This allows: + + - Supporting float32 datasets pairs and float64 datasets pairs + - Supporting fused $\{\text{sparse}, \text{dense}\}^2$ datasets pairs via concrete implementation, i.e.: + - `DenseDenseDatasetsPair` + - `SparseDenseDatasetsPair` + - `DenseSparseDatasetsPair` + - `SparseSparseDatasetsPair` + - Supporting all the distance metrics as defined via [`sklearn.metrics.DistanceMetric`](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.DistanceMetric.html) + +Internally, a `DatasetsPair` wraps $(\mathbf{X}, \mathbf{Y}, d)$ and exposes an interface which allows computing $d(\mathbf{X}_{i\cdot}, \mathbf{Y}_{j\cdot})$ for a given tuple $(i, j)$. + +### `PairwiseDistancesReduction`: an abstract class defining parallelization templates + +This allows: + + - Parallelising computations effectively on all cores + - Preventing threads' oversubscription (by OpenMP, joblib, or any BLAS implementations) + - Supporting generic returned values for reductions (varying number, varying types, varying shapes, etc.) + +This is made possible by: + + - setting up a general interface that performs the parallelization of computations on $\mathbf{C}_\text{chunk_size}\mathbf{(X, Y)}$: two strategies of parallelization are implemented as it's worth parallelizing on $\mathbf{X}$ +or on $\mathbf{Y}$ depending on the context. To choose one or the other strategy, a simple heuristic comparing $c_x$ and $c_y$ about the number of available threads is used and is sufficient. + - using a [`threadpoolctl.threadpool_limits` context](https://github.com/joblib/threadpoolctl#setting-the-maximum-size-of-thread-pools) at the start of the execution of the generic parallel template + - having a flexible Python interface to return results and have the parallel computations be defined agnostically from the data structures being modified in concrete classes[ref]A set of template methods are defined so as to have concrete implementations modify data structures when and where needed.[/ref]. + +The critical areas of the computations -- that is the computations of the chunk of the distance matrix associated to $\mathbf{C}_\text{chunk_size}\mathbf{(X, Y)}$ and its reduction -- is made abstract. This way, when defining a concrete `PairwiseDistancesReduction`, a sole method is to define up to some eventual python helpers methods[ref]If you are looking for the concrete implementations' critical regions, look for +`_compute_and_reduce_distances_on_chunks`.[/ref]. + + +### `PairwiseDistancesReductionArgKmin`: a first concrete `PairwiseDistancesReduction` for $\text{argkmin}$ + +For this reduction, one can use [max-heaps](https://en.wikipedia.org/wiki/Heap_(data_structure)) which are by design doing the work of +keeping the first $k$ minimum values with their indices. scikit-learn +current implementation of max-heaps is simple, readable and efficient[ref]Thanks to [Jake VanDerplas](https://vanderplas.com/)![/ref] and can be used to manipulate the data structures that we need[ref]We mainly use heap-allocated buffers that we manipulate through pointers and offsets at the lowest level of this new implementation for maximum efficiency.[/ref]. + +### Specialising reductions for the Euclidean distance metric + +Generally, distances associated with neighbors aren't returned to the user. This allows some optimization. + +In the case of the Euclidean distance metric, one can use the Squared Euclidean distance metric as a proxy: it is less costly, it preserves ordering and it can be computed efficiently. + +Indeed, $\mathbf{D}^{\odot 2}_d(\mathbf{X}_c^{(l)}, \mathbf{Y}_c^{(k)})$ -- the element-wise squared version of the $(l,k)$-th chunk of $\mathbf{D}_d(\mathbf{X}, \mathbf{Y})$ -- can be computed as follows: + +$$ +\mathbf{D}^{\odot 2}_d(\mathbf{X}_c^{(l)}, \mathbf{Y}_c^{(k)}) \triangleq \left[\Vert \mathbf{X}_{i\cdot}^{(l)} - \mathbf{Y}_{j\cdot}^{(k)} \Vert^2_2\right]_{(i,j)\in [c]^2} = \left[\Vert \mathbf{X}_{i\cdot}^{(l)}\Vert^2_2 \right]_{(i,j)\in [c]^2} + \left[\Vert \mathbf{Y}_{j\cdot }^{(k)}\Vert^2_2 \right]_{(i, j)\in [c]^2} - 2 \mathbf{X}^{(l)} {\mathbf{Y}^{(k)}}^\top +$$ + +This allows using two optimizations: + + 1. $\left[\Vert \mathbf{X}_{i\cdot}\Vert_2^2\right]_{i \in [n_x]}$ and $\left[\Vert \mathbf{Y}_{j\cdot}\Vert_2^2\right]_{j \in [n_y]}$ can be computed once and for all at the start and be cached. Those two vectors will be reused on each chunk of the distance matrix. + + 2. More importantly, $- 2 \mathbf{X}^{(l)} {\mathbf{Y}^{(k)}}^\top$ can be computed using the [GEneral Matrix Multiplication from BLAS Level 3](https://www.netlib.org/lapack/explore-html/d1/d54/group__double__blas__level3_gaeda3cbd99c8fb834a60a6412878226e1.html) -- hereinafter refered to as GEMM. This allows getting the maximum arithmetic intensity for the computations, making use of recent BLAS back-ends implementing vectorised kernels, such as + [OpenBLAS](https://www.openblas.net/). + + +For instance `EuclideanPairwiseDistancesArgkmin` is the main specialization of `PairwiseDistancesArgkmin` for the +Euclidean distance metric. This specialization solely recomputes the actual Euclidean distances when the caller asked them to be returned. + + +### Interfacing `PairwiseDistancesReductions` with scikit-learn's algorithms + +As of now, the overall design was covered without mentioning ways they can be plugged with the existing scikit-learn algorithms, progressively migrating most algorithms' back-end to those new implementations. + +Furthermore, in the future, specialized implementations for various vendors of CPUs and GPUs can be created. +In this case, we want to have such specialized implementations separated from scikit-learn source code (e.g. by having them in optional and vendor-specific packages) so as to keep `PairwiseDistancesReductions` interfaces vendor-specialization-agnostic but still be able to dispatch the computations to the most adapted and available implementations. + +To touch two birds with one tiny stone[ref]Disclaimer: during this work, no animal was killed, nor hurt; nor are and nor will.[/ref], the new implementations can be used conditionally to the yet-supported cases based on provided datasets and executed agnostically from them. + +This can be implemented by a `PairwiseDistancesReduction.{is_usable_for,compute}` pattern: + + - `PairwiseDistancesReduction.is_usable_for` returns `True` if any implementation for the provided $(\mathbf{X}, \mathbf{Y}, d)$ can be used. If none is available, the caller can default to the current implementation within scikit-learn. + - `PairwiseDistancesReduction.compute` returns the results of the reduction. Internally, it is responsible for choosing the most appropriate implementation prior to executing it. + +In this context, aforementioned vendor-specific packages could register custom implementations explicitly (i.e. with a python context manager as suggested by Olivier Grisel) or implicitly (by some package reflection when importing relevant interfaces). + +## Implementing the design + +A few first experiments have been made and converged to [`sklearn#22134`](https://github.com/scikit-learn/scikit-learn/pull/22134), a contribution that proposes integrating the previous interfaces progressively via a feature branch. + +## Future work + +Further work would treat the last requirements: + + - Support for 32 bits datasets pairs + - Support for the last fused $\{\text{sparse}, \text{dense}\}^2$ datasets pairs, i.e.: + - sparse $\mathbf{X}$ and dense $\mathbf{Y}$ + - dense $\mathbf{X}$ and sparse $\mathbf{Y}$ + - sparse $\mathbf{X}$ and sparse $\mathbf{Y}$ + - Implement adapted operations for each reduction (radius neighborhood, threshold filtering, cumulative sum, etc.) + +> If you are interested in reading more about this, read [this section from the extra notes](sklearn-perf-pdr-extra.html). + +## Acknowledgement + +This was a joint work with other core-developers -- namely [Olivier Grisel](https://ogrisel.com/), [Jérémie du Boisberranger](https://github.com/jeremiedbb), [Thomas J. Fan](https://www.thomasjpfan.com/) and [Christian Lorentzen](https://github.com/lorentzenchr). + +Finally and more importantly, the implementations presented here are made possible thanks to other notable open-source +projects, especially Cython but also OpenBLAS, which provides fast vectorized kernels implemented in C and assembly for BLAS. + + +--- + +## Notes + diff --git a/_posts/sklearn-perf.md b/_posts/sklearn-perf.md new file mode 100644 index 0000000..fd48003 --- /dev/null +++ b/_posts/sklearn-perf.md @@ -0,0 +1,36 @@ +Title: Performance and scikit-learn (0/4) +Date: 2021-12-15 +Category: scikit-learn +Slug: sklearn-perf +Lang: en +Authors: Julien Jerphanion +Summary: +Status: Published + +For more than 10 years, scikit-learn has been bringing machine learning and +data science to the world. Since then, the library always aimed at delivering +quality implementations to its users. + +This series of blog post aims at explaining the ongoing work of the +scikit-learn developers to improve the performance of the library by several +orders of magnitude. + +This series should be read as follows: + + - [Context: the current state of scikit-learn](sklearn-perf-context.html) + - [Hardware scalability issue: the k-neighbors search example](sklearn-perf-knn.html) + - [Pairwise Distances Reduction: abstracting the $k$-nn search pattern](sklearn-perf-pdr.html) + - [Pairwise Distances Reduction: extra notes](sklearn-perf-pdr-extra.html) + +Knowing about the following topics can help understand the blog posts: + + - the main algorithms in machine learning, especially $k$-nearest neighbors + - basic datastructures and algorithms complexity + - RAM and the hierarchy of CPU caches + - some elements of linear algebra + - some elements of object-oriented design (abstract class, template methods) + - some elements of C programming (allocation on the heap, pointer arithmetic) + - some elements of OpenMP (static scheduling and parallel for-loop) + - Cython + +