diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index e6719d7..b3d9efc 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -1,7 +1,9 @@ name: Build and upload to PyPI on: - push + push: + branches: + - main jobs: build_wheels: diff --git a/.github/workflows/tests_full.yml b/.github/workflows/tests_full.yml index 0865876..b5e096a 100644 --- a/.github/workflows/tests_full.yml +++ b/.github/workflows/tests_full.yml @@ -1,4 +1,4 @@ -name: tests +name: full test on: push: @@ -102,3 +102,32 @@ jobs: - name: Run tests run: pytest -vs tests/test_torch.py + + test-jax: + needs: build + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ windows-2019, ubuntu-18.04, macos-11 ] + python-version: [ 3.6, 3.7, 3.8, 3.9 ] + + steps: + - uses: actions/checkout@v1 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Download artifact + uses: actions/download-artifact@master + with: + name: "Python wheel" + + - name: Install dependencies + run: pip install jax jaxlib opencv-python-headless scikit-image pytest + + - name: Install wheel + run: pip install --find-links=${{github.workspace}} torchstain + + - name: Run tests + run: pytest -vs tests/test_jax.py diff --git a/.github/workflows/tests_quick.yml b/.github/workflows/tests_quick.yml index 3e08828..9f82d5c 100644 --- a/.github/workflows/tests_quick.yml +++ b/.github/workflows/tests_quick.yml @@ -1,4 +1,4 @@ -name: tests +name: quick tests on: push: @@ -84,3 +84,29 @@ jobs: - name: Run tests run: pytest -vs tests/test_torch.py + + + test-jax: + needs: build + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v1 + - name: Set up Python 3.7 + uses: actions/setup-python@v2 + with: + python-version: 3.7 + + - name: Download artifact + uses: actions/download-artifact@master + with: + name: "Python wheel" + + - name: Install dependencies + run: pip install jax jaxlib opencv-python-headless scikit-image pytest + + - name: Install wheel + run: pip install --find-links=${{github.workspace}} torchstain + + - name: Run tests + run: pytest -vs tests/test_jax.py diff --git a/README.md b/README.md index 74af142..4c793b2 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [![Pip Downloads](https://img.shields.io/pypi/dm/torchstain?label=pip%20downloads&logo=python)](https://pypi.org/project/torchstain/) [![DOI](https://zenodo.org/badge/323590093.svg)](https://zenodo.org/badge/latestdoi/323590093) -GPU-accelerated stain normalization tools for histopathological images. Compatible with PyTorch, TensorFlow, and Numpy. +GPU-accelerated stain normalization tools for histopathological images. Compatible with PyTorch, TensorFlow, Numpy, and JAX. Normalization algorithms currently implemented: - Macenko [\[1\]](#reference) (ported from [numpy implementation](https://github.com/schaugf/HEnorm_python)) @@ -47,11 +47,11 @@ norm, H, E = normalizer.normalize(I=t_to_transform, stains=True) ## Implemented algorithms -| Algorithm | numpy | torch | tensorflow | -|-|-|-|-| -| Macenko | ✓ | ✓ | ✓ | -| Reinhard | ✓ | ✓ | ✓ | -| Modified Reinhard | ✓ | ✓ | ✓ | +| Algorithm | numpy | torch | tensorflow | jax | +|-|-|-|-|-| +| Macenko | ✓ | ✓ | ✓ | ✓ | +| Reinhard | ✓ | ✓ | ✓ | ✗ | +| Modified Reinhard | ✓ | ✓ | ✓ | ✗ | ## Backend comparison diff --git a/setup.py b/setup.py index d27acf7..cb0c60b 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ extras_require={ "tf": ["tensorflow"], "torch": ["torch"], + "jax": ["jax", "jaxlib"], }, classifiers=[ 'Development Status :: 4 - Beta', diff --git a/tests/test_jax.py b/tests/test_jax.py new file mode 100644 index 0000000..e5a033f --- /dev/null +++ b/tests/test_jax.py @@ -0,0 +1,32 @@ +import os +import cv2 +import torchstain +import torchstain.jax +import time +from skimage.metrics import structural_similarity as ssim +import numpy as np +from jax import numpy as jnp + +def test_macenko_jax(): + size = 1024 + curr_file_path = os.path.dirname(os.path.realpath(__file__)) + target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size)) + to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size)) + + # initialize normalizers for each backend and fit to target image + normalizer = torchstain.normalizers.MacenkoNormalizer(backend='numpy') + normalizer.fit(target) + + jax_normalizer = torchstain.normalizers.MacenkoNormalizer(backend='jax') + jax_normalizer.fit(target) + + # transform + result_numpy, _, _ = normalizer.normalize(I=to_transform) + result_jax, _, _ = jax_normalizer.normalize(I=to_transform) + + # convert to numpy and set dtype + result_numpy = result_numpy.astype("float32") + result_jax = np.asarray(result_jax).astype("float32") + + # assess whether the normalized images are identical across backends + np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_jax.flatten()), 1.0, decimal=4, verbose=True) diff --git a/torchstain/base/normalizers/macenko.py b/torchstain/base/normalizers/macenko.py index 2c0c988..db806e3 100644 --- a/torchstain/base/normalizers/macenko.py +++ b/torchstain/base/normalizers/macenko.py @@ -8,5 +8,8 @@ def MacenkoNormalizer(backend='torch'): elif backend == "tensorflow": from torchstain.tf.normalizers.macenko import TensorFlowMacenkoNormalizer return TensorFlowMacenkoNormalizer() + elif backend == "jax": + from torchstain.jax.normalizers.macenko import JaxMacenkoNormalizer + return JaxMacenkoNormalizer() else: raise Exception(f'Unknown backend {backend}') diff --git a/torchstain/jax/__init__.py b/torchstain/jax/__init__.py new file mode 100644 index 0000000..c36ce2a --- /dev/null +++ b/torchstain/jax/__init__.py @@ -0,0 +1 @@ +from torchstain.jax import normalizers, utils diff --git a/torchstain/jax/normalizers/__init__.py b/torchstain/jax/normalizers/__init__.py new file mode 100644 index 0000000..de86fc2 --- /dev/null +++ b/torchstain/jax/normalizers/__init__.py @@ -0,0 +1 @@ +from torchstain.jax.normalizers.macenko import JaxMacenkoNormalizer diff --git a/torchstain/jax/normalizers/macenko.py b/torchstain/jax/normalizers/macenko.py new file mode 100644 index 0000000..9cd4379 --- /dev/null +++ b/torchstain/jax/normalizers/macenko.py @@ -0,0 +1,112 @@ +import jax +from jax import lax +from jax import numpy as jnp +from torchstain.base.normalizers import HENormalizer +from jax import tree_util + +class JaxMacenkoNormalizer(HENormalizer): + def __init__(self): + super().__init__() + + self.HERef = jnp.array([[0.5626, 0.2159], + [0.7201, 0.8012], + [0.4062, 0.5581]], dtype=jnp.float32) + self.maxCRef = jnp.array([1.9705, 1.0308], dtype=jnp.float32) + + @jax.jit + def __find_concentration(self, OD, HE): + # rows correspond to channels (RGB), columns to OD values + Y = jnp.reshape(OD, (-1, 3)).T + + # determine concentrations of the individual stains + C = jnp.linalg.lstsq(HE, Y, rcond=None)[0] + + return C + + @jax.jit + def __compute_matrices(self, I, Io, alpha, beta): + I = I.reshape((-1, 3)) + + # calculate optical density + OD = -jnp.log((I.astype(jnp.float32) + 1) / Io) + + mask = ~jnp.any(OD < beta, axis=1) # to remove transparent pixels + cov = jnp.cov(OD.T, fweights=mask.astype(jnp.int32)) + _, eigvecs = jnp.linalg.eigh(cov) + + Th = OD.dot(eigvecs[:, 1:3]) + + phi = jnp.arctan2(Th[:, 1], Th[:, 0]) + + phi = jnp.where(mask, phi, jnp.inf) + pvalid = mask.mean() # proportion that is valid and not masked + + minPhi = jnp.percentile(phi, alpha * pvalid) + maxPhi = jnp.percentile(phi, (100 - alpha) * pvalid) + + vMin = eigvecs[:, 1:3].dot(jnp.array([(jnp.cos(minPhi), jnp.sin(minPhi))]).T) + vMax = eigvecs[:, 1:3].dot(jnp.array([(jnp.cos(maxPhi), jnp.sin(maxPhi))]).T) + + # a heuristic to make the vector corresponding to hematoxylin first and the + # one corresponding to eosin second + HE = lax.cond( + vMin[0, 0] > vMax[0, 0], + lambda x: jnp.array((x[0], x[1])).T, + lambda x: jnp.array((x[1], x[0])).T, + (vMin[:, 0], vMax[:, 0]) + ) + C = self.__find_concentration(OD, HE) + + # normalize stain concentrations + maxC = jnp.array([jnp.percentile(C[0, :], 99), jnp.percentile(C[1, :],99)]) + + return HE, C, maxC + + #@jax.jit + def fit(self, I, Io=240, alpha=1, beta=0.15): + HE, _, maxC = self.__compute_matrices(I, Io, alpha, beta) + + self.HERef = HE + self.maxCRef = maxC + + #@jax.jit + def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True): + h, w, c = I.shape + I = I.reshape((-1, 3)) + + HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta) + + maxC = jnp.divide(maxC, self.maxCRef) + C2 = jnp.divide(C, maxC[:, jnp.newaxis]) + + # recreate the image using reference mixing matrix + Inorm = jnp.multiply(Io, jnp.exp(-self.HERef.dot(C2))) + Inorm = jnp.clip(Inorm, 0, 255) + Inorm = jnp.reshape(Inorm.T, (h, w, c)).astype(jnp.uint8) + + H, E = None, None + + if stains: + # unmix hematoxylin and eosin + H = jnp.multiply(Io, jnp.exp(jnp.expand_dims(-self.HERef[:, 0], axis=1).dot(jnp.expand_dims(C2[0, :], axis=0)))) + H = jnp.clip(H, 0, 255) + H = jnp.reshape(H.T, (h, w, c)).astype(jnp.uint8) + + E = jnp.multiply(Io, jnp.exp(jnp.expand_dims(-self.HERef[:, 1], axis=1).dot(jnp.expand_dims(C2[1, :], axis=0)))) + E = jnp.clip(E, 0, 255) + E = jnp.reshape(E.T, (h, w, c)).astype(jnp.uint8) + + return Inorm, H, E + + def _tree_flatten(self): + children = () # arrays / dynamic values + aux = () # static values + return children, aux + + @classmethod + def _tree_unflatten(cls, aux, children): + return cls(*children, *aux) + +tree_util.register_pytree_node( + JaxMacenkoNormalizer, JaxMacenkoNormalizer._tree_flatten,JaxMacenkoNormalizer._tree_unflatten +) diff --git a/torchstain/jax/utils/__init__.py b/torchstain/jax/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torchstain/numpy/normalizers/macenko.py b/torchstain/numpy/normalizers/macenko.py index faada97..968b716 100644 --- a/torchstain/numpy/normalizers/macenko.py +++ b/torchstain/numpy/normalizers/macenko.py @@ -10,13 +10,13 @@ def __init__(self): super().__init__() self.HERef = np.array([[0.5626, 0.2159], - [0.7201, 0.8012], - [0.4062, 0.5581]]) + [0.7201, 0.8012], + [0.4062, 0.5581]]) self.maxCRef = np.array([1.9705, 1.0308]) def __convert_rgb2od(self, I, Io=240, beta=0.15): # calculate optical density - OD = -np.log((I.astype(float)+1)/Io) + OD = -np.log((I.astype(np.float32)+1)/Io) # remove transparent pixels ODhat = OD[~np.any(OD < beta, axis=1)] diff --git a/torchstain/tf/normalizers/macenko.py b/torchstain/tf/normalizers/macenko.py index bf21666..578560a 100644 --- a/torchstain/tf/normalizers/macenko.py +++ b/torchstain/tf/normalizers/macenko.py @@ -1,7 +1,6 @@ import tensorflow as tf from torchstain.base.normalizers.he_normalizer import HENormalizer from torchstain.tf.utils import cov, percentile, solveLS -import numpy as np import tensorflow.keras.backend as K