Skip to content

ConnorStoneAstro/caskade

Repository files navigation

caskade logo

caskade

CI CD codecov PyPI - Version Documentation Status

Build scientific simulators, treating them as a directed acyclic graph. Handles argument passing for complex nested simulators.

Install

pip install caskade

if you want to use caskade with jax then run:

pip install caskade[jax]

Alternately, just pip install jax/jaxlib separately as they are the only extra requirements.

Usage

Make a Module object which may have some Params. Define a forward method using the decorator.

from caskade import Module, Param, forward

class MySim(Module):
    def __init__(self, a, b=None):
        super().__init__()
        self.a = a
        self.b = Param("b", b)

    @forward
    def myfun(self, x, b=None):
        return x + self.a + b

We may now create instances of the simulator and pass the dynamic parameters.

import torch

sim = MySim(1.0)

params = [torch.tensor(2.0)]

print(sim.myfun(3.0, params=params))

Which will print 6 by automatically filling b with the value from params.

Why do this?

The above example is not very impressive, the real power comes from the fact that Module objects can be nested arbitrarily making a much more complicated analysis graph. Further, the Param objects can be linked or have other complex relationships. All of the complexity of the nested structure and argument passing is abstracted away so that at the top one need only pass a list of tensors for each parameter, a single large 1d tensor, or a dictionary with the same structure as the graph.

Use different backends

caskade can be run with different backends for torch, numpy, and jax. See the Beginners Guide tutorial to learn more!

Documentation

The caskade interface has lots of flexibility, check out the docs to learn more. For a quick start, jump right to the Jupyter notebook tutorial!

About

Build scientific simulators, treating them as a directed acyclic graph

Resources

License

Stars

Watchers

Forks

Languages