Skip to content

Macenko JAX backend support #36

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
name: Build and upload to PyPI

on:
push
push:
branches:
- main

jobs:
build_wheels:
Expand Down
31 changes: 30 additions & 1 deletion .github/workflows/tests_full.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: tests
name: full test

on:
push:
Expand Down Expand Up @@ -102,3 +102,32 @@ jobs:

- name: Run tests
run: pytest -vs tests/test_torch.py

test-jax:
needs: build
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ windows-2019, ubuntu-18.04, macos-11 ]
python-version: [ 3.6, 3.7, 3.8, 3.9 ]

steps:
- uses: actions/checkout@v1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}

- name: Download artifact
uses: actions/download-artifact@master
with:
name: "Python wheel"

- name: Install dependencies
run: pip install jax jaxlib opencv-python-headless scikit-image pytest

- name: Install wheel
run: pip install --find-links=${{github.workspace}} torchstain

- name: Run tests
run: pytest -vs tests/test_jax.py
28 changes: 27 additions & 1 deletion .github/workflows/tests_quick.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: tests
name: quick tests

on:
push:
Expand Down Expand Up @@ -84,3 +84,29 @@ jobs:

- name: Run tests
run: pytest -vs tests/test_torch.py


test-jax:
needs: build
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v1
- name: Set up Python 3.7
uses: actions/setup-python@v2
with:
python-version: 3.7

- name: Download artifact
uses: actions/download-artifact@master
with:
name: "Python wheel"

- name: Install dependencies
run: pip install jax jaxlib opencv-python-headless scikit-image pytest

- name: Install wheel
run: pip install --find-links=${{github.workspace}} torchstain

- name: Run tests
run: pytest -vs tests/test_jax.py
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
[![Pip Downloads](https://img.shields.io/pypi/dm/torchstain?label=pip%20downloads&logo=python)](https://pypi.org/project/torchstain/)
[![DOI](https://zenodo.org/badge/323590093.svg)](https://zenodo.org/badge/latestdoi/323590093)

GPU-accelerated stain normalization tools for histopathological images. Compatible with PyTorch, TensorFlow, and Numpy.
GPU-accelerated stain normalization tools for histopathological images. Compatible with PyTorch, TensorFlow, Numpy, and JAX.
Normalization algorithms currently implemented:

- Macenko [\[1\]](#reference) (ported from [numpy implementation](https://github.com/schaugf/HEnorm_python))
Expand Down Expand Up @@ -47,11 +47,11 @@ norm, H, E = normalizer.normalize(I=t_to_transform, stains=True)

## Implemented algorithms

| Algorithm | numpy | torch | tensorflow |
|-|-|-|-|
| Macenko | ✓ | ✓ | ✓ |
| Reinhard | ✓ | ✓ | ✓ |
| Modified Reinhard | ✓ | ✓ | ✓ |
| Algorithm | numpy | torch | tensorflow | jax |
|-|-|-|-|-|
| Macenko | ✓ | ✓ | ✓ | ✓ |
| Reinhard | ✓ | ✓ | ✓ | ✗ |
| Modified Reinhard | ✓ | ✓ | ✓ | ✗ |

## Backend comparison

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
extras_require={
"tf": ["tensorflow"],
"torch": ["torch"],
"jax": ["jax", "jaxlib"],
},
classifiers=[
'Development Status :: 4 - Beta',
Expand Down
32 changes: 32 additions & 0 deletions tests/test_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os
import cv2
import torchstain
import torchstain.jax
import time
from skimage.metrics import structural_similarity as ssim
import numpy as np
from jax import numpy as jnp

def test_macenko_jax():
size = 1024
curr_file_path = os.path.dirname(os.path.realpath(__file__))
target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size))
to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size))

# initialize normalizers for each backend and fit to target image
normalizer = torchstain.normalizers.MacenkoNormalizer(backend='numpy')
normalizer.fit(target)

jax_normalizer = torchstain.normalizers.MacenkoNormalizer(backend='jax')
jax_normalizer.fit(target)

# transform
result_numpy, _, _ = normalizer.normalize(I=to_transform)
result_jax, _, _ = jax_normalizer.normalize(I=to_transform)

# convert to numpy and set dtype
result_numpy = result_numpy.astype("float32")
result_jax = np.asarray(result_jax).astype("float32")

# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_jax.flatten()), 1.0, decimal=4, verbose=True)
3 changes: 3 additions & 0 deletions torchstain/base/normalizers/macenko.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,8 @@ def MacenkoNormalizer(backend='torch'):
elif backend == "tensorflow":
from torchstain.tf.normalizers.macenko import TensorFlowMacenkoNormalizer
return TensorFlowMacenkoNormalizer()
elif backend == "jax":
from torchstain.jax.normalizers.macenko import JaxMacenkoNormalizer
return JaxMacenkoNormalizer()
else:
raise Exception(f'Unknown backend {backend}')
1 change: 1 addition & 0 deletions torchstain/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from torchstain.jax import normalizers, utils
1 change: 1 addition & 0 deletions torchstain/jax/normalizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from torchstain.jax.normalizers.macenko import JaxMacenkoNormalizer
112 changes: 112 additions & 0 deletions torchstain/jax/normalizers/macenko.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import jax
from jax import lax
from jax import numpy as jnp
from torchstain.base.normalizers import HENormalizer
from jax import tree_util

class JaxMacenkoNormalizer(HENormalizer):
def __init__(self):
super().__init__()

self.HERef = jnp.array([[0.5626, 0.2159],
[0.7201, 0.8012],
[0.4062, 0.5581]], dtype=jnp.float32)
self.maxCRef = jnp.array([1.9705, 1.0308], dtype=jnp.float32)

@jax.jit
def __find_concentration(self, OD, HE):
# rows correspond to channels (RGB), columns to OD values
Y = jnp.reshape(OD, (-1, 3)).T

# determine concentrations of the individual stains
C = jnp.linalg.lstsq(HE, Y, rcond=None)[0]

return C

@jax.jit
def __compute_matrices(self, I, Io, alpha, beta):
I = I.reshape((-1, 3))

# calculate optical density
OD = -jnp.log((I.astype(jnp.float32) + 1) / Io)

mask = ~jnp.any(OD < beta, axis=1) # to remove transparent pixels
cov = jnp.cov(OD.T, fweights=mask.astype(jnp.int32))
_, eigvecs = jnp.linalg.eigh(cov)

Th = OD.dot(eigvecs[:, 1:3])

phi = jnp.arctan2(Th[:, 1], Th[:, 0])

phi = jnp.where(mask, phi, jnp.inf)
pvalid = mask.mean() # proportion that is valid and not masked

minPhi = jnp.percentile(phi, alpha * pvalid)
maxPhi = jnp.percentile(phi, (100 - alpha) * pvalid)

vMin = eigvecs[:, 1:3].dot(jnp.array([(jnp.cos(minPhi), jnp.sin(minPhi))]).T)
vMax = eigvecs[:, 1:3].dot(jnp.array([(jnp.cos(maxPhi), jnp.sin(maxPhi))]).T)

# a heuristic to make the vector corresponding to hematoxylin first and the
# one corresponding to eosin second
HE = lax.cond(
vMin[0, 0] > vMax[0, 0],
lambda x: jnp.array((x[0], x[1])).T,
lambda x: jnp.array((x[1], x[0])).T,
(vMin[:, 0], vMax[:, 0])
)
C = self.__find_concentration(OD, HE)

# normalize stain concentrations
maxC = jnp.array([jnp.percentile(C[0, :], 99), jnp.percentile(C[1, :],99)])

return HE, C, maxC

#@jax.jit
def fit(self, I, Io=240, alpha=1, beta=0.15):
HE, _, maxC = self.__compute_matrices(I, Io, alpha, beta)

self.HERef = HE
self.maxCRef = maxC

#@jax.jit
def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True):
h, w, c = I.shape
I = I.reshape((-1, 3))

HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta)

