Skip to content

Commit 8aed2d7

Browse files
authored
[MRG] New Quickstart guide and revamp User guide (#726)
* rrme to user guide and add quicksrat script * remoe import of deprecated module * premier jet quckstart guide * working on the guide * add stuff * comment unbalanced Gromov * rename section * first shot done * better verison quickstart guide * better verison quickstart guide * fix doc * cleanup exmaple * fix? * add test * remove sentence * call it the unified vs classic API
1 parent 928a67a commit 8aed2d7

File tree

9 files changed

+694
-40
lines changed

9 files changed

+694
-40
lines changed

Diff for: RELEASES.md

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
- Added `ot.gaussian.bures_wasserstein_distance` (PR #680)
1616
- `ot.gaussian.bures_wasserstein_distance` can be batched (PR #680)
1717
- Backend implementation of `ot.dist` for (PR #701)
18+
- Updated documentation Quickstart guide and User guide with new API (PR #726)
1819

1920
#### Closed issues
2021
- Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)

Diff for: docs/source/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def __getattr__(cls, name):
347347
}
348348

349349
sphinx_gallery_conf = {
350-
"examples_dirs": ["../../examples", "../../examples/da"],
350+
"examples_dirs": ["../../examples"],
351351
"gallery_dirs": "auto_examples",
352352
"filename_pattern": "plot_", # (?!barycenter_fgw)
353353
"nested_sections": False,

Diff for: docs/source/index.rst

+3-2
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@ Contents
1717
:maxdepth: 1
1818

1919
self
20-
quickstart
21-
all
20+
auto_examples/plot_quickstart_guide
2221
auto_examples/index
22+
user_guide
23+
all
2324
releases
2425
contributors
2526
contributing

Diff for: docs/source/quickstart.rst renamed to docs/source/user_guide.rst

+35-35
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

2-
Quick start guide
3-
=================
2+
User guide
3+
==========
44

55
In the following we provide some pointers about which functions and classes
66
to use for different problems related to optimal transport (OT) and machine
@@ -136,12 +136,12 @@ instance the memory cost for an OT problem is always :math:`\mathcal{O}(n^2)` in
136136
memory because the cost matrix has to be computed. The exact solver in of time
137137
complexity :math:`\mathcal{O}(n^3\log(n))` and the Sinkhorn solver has been
138138
proven to be nearly :math:`\mathcal{O}(n^2)` which is still too complex for very
139-
large scale solvers.
139+
large scale solvers. For all the generic solvers we need to compute the cost
140+
matrix and the OT matrix of memory size :math:`\mathcal{O}(n^2)` which can be
141+
prohibitive for very large scale problems.
140142

141-
142-
If you need to solve OT with large number of samples, we recommend to use
143-
entropic regularization and memory efficient implementation of Sinkhorn as
144-
proposed in `GeomLoss <https://www.kernel-operations.io/geomloss/>`_. This
143+
If you need to solve OT with large number of samples, we provide "lazy" memory efficient implementation of Sinkhorn in pure
144+
python and using `GeomLoss <https://www.kernel-operations.io/geomloss/>`_. This
145145
implementation is compatible with Pytorch and can handle large number of
146146
samples. Another approach to estimate the Wasserstein distance for very large
147147
number of sample is to use the trick from `Wasserstein GAN
@@ -193,15 +193,19 @@ that will return the optimal transport matrix :math:`\gamma^*`:
193193
194194
# a and b are 1D histograms (sum to 1 and positive)
195195
# M is the ground cost matrix
196+
197+
# unified API
198+
T = ot.solve(M, a, b).plan # exact linear program
199+
200+
# classical API
196201
T = ot.emd(a, b, M) # exact linear program
197202
198203
The method implemented for solving the OT problem is the network simplex. It is
199204
implemented in C from [1]_. It has a complexity of :math:`O(n^3)` but the
200205
solver is quite efficient and uses sparsity of the solution.
201206

202207

203-
204-
.. minigallery:: ot.emd
208+
.. minigallery:: ot.emd, ot.solve
205209
:add-heading: Examples of use for :any:`ot.emd`
206210
:heading-level: "
207211

@@ -226,7 +230,12 @@ It can computed from an already estimated OT matrix with
226230
227231
# a and b are 1D histograms (sum to 1 and positive)
228232
# M is the ground cost matrix
229-
W = ot.emd2(a, b, M) # Wasserstein distance / EMD value
233+
234+
# Wasserstein distance / EMD value with unified API
235+
W = ot.solve(M, a, b, return_matrix=False).value
236+
237+
# with classical API
238+
W = ot.emd2(a, b, M)
230239
231240
Note that the well known `Wasserstein distance
232241
<https://en.wikipedia.org/wiki/Wasserstein_metric>`_ between distributions a and
@@ -246,7 +255,7 @@ the :math:`W_1` Wasserstein distance can be done directly with :any:`ot.emd2`
246255
when providing :code:`M = ot.dist(xs, xt, metric='euclidean')` to use the Euclidean
247256
distance.
248257

249-
.. minigallery:: ot.emd2
258+
.. minigallery:: ot.emd2, ot.solve
250259
:add-heading: Examples of use for :any:`ot.emd2`
251260
:heading-level: "
252261

@@ -274,6 +283,10 @@ distributions. In the case when the finite sample dataset is supposed Gaussian,
274283
we provide :any:`ot.gaussian.bures_wasserstein_mapping` that returns the parameters for the
275284
Monge mapping.
276285

286+
All those special cases are accessible with the unified API of POT through the
287+
function :any:`ot.solve_sample` with the parameter :code:`method` that allows to
288+
choose the method used to solve the problem (with :code:`method='1D'` or :code:`method='gaussian'`).
289+
277290

278291
Regularized Optimal Transport
279292
-----------------------------
@@ -330,13 +343,15 @@ The Sinkhorn-Knopp algorithm is implemented in :any:`ot.sinkhorn` and
330343
linear term. Note that the regularization parameter :math:`\lambda` in the
331344
equation above is given to those functions with the parameter :code:`reg`.
332345

333-
>>> import ot
334-
>>> a = [.5, .5]
335-
>>> b = [.5, .5]
336-
>>> M = [[0., 1.], [1., 0.]]
337-
>>> ot.sinkhorn(a, b, M, 1)
338-
array([[ 0.36552929, 0.13447071],
339-
[ 0.13447071, 0.36552929]])
346+
.. code:: python
347+
348+
# unified API
349+
P = ot.solve(M, a, b, reg=1).plan # OT Sinkhorn matrix
350+
loss = ot.solve(M, a, b, reg=1).value # OT Sinkhorn value
351+
352+
# classical API
353+
P = ot.sinkhorn(a, b, M, reg=1) # OT Sinkhorn matrix
354+
loss = ot.sinkhorn2(a, b, M, reg=1) # OT Sinkhorn value
340355
341356
More details about the algorithms used are given in the following note.
342357

@@ -406,13 +421,10 @@ implementations are not optimized for speed but provide a robust implementation
406421
of algorithms in [18]_ [19]_.
407422

408423

409-
.. minigallery:: ot.sinkhorn
410-
:add-heading: Examples of use for :any:`ot.sinkhorn`
424+
.. minigallery:: ot.sinkhorn ot.sinkhorn2
425+
:add-heading: Examples of use for Sinkhorn algorithm
411426
:heading-level: "
412427

413-
.. minigallery:: ot.sinkhorn2
414-
:add-heading: Examples of use for :any:`ot.sinkhorn2`
415-
:heading-level: "
416428

417429

418430
Other regularizations
@@ -969,18 +981,6 @@ For instance, to disable TensorFlow, set `export POT_BACKEND_DISABLE_TENSORFLOW=
969981
It's important to note that the `numpy` backend cannot be disabled.
970982

971983

972-
List of compatible modules
973-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
974-
975-
This list will get longer for new releases and will hopefully disappear when POT
976-
become fully implemented with the backend.
977-
978-
- :any:`ot.bregman`
979-
- :any:`ot.gromov` (some functions use CPU only solvers with copy overhead)
980-
- :any:`ot.optim` (some functions use CPU only solvers with copy overhead)
981-
- :any:`ot.sliced`
982-
- :any:`ot.utils` (partial)
983-
984984

985985
FAQ
986986
---

Diff for: examples/plot_OT_2D_samples.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565

6666
# %% EMD
6767

68-
G0 = ot.emd(a, b, M)
68+
G0 = ot.solve(M, a, b).plan
6969

7070
pl.figure(3)
7171
pl.imshow(G0, interpolation="nearest")

0 commit comments

Comments
 (0)