maxC = jnp.divide(maxC, self.maxCRef)
C2 = jnp.divide(C, maxC[:, jnp.newaxis])

# recreate the image using reference mixing matrix
Inorm = jnp.multiply(Io, jnp.exp(-self.HERef.dot(C2)))
Inorm = jnp.clip(Inorm, 0, 255)
Inorm = jnp.reshape(Inorm.T, (h, w, c)).astype(jnp.uint8)

H, E = None, None

if stains:
# unmix hematoxylin and eosin
H = jnp.multiply(Io, jnp.exp(jnp.expand_dims(-self.HERef[:, 0], axis=1).dot(jnp.expand_dims(C2[0, :], axis=0))))
H = jnp.clip(H, 0, 255)
H = jnp.reshape(H.T, (h, w, c)).astype(jnp.uint8)

E = jnp.multiply(Io, jnp.exp(jnp.expand_dims(-self.HERef[:, 1], axis=1).dot(jnp.expand_dims(C2[1, :], axis=0))))
E = jnp.clip(E, 0, 255)
E = jnp.reshape(E.T, (h, w, c)).astype(jnp.uint8)

return Inorm, H, E

def _tree_flatten(self):
children = () # arrays / dynamic values
aux = () # static values
return children, aux

@classmethod
def _tree_unflatten(cls, aux, children):
return cls(*children, *aux)

tree_util.register_pytree_node(
JaxMacenkoNormalizer, JaxMacenkoNormalizer._tree_flatten,JaxMacenkoNormalizer._tree_unflatten
)
Empty file.
6 changes: 3 additions & 3 deletions torchstain/numpy/normalizers/macenko.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ def __init__(self):
super().__init__()

self.HERef = np.array([[0.5626, 0.2159],
[0.7201, 0.8012],
[0.4062, 0.5581]])
[0.7201, 0.8012],
[0.4062, 0.5581]])
self.maxCRef = np.array([1.9705, 1.0308])

def __convert_rgb2od(self, I, Io=240, beta=0.15):
# calculate optical density
OD = -np.log((I.astype(float)+1)/Io)
OD = -np.log((I.astype(np.float32)+1)/Io)

# remove transparent pixels
ODhat = OD[~np.any(OD < beta, axis=1)]
Expand Down
1 change: 0 additions & 1 deletion torchstain/tf/normalizers/macenko.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import tensorflow as tf
from torchstain.base.normalizers.he_normalizer import HENormalizer
from torchstain.tf.utils import cov, percentile, solveLS
import numpy as np
import tensorflow.keras.backend as K


Expand Down