From 20c9ac0a018255fd56bbeca739bd1306082659e2 Mon Sep 17 00:00:00 2001 From: TomasPegado Date: Mon, 3 Feb 2025 11:51:53 -0300 Subject: [PATCH 01/36] docs: add a Getting Started Section --- nx/guides/getting_started/introduction.md | 86 +++++ nx/guides/getting_started/quickstart.livemd | 395 ++++++++++++++++++++ nx/mix.exs | 3 + 3 files changed, 484 insertions(+) create mode 100644 nx/guides/getting_started/introduction.md create mode 100644 nx/guides/getting_started/quickstart.livemd diff --git a/nx/guides/getting_started/introduction.md b/nx/guides/getting_started/introduction.md new file mode 100644 index 0000000000..8a3586a86c --- /dev/null +++ b/nx/guides/getting_started/introduction.md @@ -0,0 +1,86 @@ +# What is Nx? + +Nx is the numerical computing library of Elixir. Since Elixir´s primary numerical datatypes and structures are not optimized for numerical programming, Nx is the fundamental package built to bridge this gap. + +[Elixir Nx](https://github.com/elixir-nx/nx) smoothly integrate to typed, multidimensional data implemented on other +platforms (called [tensors](introduction.html#what-are-tensors)). This support extends to the compilers and +libraries that support those tensors. Nx has four primary capabilities: + +- In Nx, tensors hold typed data in multiple, named dimensions. +- Numerical definitions, known as `defn`, support custom code with + tensor-aware operators and functions. +- [Automatic differentiation](https://arxiv.org/abs/1502.05767), also known as + autograd or autodiff, supports common computational scenarios + such as machine learning, simulations, curve fitting, and probabilistic models. +- Broadcasting, which is term for element-by-element operations. Most of the Nx operations + automatically broadcast using an effective algorithm. You can see more on broadcast + [here.](intro-to-nx.html#broadcasts) + +Here's more about each of those capabilities. Nx tensors can hold +unsigned integers (u2, u4, u8, u16, u32, u64), +signed integers (s2, s4, s8, s16, s32, s64), +floats (f32, f64), brain floats (bf16), and complex (c64, c128). +Tensors support backends implemented outside of Elixir, including Google's +Accelerated Linear Algebra (XLA) and LibTorch. + +Numerical definitions have compiler support to allow just-in-time compilation +that support specialized processors to speed up numeric computation including +TPUs and GPUs. + +## What are Tensors? + +In Nx, we express multi-dimensional data using typed tensors. Simply put, +a tensor is a multi-dimensional array with a predetermined shape and +type. To interact with them, Nx relies on tensor-aware operators rather +than `Enum.map/2` and `Enum.reduce/3`. + +It allows us to work with the central theme in numerical computing, systems of equations, +which are often expressed and solved with multidimensional arrays. + +For example, this is a two dimensional array: + +$$ +\begin{bmatrix} + 1 & 2 \\\\ + 3 & 4 +\end{bmatrix} +$$ + +As elixir programmers, we can typically express a similar data structure using a list of lists, +like this: + +```elixir +[ + [1, 2], + [3, 4] +] +``` + +This data structure works fine within many functional programming +algorithms, but breaks down with deep nesting and random access. + +On top of that, Elixir numeric types lack optimization for many numerical +applications. They work fine when programs +need hundreds or even thousands of calculations. However, they tend to break +down with traditional STEM applications when a typical problem +needs millions of calculations. + +To solve for this, we can simply use Nx tensors, for example: + +```elixir +Nx.tensor([[1,2],[3,4]]) + +Output: +#Nx.Tensor< +s32[2][2] +[ +[1, 2], +[3, 4] +] + +``` + +To know Nx, we'll get to know tensors first. The following overview will touch +on the major libraries. Then, future notebooks will take a deep dive into working +with tensors in detail, autograd, and backends. Then, we'll dive into specific +problem spaces like Axon, the machine learning library. diff --git a/nx/guides/getting_started/quickstart.livemd b/nx/guides/getting_started/quickstart.livemd new file mode 100644 index 0000000000..f0e0472284 --- /dev/null +++ b/nx/guides/getting_started/quickstart.livemd @@ -0,0 +1,395 @@ +# Nx quickstart + +## Prerequisites + +You will need to know a bit of Elixir. For a refresher, check out the +[Elixir Getting Started Guide](https://hexdocs.pm/elixir/introduction.html). + +To work the examples you can run using the livebook buttom in this page. + +#### Learning Objectives + +This is a overview of Nx tensors. In this section, we'll look at some of the various tools for +creating and interacting with tensors. The IEx helpers will assist our +exploration of the core tensor concepts. + +```elixir +import IEx.Helpers +``` + +After reading, you should be able to understand: + +- Create 1, 2 and N-dimensional tensors in `Nx`; +- How to index, slice and iterate through tensors; +- Basic tensor functions; +- How to apply some linear algebra operations to n-dimensional tensors without using for-loops; +- Axis and shape properties for n-dimensional tensors. + +## The Basics + +Now, everything is set up, so we're ready to create some tensors. + +```elixir +Mix.install([ + {:nx, "~> 0.5"} +]) +``` + +### Creating tensors + +The argument must be one of: + +- a tensor +- a number (which means the tensor is scalar/zero-dimensional) +- a boolean (also scalar/zero-dimensional) +- an arbitrarily nested list of numbers and booleans + +If a new tensor is allocated, it will be allocated in the backend defined by +`Nx.default_backend/0`, unless the `:backend` option is given, which overrides the +default. + +#### Examples + +A number returns a tensor of zero dimensions: + +```elixir +Nx.tensor(0) +``` + +```elixir +Nx.tensor(1.0) +``` + +Giving a list returns a vector (a one-dimensional tensor): + +```elixir +Nx.tensor([1, 2, 3]) +``` + +```elixir +Nx.tensor([1.2, 2.3, 3.4, 4.5]) +``` + +Multi-dimensional tensors are also possible: + +```elixir +Nx.tensor([[1, 2, 3], [4, 5, 6]]) +``` + +```elixir +Nx.tensor([[1, 2], [3, 4], [5, 6]]) +``` + +```elixir +Nx.tensor([[[1, 2], [3, 4], [5, 6]], [[-1, -2], [-3, -4], [-5, -6]]]) +``` + +Tensors can also be given as inputs, which is useful for functions that don´t want to care +about the input kind: + +```elixir +Nx.tensor(Nx.tensor([1, 2, 3])) +``` + +### Naming dimensions + +You can provide names for tensor dimensions. Names are atoms: + +```elixir +Nx.tensor([[1, 2, 3], [4, 5, 6]], names: [:x, :y]) +``` + +Names make your code more expressive: + +```elixir +Nx.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], names: [:batch, :height, :width]) +``` + +You can also leave dimension names as `nil`: + +```elixir +Nx.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], names: [:batch, nil, nil]) +``` + +However, you must provide a name for every dimension in the tensor. For example, +the following code snippet raises an error: + +```elixir +Nx.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], names: [:batch]) +``` + +### Indexing and Slicing tensor values + +We can get any cell of the tensor: + +```elixir +tensor = Nx.tensor([[1, 2], [3, 4]], names: [:y, :x]) +tensor[[0, 1]] +``` + +```elixir +tensor = Nx.tensor([[1, 2], [3, 4], [5, 6]], names: [:y, :x]) +tensor[[-1, -1]] +``` + +Now, try getting the first row of the tensor: + +```elixir +# ...your code here... +``` + +We can also get a whole dimension: + +```elixir +tensor[x: 1] +``` + +or a range: + +```elixir +tensor[y: 0..1] +``` + +`tensor[[.., 1]]` will achieve the same result as `tensor[x: 1]`. +This is because Elixir has the syntax sugar `..` for a `0..-1//1` range. + +Now, + +- create your own `{3, 3}` tensor with named dimensions +- return a `{2, 2}` tensor containing the first two columns + of the first two rows + +```elixir +# ...your code here... +``` + +### Floats and Complex numbers + +Besides single-precision (32 bits), floats can have other kinds of precision, such as half-precision (16) or +double-precision (64): + +```elixir +Nx.tensor([1, 2, 3], type: :f16) +``` + +```elixir +Nx.tensor([1, 2, 3], type: :f64) +``` + +Brain-floating points are also supported: + +```elixir +Nx.tensor([1, 2, 3], type: :bf16) +``` + +Certain backends and compilers support 8-bit floats. The precision +implementation of 8-bit floats may change per backend, so you must be careful +when transferring data across. The binary backend implements F8E5M2: + +```elixir +Nx.tensor([1, 2, 3], type: :f8) +``` + +In all cases, the non-finite values negative infinity (-Inf), infinity (Inf), +and "not a number" (NaN) can be represented by the atoms `:neg_infinity`, +`:infinity`, and `:nan respectively`: + +```elixir +Nx.tensor([:neg_infinity, :nan, :infinity]) +``` + +Finally, complex numbers are also supported in tensors: + +```elixir +Nx.tensor(Complex.new(1, -1)) +``` + +Check out the documentation for `Nx.tensor/2` for more documentation on the accepted options. + +## Basic operations + +Nx supports element-wise arithmetic operations for tensors and broadcasting when necessary. + +### Addition + +`Nx.add/2`: Adds corresponding elements of two tensors. + +```elixir +a = Nx.tensor([1, 2, 3]) +b = Nx.tensor([0, 1, 2]) +Nx.add(a , b) +``` + +### Subtraction + +`Nx.subtract/2`: Subtracts the elements of the second tensor from the first. + +```elixir +a = Nx.tensor([10, 20, 30]) +b = Nx.tensor([0, 1, 2]) +Nx.subtract(a , b) +``` + +### Multiplication + +`Nx.multiply/2`: Multiplies corresponding elements of two tensors. + +```elixir +a = Nx.tensor([2, 3, 4]) +b = Nx.tensor([0, 1, 2]) +Nx.multiply(a , b) +``` + +### Division + +`Nx.divide/2`: Divides the elements of the first tensor by the second tensor. + +```elixir +a = Nx.tensor([10, 30, 40]) +b = Nx.tensor([5, 6, 8]) +Nx.divide(a , b) +``` + +### Exponentiation + +`Nx.pow/2`: Raises each element of the first tensor to the power of the corresponding element in the second tensor. + +```elixir +a = Nx.tensor([2, 3, 4]) +b = Nx.tensor([2]) +Nx.pow(a , b) +``` + +### Quotient + +`Nx.quotient/2`: Returns a new tensor where each element is the integer division (div/2) of left by right. + +```elixir +a = Nx.tensor([10, 20, 30]) +b = Nx.tensor([3, 7, 4]) + +Nx.quotient(a, b) +``` + +### Remainder + +`Nx.remainder/2`: Computes the remainder of the division of two integer tensors. + +```elixir +a = Nx.tensor([27, 32, 43]) +b = Nx.tensor([2, 3, 4]) +Nx.remainder(a , b) +``` + +### Negation + +`Nx.negate/1`: Negates each element of a tensor. + +```elixir +a = Nx.tensor([2, 3, 4]) +Nx.negate(a) +``` + +### Square Root + +`Nx.sqrt/1`: It computes the element-wise square root of the given tensor. + +```elixir +a = Nx.tensor([4, 9, 16]) +Nx.sqrt(a) +``` + +## Element-Wise Comparison + +Returns 1 when true and 0 when false + +### Equality and Inequality + +`Nx.equal/2`, `Nx.not_equal/2` + +```elixir +a = Nx.tensor([4, 9, 16]) +b = Nx.tensor([4, 9, 16]) +Nx.equal(a, b) +``` + +```elixir +a = Nx.tensor([4, 9, 16]) +b = Nx.tensor([4.0, 9.0, 16.0]) +Nx.not_equal(a, b) +``` + +### Greater and Less + +`Nx.greater/2`, `Nx.less/2` + +```elixir +a = Nx.tensor([4, 9, 16]) +b = Nx.tensor([4, 8, 17]) +Nx.greater(a, b) +``` + +```elixir +a = Nx.tensor([4, 9, 16]) +b = Nx.tensor([4.2, 9.0, 16.7]) +Nx.less(a, b) +``` + +### Greater_Equal and Less_Equal + +`Nx.greater_equal/2`, `Nx.less_equal/2` + +```elixir +a = Nx.tensor([3, 5, 2]) +b = Nx.tensor([2, 5, 4]) + +Nx.greater_equal(a, b) +``` + +```elixir +a = Nx.tensor([3, 5, 2]) +b = Nx.tensor([2, 5, 4]) + +Nx.less_equal(a, b) +``` + +## Aggregate functions + +These operations aggregate values across tensor axes. + +### Sum + +`Nx.sum/1`: Sums all elements + +```elixir +a = Nx.tensor([[4, 9, 16], [4.2, 9.0, 16.7]]) +Nx.sum(a) +``` + +### Mean + +`Nx.mean/1`: Computes the mean value of the tensor + +```elixir +a = Nx.tensor([[4, 9, 16], [4.2, 9.0, 16.7]]) +Nx.mean(a) +``` + +### Product + +`Nx.product/1`: Computes the product of all elements. + +```elixir +a = Nx.tensor([[4, 9, 16], [4.2, 9.0, 16.7]]) +Nx.product(a) +``` + +## Matrix Multiplication + +`Nx.dot/4`: Computes the generalized dot product between two tensors, given the contracting axes.hyunnnn + +```elixir +t1 = Nx.tensor([[1, 2], [3, 4]], names: [:x, :y]) +t2 = Nx.tensor([[10, 20], [30, 40]], names: [:height, :width]) +Nx.dot(t1, [0], t2, [0]) +``` diff --git a/nx/mix.exs b/nx/mix.exs index a43972cf17..6e682112cc 100644 --- a/nx/mix.exs +++ b/nx/mix.exs @@ -58,6 +58,8 @@ defmodule Nx.MixProject do extras: [ "CHANGELOG.md", "guides/intro-to-nx.livemd", + "guides/getting_started/introduction.md", + "guides/getting_started/quickstart.livemd", "guides/advanced/vectorization.livemd", "guides/advanced/aggregation.livemd", "guides/exercises/exercises-1-20.livemd" @@ -112,6 +114,7 @@ defmodule Nx.MixProject do ] ], groups_for_extras: [ + Getting_Started: ~r"^guides/getting_started/", Exercises: ~r"^guides/exercises/", Advanced: ~r"^guides/advanced/" ] From ba529c6bfaf27463beeabd2c1c328b9c44bb2371 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 4 Feb 2025 09:26:00 -0300 Subject: [PATCH 02/36] Update introduction.md --- nx/guides/getting_started/introduction.md | 38 ++++++++++------------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/nx/guides/getting_started/introduction.md b/nx/guides/getting_started/introduction.md index 8a3586a86c..d70b48167d 100644 --- a/nx/guides/getting_started/introduction.md +++ b/nx/guides/getting_started/introduction.md @@ -1,30 +1,28 @@ # What is Nx? -Nx is the numerical computing library of Elixir. Since Elixir´s primary numerical datatypes and structures are not optimized for numerical programming, Nx is the fundamental package built to bridge this gap. +Nx is the numerical computing library of Elixir. Since Elixir's primary numerical datatypes and structures are not optimized for numerical programming, Nx is the fundamental package built to bridge this gap. -[Elixir Nx](https://github.com/elixir-nx/nx) smoothly integrate to typed, multidimensional data implemented on other -platforms (called [tensors](introduction.html#what-are-tensors)). This support extends to the compilers and -libraries that support those tensors. Nx has four primary capabilities: +[Elixir Nx](https://github.com/elixir-nx/nx) smoothly integrates typed, multidimensional data called [tensors](introduction.html#what-are-tensors)). +Nx has four primary capabilities: -- In Nx, tensors hold typed data in multiple, named dimensions. +- In Nx, tensors hold typed data in multiple, optionally named dimensions. - Numerical definitions, known as `defn`, support custom code with tensor-aware operators and functions. - [Automatic differentiation](https://arxiv.org/abs/1502.05767), also known as autograd or autodiff, supports common computational scenarios such as machine learning, simulations, curve fitting, and probabilistic models. -- Broadcasting, which is term for element-by-element operations. Most of the Nx operations - automatically broadcast using an effective algorithm. You can see more on broadcast +- Broadcasting, which is a term for element-by-element operations. Most of the Nx operations + make use of automatic implicit broadcasting. You can see more on broadcasting [here.](intro-to-nx.html#broadcasts) -Here's more about each of those capabilities. Nx tensors can hold -unsigned integers (u2, u4, u8, u16, u32, u64), +Nx tensors can hold unsigned integers (u2, u4, u8, u16, u32, u64), signed integers (s2, s4, s8, s16, s32, s64), -floats (f32, f64), brain floats (bf16), and complex (c64, c128). -Tensors support backends implemented outside of Elixir, including Google's -Accelerated Linear Algebra (XLA) and LibTorch. +floats (f8, f16, f32, f64), brain floats (bf16), and complex (c64, c128). +Tensors support backends implemented outside of Elixir, such as Google's +Accelerated Linear Algebra (XLA) and PyTorch. -Numerical definitions have compiler support to allow just-in-time compilation -that support specialized processors to speed up numeric computation including +Numerical definitions provide compiler support to allow just-in-time compilation +targetting specialized processors to speed up numeric computation including TPUs and GPUs. ## What are Tensors? @@ -74,13 +72,11 @@ Output: #Nx.Tensor< s32[2][2] [ -[1, 2], -[3, 4] + [1, 2], + [3, 4] ] - ``` -To know Nx, we'll get to know tensors first. The following overview will touch -on the major libraries. Then, future notebooks will take a deep dive into working -with tensors in detail, autograd, and backends. Then, we'll dive into specific -problem spaces like Axon, the machine learning library. +To learn Nx, we'll get to know tensors first. The following overview will touch +on the major features. The advanced section of the documentation will take a deep dive into working +with tensors in detail, autodiff, and backends. From 0a97e25bb462a6b290aa5a3a9c0426e2c6678e89 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 4 Feb 2025 10:02:48 -0300 Subject: [PATCH 03/36] Update quickstart.livemd --- nx/guides/getting_started/quickstart.livemd | 100 ++++++++++---------- 1 file changed, 52 insertions(+), 48 deletions(-) diff --git a/nx/guides/getting_started/quickstart.livemd b/nx/guides/getting_started/quickstart.livemd index f0e0472284..35ac538e79 100644 --- a/nx/guides/getting_started/quickstart.livemd +++ b/nx/guides/getting_started/quickstart.livemd @@ -2,20 +2,15 @@ ## Prerequisites -You will need to know a bit of Elixir. For a refresher, check out the +To properly use Nx, you will need to know a bit of Elixir. For a refresher, check out the [Elixir Getting Started Guide](https://hexdocs.pm/elixir/introduction.html). -To work the examples you can run using the livebook buttom in this page. +To work on the examples you can run using the "Run in Livebook" button in this page. #### Learning Objectives This is a overview of Nx tensors. In this section, we'll look at some of the various tools for -creating and interacting with tensors. The IEx helpers will assist our -exploration of the core tensor concepts. - -```elixir -import IEx.Helpers -``` +creating and interacting with tensors. After reading, you should be able to understand: @@ -27,7 +22,7 @@ After reading, you should be able to understand: ## The Basics -Now, everything is set up, so we're ready to create some tensors. +First, let's install Nx with `Mix.install`. ```elixir Mix.install([ @@ -35,22 +30,28 @@ Mix.install([ ]) ``` +The `IEx.Helpers` module will assist our exploration of the core tensor concepts. + +```elixir +import IEx.Helpers +``` + ### Creating tensors -The argument must be one of: +The argument for `Nx.tensor/1` must be one of: -- a tensor -- a number (which means the tensor is scalar/zero-dimensional) -- a boolean (also scalar/zero-dimensional) +- a tensor; +- a number (which means the tensor is scalar/zero-dimensional); +- a boolean (also scalar/zero-dimensional); - an arbitrarily nested list of numbers and booleans +- the special atoms `:nan`, `:infinity`, `:neg_infinity`, which represent values not supported by Elixir floats. -If a new tensor is allocated, it will be allocated in the backend defined by -`Nx.default_backend/0`, unless the `:backend` option is given, which overrides the -default. +If a new tensor is allocated, it will be allocated in the backend defined by the `:backend` option. +If it is not provided, `Nx.default_backend/0` will be used instead. #### Examples -A number returns a tensor of zero dimensions: +A number returns a tensor of zero dimensions, also known as a scalar: ```elixir Nx.tensor(0) @@ -60,7 +61,7 @@ Nx.tensor(0) Nx.tensor(1.0) ``` -Giving a list returns a vector (a one-dimensional tensor): +A list returns a one-dimensional tensor, also known as a vector: ```elixir Nx.tensor([1, 2, 3]) @@ -70,7 +71,7 @@ Nx.tensor([1, 2, 3]) Nx.tensor([1.2, 2.3, 3.4, 4.5]) ``` -Multi-dimensional tensors are also possible: +Higher dimensional tensors are also possible: ```elixir Nx.tensor([[1, 2, 3], [4, 5, 6]]) @@ -105,14 +106,14 @@ Names make your code more expressive: Nx.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], names: [:batch, :height, :width]) ``` -You can also leave dimension names as `nil`: +You can also leave dimension names as `nil` (which is the default): ```elixir Nx.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], names: [:batch, nil, nil]) ``` However, you must provide a name for every dimension in the tensor. For example, -the following code snippet raises an error: +the following code snippet raises an error because 1 name is given, but there are 3 dimensions: ```elixir Nx.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], names: [:batch]) @@ -127,6 +128,9 @@ tensor = Nx.tensor([[1, 2], [3, 4]], names: [:y, :x]) tensor[[0, 1]] ``` +Negative indices will start counting from the end of the axis. +`-1` is the last entry, `-2` the second to last and so on. + ```elixir tensor = Nx.tensor([[1, 2], [3, 4], [5, 6]], names: [:y, :x]) tensor[[-1, -1]] @@ -169,22 +173,22 @@ Besides single-precision (32 bits), floats can have other kinds of precision, su double-precision (64): ```elixir -Nx.tensor([1, 2, 3], type: :f16) +Nx.tensor([0.0, 0.2, 0.4, 1.0], type: :f16) ``` ```elixir -Nx.tensor([1, 2, 3], type: :f64) +Nx.tensor([0.0, 0.2, 0.4, 1.0, type: :f64) ``` -Brain-floating points are also supported: +Brain floats are also supported: ```elixir -Nx.tensor([1, 2, 3], type: :bf16) +Nx.tensor([0.0, 0.2, 0.4, 1.0, type: :bf16) ``` Certain backends and compilers support 8-bit floats. The precision implementation of 8-bit floats may change per backend, so you must be careful -when transferring data across. The binary backend implements F8E5M2: +when transferring data across different backends. The binary backend implements F8E5M2: ```elixir Nx.tensor([1, 2, 3], type: :f8) @@ -192,13 +196,13 @@ Nx.tensor([1, 2, 3], type: :f8) In all cases, the non-finite values negative infinity (-Inf), infinity (Inf), and "not a number" (NaN) can be represented by the atoms `:neg_infinity`, -`:infinity`, and `:nan respectively`: +`:infinity`, and `:nan`, respectively: ```elixir Nx.tensor([:neg_infinity, :nan, :infinity]) ``` -Finally, complex numbers are also supported in tensors: +Finally, complex numbers are also supported in tensors, in both 32-bit and 64-bit precision: ```elixir Nx.tensor(Complex.new(1, -1)) @@ -217,7 +221,7 @@ Nx supports element-wise arithmetic operations for tensors and broadcasting when ```elixir a = Nx.tensor([1, 2, 3]) b = Nx.tensor([0, 1, 2]) -Nx.add(a , b) +Nx.add(a, b) ``` ### Subtraction @@ -227,7 +231,7 @@ Nx.add(a , b) ```elixir a = Nx.tensor([10, 20, 30]) b = Nx.tensor([0, 1, 2]) -Nx.subtract(a , b) +Nx.subtract(a, b) ``` ### Multiplication @@ -237,7 +241,7 @@ Nx.subtract(a , b) ```elixir a = Nx.tensor([2, 3, 4]) b = Nx.tensor([0, 1, 2]) -Nx.multiply(a , b) +Nx.multiply(a, b) ``` ### Division @@ -247,7 +251,7 @@ Nx.multiply(a , b) ```elixir a = Nx.tensor([10, 30, 40]) b = Nx.tensor([5, 6, 8]) -Nx.divide(a , b) +Nx.divide(a, b) ``` ### Exponentiation @@ -257,12 +261,12 @@ Nx.divide(a , b) ```elixir a = Nx.tensor([2, 3, 4]) b = Nx.tensor([2]) -Nx.pow(a , b) +Nx.pow(a, b) ``` ### Quotient -`Nx.quotient/2`: Returns a new tensor where each element is the integer division (div/2) of left by right. +`Nx.quotient/2`: Returns a new tensor where each element is the integer division (`div/2`). ```elixir a = Nx.tensor([10, 20, 30]) @@ -273,7 +277,7 @@ Nx.quotient(a, b) ### Remainder -`Nx.remainder/2`: Computes the remainder of the division of two integer tensors. +`Nx.remainder/2`: Computes the remainder of the integer division. ```elixir a = Nx.tensor([27, 32, 43]) @@ -292,7 +296,7 @@ Nx.negate(a) ### Square Root -`Nx.sqrt/1`: It computes the element-wise square root of the given tensor. +`Nx.sqrt/1`: Computes the element-wise square root. ```elixir a = Nx.tensor([4, 9, 16]) @@ -301,7 +305,7 @@ Nx.sqrt(a) ## Element-Wise Comparison -Returns 1 when true and 0 when false +The following operations returns a u8 tensor where 1 represents `true` and 0 represents `false`. ### Equality and Inequality @@ -309,13 +313,13 @@ Returns 1 when true and 0 when false ```elixir a = Nx.tensor([4, 9, 16]) -b = Nx.tensor([4, 9, 16]) +b = Nx.tensor([4, 9, -16]) Nx.equal(a, b) ``` ```elixir a = Nx.tensor([4, 9, 16]) -b = Nx.tensor([4.0, 9.0, 16.0]) +b = Nx.tensor([4.0, 9.0, -16.0]) Nx.not_equal(a, b) ``` @@ -331,7 +335,7 @@ Nx.greater(a, b) ```elixir a = Nx.tensor([4, 9, 16]) -b = Nx.tensor([4.2, 9.0, 16.7]) +b = Nx.tensor([4.2, 9.0, 15.9]) Nx.less(a, b) ``` @@ -340,15 +344,15 @@ Nx.less(a, b) `Nx.greater_equal/2`, `Nx.less_equal/2` ```elixir -a = Nx.tensor([3, 5, 2]) -b = Nx.tensor([2, 5, 4]) +a = Nx.tensor([4, 9, 16]) +b = Nx.tensor([4, 8, 17]) Nx.greater_equal(a, b) ``` ```elixir -a = Nx.tensor([3, 5, 2]) -b = Nx.tensor([2, 5, 4]) +a = Nx.tensor([4, 9, 16]) +b = Nx.tensor([4.2, 9.0, 15.9]) Nx.less_equal(a, b) ``` @@ -359,7 +363,7 @@ These operations aggregate values across tensor axes. ### Sum -`Nx.sum/1`: Sums all elements +`Nx.sum/1`: Sums all elements. ```elixir a = Nx.tensor([[4, 9, 16], [4.2, 9.0, 16.7]]) @@ -368,7 +372,7 @@ Nx.sum(a) ### Mean -`Nx.mean/1`: Computes the mean value of the tensor +`Nx.mean/1`: Computes the mean value of the elements. ```elixir a = Nx.tensor([[4, 9, 16], [4.2, 9.0, 16.7]]) @@ -377,7 +381,7 @@ Nx.mean(a) ### Product -`Nx.product/1`: Computes the product of all elements. +`Nx.product/1`: Computes the product of the elements. ```elixir a = Nx.tensor([[4, 9, 16], [4.2, 9.0, 16.7]]) @@ -386,7 +390,7 @@ Nx.product(a) ## Matrix Multiplication -`Nx.dot/4`: Computes the generalized dot product between two tensors, given the contracting axes.hyunnnn +`Nx.dot/4`: Computes the generalized dot product between two tensors, given the contracting axes. ```elixir t1 = Nx.tensor([[1, 2], [3, 4]], names: [:x, :y]) From 22aab8f08475805f77206288591d3ea5b687f59e Mon Sep 17 00:00:00 2001 From: TomasPegado Date: Mon, 10 Feb 2025 10:35:53 -0300 Subject: [PATCH 04/36] docs: add installation guide --- nx/guides/getting_started/installation.md | 186 ++++++++++++++++++++ nx/guides/getting_started/quickstart.livemd | 25 ++- nx/mix.exs | 3 +- 3 files changed, 210 insertions(+), 4 deletions(-) create mode 100644 nx/guides/getting_started/installation.md diff --git a/nx/guides/getting_started/installation.md b/nx/guides/getting_started/installation.md new file mode 100644 index 0000000000..f00b2f0655 --- /dev/null +++ b/nx/guides/getting_started/installation.md @@ -0,0 +1,186 @@ +# Installation + +The only prerequisite for installing Nx is Elixir itself. If you don´t have Elixir installed +in your machine you can visit this [intallation page](https://elixir-lang.org/install.html). + +There are several ways to install Nx (Numerical Elixir), depending on your project type and needs. + +## Using Mix in a standardElixir Project + +If you are working inside a Mix project, the recommended way to install Nx is by adding it to your mix.exs dependencies: + +1. Open mix.exs and modify the deps function: + +```elixir +defp deps do + [ + {:nx, "~> 0.5"} # Install the latest stable version + ] +end +``` + +2. Fetch the dependencies, run on the terminal: + +```sh +mix deps.get +``` + +## Installing Nx from GitHub (Latest Development Version) + +If you need the latest, unreleased features, install Nx directly from the GitHub repository. + +1. Modify mix.exs: + +```elixir +defp deps do + [ + {:nx, github: "elixir-nx/nx", branch: "main"} + ] +end + +``` + +2. Fetch dependencies: + +```sh +mix deps.get + +``` + +## Installing Nx in a Standalone Script (Without a Mix Project) + +If you don’t have a Mix project and just want to run a standalone script, use Mix.install/1 to dynamically fetch and install Nx. + +```elixir +Mix.install([:nx]) + +require Nx + +tensor = Nx.tensor([1, 2, 3]) +IO.inspect(tensor) + +``` + +Run the script with: + +```sh +elixir my_script.exs + +``` + +Best for: Quick experiments, small scripts, or one-off computations. + +## Installing the Latest Nx from GitHub in a Standalone Script + +To use the latest development version in a script (without a Mix project): + +```elixir +Mix.install([ + {:nx, github: "elixir-nx/nx", branch: "main"} +]) + +require Nx + +tensor = Nx.tensor([1, 2, 3]) +IO.inspect(tensor) +``` + +Run: + +```sh +elixir my_script.exs + +``` + +Best for: Trying new features from Nx without creating a full project. + +## Installing Nx with EXLA for GPU Acceleration + +To enable GPU/TPU acceleration with Google’s XLA backend, install Nx along with EXLA: + +1. Modify mix.exs: + +```elixir +defp deps do + [ + {:nx, "~> 0.5"}, + {:exla, "~> 0.5"} # EXLA (Google XLA Backend) + ] +end +``` + +2. Fetch dependencies: + +```sh +mix deps.get +``` + +3. Run with EXLA enabled: + +```elixir +EXLA.set_preferred_backend(:tpu) +``` + +Best for: Running Nx on GPUs or TPUs using Google’s XLA compiler. + +## Installing Nx with Torchx for PyTorch Acceleration + +To run Nx operations on PyTorch’s backend (LibTorch): + +1. Modify mix.exs: + +```elixir +defp deps do + [ + {:nx, "~> 0.5"}, + {:torchx, "~> 0.5"} # PyTorch Backend + ] +end + +``` + +2. Fetch dependencies: + +```sh +mix deps.get +``` + +3. Run with EXLA enabled: + +```elixir +Torchx.set_preferred_backend() +``` + +Best for: Deep learning applications with PyTorch acceleration. + +## Installing Nx with OpenBLAS for CPU Optimization + +To optimize CPU performance with OpenBLAS: + +1. Install OpenBLAS (libopenblas): + - Ubuntu/Debian: + ```sh + sudo apt install libopenblas-dev + ``` + - MacOS (using Homebrew): + ```sh + brew install openblas + ``` +2. Modify mix.exs: + +```elixir +defp deps do + [ + {:nx, "~> 0.5"}, + {:openblas, "~> 0.5"} # CPU-optimized BLAS backend + ] +end +``` + +3. Fetch dependencies: + +```sh +mix deps.get +``` + +Best for: Optimizing CPU-based tensor computations. diff --git a/nx/guides/getting_started/quickstart.livemd b/nx/guides/getting_started/quickstart.livemd index 35ac538e79..36064a7971 100644 --- a/nx/guides/getting_started/quickstart.livemd +++ b/nx/guides/getting_started/quickstart.livemd @@ -106,6 +106,8 @@ Names make your code more expressive: Nx.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], names: [:batch, :height, :width]) ``` +We created a tensor of the shape `{3, 3}`, and two axes named `height` and `width`. + You can also leave dimension names as `nil` (which is the default): ```elixir @@ -128,7 +130,7 @@ tensor = Nx.tensor([[1, 2], [3, 4]], names: [:y, :x]) tensor[[0, 1]] ``` -Negative indices will start counting from the end of the axis. +Negative indices will start counting from the end of the axis. `-1` is the last entry, `-2` the second to last and so on. ```elixir @@ -167,6 +169,23 @@ Now, # ...your code here... ``` +### Tensor shape and reshape + +```elixir +Nx.shape(tensor) +``` + +We can also create a new tensor with a new shape using `Nx.reshape/2`: + +```elixir +Nx.reshape(tensor, {1, 4}, names: [:batches, :values]) +``` + +This operation reuses all of the tensor data and simply +changes the metadata, so it has no notable cost. + +The new tensor has the same type, but a new shape. + ### Floats and Complex numbers Besides single-precision (32 bits), floats can have other kinds of precision, such as half-precision (16) or @@ -177,13 +196,13 @@ Nx.tensor([0.0, 0.2, 0.4, 1.0], type: :f16) ``` ```elixir -Nx.tensor([0.0, 0.2, 0.4, 1.0, type: :f64) +Nx.tensor([0.0, 0.2, 0.4, 1.0], type: :f64) ``` Brain floats are also supported: ```elixir -Nx.tensor([0.0, 0.2, 0.4, 1.0, type: :bf16) +Nx.tensor([0.0, 0.2, 0.4, 1.0], type: :bf16) ``` Certain backends and compilers support 8-bit floats. The precision diff --git a/nx/mix.exs b/nx/mix.exs index 6e682112cc..90ffa9c51e 100644 --- a/nx/mix.exs +++ b/nx/mix.exs @@ -59,6 +59,7 @@ defmodule Nx.MixProject do "CHANGELOG.md", "guides/intro-to-nx.livemd", "guides/getting_started/introduction.md", + "guides/getting_started/installation.md", "guides/getting_started/quickstart.livemd", "guides/advanced/vectorization.livemd", "guides/advanced/aggregation.livemd", @@ -114,7 +115,7 @@ defmodule Nx.MixProject do ] ], groups_for_extras: [ - Getting_Started: ~r"^guides/getting_started/", + "Getting Started": ~r"^guides/getting_started/", Exercises: ~r"^guides/exercises/", Advanced: ~r"^guides/advanced/" ] From e58062b9903b145bc4e43e5dd95e9976d76665ea Mon Sep 17 00:00:00 2001 From: TomasPegado Date: Wed, 26 Mar 2025 16:09:07 -0300 Subject: [PATCH 05/36] docs: getting started section --- nx/guides/getting_started/broadcast.livemd | 153 ++++++++++++++++++ .../numerical_definitions.livemd | 121 ++++++++++++++ nx/mix.exs | 2 + 3 files changed, 276 insertions(+) create mode 100644 nx/guides/getting_started/broadcast.livemd create mode 100644 nx/guides/getting_started/numerical_definitions.livemd diff --git a/nx/guides/getting_started/broadcast.livemd b/nx/guides/getting_started/broadcast.livemd new file mode 100644 index 0000000000..ecd46f1b7a --- /dev/null +++ b/nx/guides/getting_started/broadcast.livemd @@ -0,0 +1,153 @@ +# Broadcasts + +Often, the dimensions of tensors in an operator don't match. +For example, you might want to subtract a `1` from every +element of a `{2, 2}` tensor, like this: + +$$ +\begin{bmatrix} + 1 & 2 \\\\ + 3 & 4 +\end{bmatrix} - 1 = +\begin{bmatrix} + 0 & 1 \\\\ + 2 & 3 +\end{bmatrix} +$$ + +Mathematically, it's the same as this: + +$$ +\begin{bmatrix} + 1 & 2 \\\\ + 3 & 4 +\end{bmatrix} - +\begin{bmatrix} + 1 & 1 \\\\ + 1 & 1 +\end{bmatrix} = +\begin{bmatrix} + 0 & 1 \\\\ + 2 & 3 +\end{bmatrix} +$$ + +That means we need a way to convert `1` to a `{2, 2}` tensor. +`Nx.broadcast/2` solves that problem. This function takes +a tensor or a scalar and a shape. + +```elixir +Mix.install([ + {:nx, "~> 0.5"} +]) + + +Nx.broadcast(1, {2, 2}) +``` + +This broadcast takes the scalar `1` and translates it +to a compatible shape by copying it. Sometimes, it's easier +to provide a tensor as the second argument, and let `broadcast/2` +extract its shape: + +```elixir +tensor = Nx.tensor([[1, 2], [3, 4]]) +Nx.broadcast(1, tensor) +``` + +The code broadcasts `1` to the shape of `tensor`. In many operators +and functions, the broadcast happens automatically: + +```elixir +Nx.subtract(tensor, 1) +``` + +This result is possible because Nx broadcasts _both tensors_ +in `subtract/2` to compatible shapes. That means you can provide +scalar values as either argument: + +```elixir +Nx.subtract(10, tensor) +``` + +Or subtract a row or column. Mathematically, it would look like this: + +$$ +\begin{bmatrix} + 1 & 2 \\\\ + 3 & 4 +\end{bmatrix} - +\begin{bmatrix} + 1 & 2 +\end{bmatrix} = +\begin{bmatrix} + 0 & 0 \\\\ + 2 & 2 +\end{bmatrix} +$$ + +which is the same as this: + +$$ +\begin{bmatrix} + 1 & 2 \\\\ + 3 & 4 +\end{bmatrix} - +\begin{bmatrix} + 1 & 2 \\\\ + 1 & 2 +\end{bmatrix} = +\begin{bmatrix} + 0 & 0 \\\\ + 2 & 2 +\end{bmatrix} +$$ + +This rewrite happens in Nx too, also through a broadcast. We want to +broadcast the tensor `[1, 2]` to match the `{2, 2}` shape, like this: + +```elixir +Nx.broadcast(Nx.tensor([1, 2]), {2, 2}) +``` + +The `subtract` function in `Nx` takes care of that broadcast +implicitly, as before: + +```elixir +Nx.subtract(tensor, Nx.tensor([1, 2])) +``` + +The broadcast worked as advertised, copying the `[1, 2]` row +enough times to fill a `{2, 2}` tensor. A tensor with a +dimension of `1` will broadcast to fill the tensor: + +```elixir +[[1], [2]] |> Nx.tensor() |> Nx.broadcast({1, 2, 2}) +``` + +```elixir +[[[1, 2, 3]]] +|> Nx.tensor() +|> Nx.broadcast({4, 2, 3}) +``` + +Both of these examples copy parts of the tensor enough +times to fill out the broadcast shape. You can check out the +Nx broadcasting documentation for more details: + + + +```elixir +h Nx.broadcast +``` + +Much of the time, you won't have to broadcast yourself. Many of +the functions and operators Nx supports will do so automatically. + +We can use tensor-aware operators via various `Nx` functions and +many of them implicitly broadcast tensors. + +Throughout this section, we have been invoking `Nx.subtract/2` and +our code would be more expressive if we could use its equivalent +mathematical operator. Fortunately, Nx provides a way. Next, we'll +dive into numerical definitions using `defn`. diff --git a/nx/guides/getting_started/numerical_definitions.livemd b/nx/guides/getting_started/numerical_definitions.livemd new file mode 100644 index 0000000000..3f7b607042 --- /dev/null +++ b/nx/guides/getting_started/numerical_definitions.livemd @@ -0,0 +1,121 @@ +# Numerical definitions (defn) + +The `defn` macro simplifies the expression of mathematical formulas +containing tensors. Numerical definitions have two primary benefits +over classic Elixir functions. + +- They are _tensor-aware_. Nx replaces operators like `Kernel.-/2` + with the `Defn` counterparts — which in turn use `Nx` functions + optimized for tensors — so the formulas we express can use + tensors out of the box. + +- `defn` definitions allow for building computation graph of all the + individual operations and using a just-in-time (JIT) compiler to emit + highly specialized native code for the desired computation unit. + +We don't have to do anything special to get access to +get tensor awareness beyond importing `Nx.Defn` and writing +our code within a `defn` block. + +To use Nx in a Mix project or a notebook, we need to include +the `:nx` dependency and import the `Nx.Defn` module, +like this: + +```elixir +Mix.install([ + {:nx, "~> 0.5"} +]) +``` + +```elixir +import Nx.Defn +``` + +Just as the Elixir language supports `def`, `defmacro`, and `defp`, +Nx supports `defn`. There are a few restrictions. It allows only +numerical arguments in the form of primitives or tensors as arguments +or return values, and supports only a subset of the language. + +The subset of Elixir allowed within `defn` is quite broad, though. We can +use macros, pipes, and even conditionals, so we're not giving up +much when you're declaring mathematical functions. + +Additionally, despite these small concessions, `defn` provides huge benefits. +Code in a `defn` block uses tensor aware operators and types, so the math +beneath your functions has a better chance to shine through. Numerical +definitions can also run on accelerated numerical processors like GPUs and +TPUs. Here's an example numerical definition: + +```elixir +defmodule TensorMath do + import Nx.Defn + + defn subtract(a, b) do + a - b + end +end +``` + +This module has a numerical definition that will be compiled. +If we wanted to specify a compiler for this module, we could add +a module attribute before the `defn` clause. One of such compilers +is [the EXLA compiler](https://github.com/elixir-nx/nx/tree/main/exla). +You'd add the `mix` dependency for EXLA and do this: + + + +```elixir +@defn_compiler EXLA +defn subtract(a, b) do + a - b +end +``` + +Now, it's your turn. Add a `defn` to `TensorMath` +that accepts two tensors representing the lengths of sides of a +right triangle and uses the pythagorean theorem to return the +[length of the hypotenuse](https://www.mathsisfun.com/pythagoras.html). +Add your function directly to the previous Code cell. + +## deftransform + +The defn macro in Nx allows you to define functions that compile to efficient +numerical computations, but it comes with certain limitations—such as restrictions +on argument types, return values, and the subset of Elixir that it supports. +To overcome many of these limitations, Nx offers the deftransform macro. + +deftransform lets you perform computations or execute code that isn't directly +supported by defn, and then incorporate those results back into your numerical +function. This separation lets you use standard Elixir features where necessary +while keeping your core numerical logic optimized. + +In the following example, we define a deftransform function called +compute_tensor_from_list/1 that receives a list, which is not allowed +inside defn. Inside this transform function, we convert the list to a tensor +using Nx.tensor/1, and then pass it to a defn function called double_tensor/1, +which performs the actual numerical computation. + +```elixir +defmodule MyMath do + import Nx.Defn + + defn double_tensor(tensor) do + tensor * 2 + end + + deftransform compute_tensor_from_list(list) do + tensor = Nx.tensor(list) + double_tensor(tensor) + end +end + +``` + +```elixir +input = [1, 2, 3, 4] +result = MyMath.compute_tensor_from_list(input) +``` + +This setup allows us to keep our defn code clean and focused only on tensor +operations, while using deftransform to handle Elixir-native types and +preprocessing. diff --git a/nx/mix.exs b/nx/mix.exs index fb0af74bcd..46ac5ad508 100644 --- a/nx/mix.exs +++ b/nx/mix.exs @@ -61,6 +61,8 @@ defmodule Nx.MixProject do "guides/getting_started/introduction.md", "guides/getting_started/installation.md", "guides/getting_started/quickstart.livemd", + "guides/getting_started/broadcast.livemd", + "guides/getting_started/numerical_definitions.livemd", "guides/advanced/vectorization.livemd", "guides/advanced/aggregation.livemd", "guides/advanced/automatic_differentiation.livemd", From d006d3d3ffacf9d76f1987abef8e4a1d4cc8d0aa Mon Sep 17 00:00:00 2001 From: TomasPegado Date: Wed, 26 Mar 2025 19:07:54 -0300 Subject: [PATCH 06/36] docs: Adding Cheatsheet --- nx/guides/cheatsheet/cheatsheet.cheatmd | 7 +++++++ nx/mix.exs | 2 ++ 2 files changed, 9 insertions(+) create mode 100644 nx/guides/cheatsheet/cheatsheet.cheatmd diff --git a/nx/guides/cheatsheet/cheatsheet.cheatmd b/nx/guides/cheatsheet/cheatsheet.cheatmd new file mode 100644 index 0000000000..b9bdb1cc8a --- /dev/null +++ b/nx/guides/cheatsheet/cheatsheet.cheatmd @@ -0,0 +1,7 @@ +# Cheatsheet + +This cheatsheet is designed to assist Python developers in transitioning to Elixir, +specifically by providing equivalent commands and code examples between NumPy and Nx. + +## Numpy -> Nx + diff --git a/nx/mix.exs b/nx/mix.exs index 46ac5ad508..6a73c427d6 100644 --- a/nx/mix.exs +++ b/nx/mix.exs @@ -63,6 +63,7 @@ defmodule Nx.MixProject do "guides/getting_started/quickstart.livemd", "guides/getting_started/broadcast.livemd", "guides/getting_started/numerical_definitions.livemd", + "guides/cheatsheet/cheatsheet.cheatmd", "guides/advanced/vectorization.livemd", "guides/advanced/aggregation.livemd", "guides/advanced/automatic_differentiation.livemd", @@ -119,6 +120,7 @@ defmodule Nx.MixProject do ], groups_for_extras: [ "Getting Started": ~r"^guides/getting_started/", + Cheatsheet: ~r"^guides/cheatsheet/", Exercises: ~r"^guides/exercises/", Advanced: ~r"^guides/advanced/" ] From 07f05998d177e2794b4b9c1592f349d3390fb04d Mon Sep 17 00:00:00 2001 From: TomasPegado Date: Thu, 27 Mar 2025 14:21:40 -0300 Subject: [PATCH 07/36] docs: cheatsheet on array creation --- nx/guides/cheatsheet/cheatsheet.cheatmd | 47 +++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/nx/guides/cheatsheet/cheatsheet.cheatmd b/nx/guides/cheatsheet/cheatsheet.cheatmd index b9bdb1cc8a..69f46c9563 100644 --- a/nx/guides/cheatsheet/cheatsheet.cheatmd +++ b/nx/guides/cheatsheet/cheatsheet.cheatmd @@ -4,4 +4,51 @@ This cheatsheet is designed to assist Python developers in transitioning to Elix specifically by providing equivalent commands and code examples between NumPy and Nx. ## Numpy -> Nx +{: .col-2} +### Array Creation + + +#### Python code + +```python +import numpy as np + +# From list or nested list +a = np.array([1, 2, 3]) +b = np.array([[1, 2], [3, 4]]) + +# zeros and ones +np.zeros((2, 3)) # 2x3 array filled with zeros +np.ones((2, 3)) # 2x3 array filled with ones + +# Range of Numbers (like range()) +np.arange(0, 10, 2) # [0 2 4 6 8] + +# Linearly Spaced Values +np.linspace(0, 1, 5) # [0. 0.25 0.5 0.75 1. ] + +``` + +### Tensor Creation + +#### Elixir code + +```elixir +Mix.install([:nx]) + +# From list or nested list +a = Nx.tensor([1, 2, 3]) +b = Nx.tensor([[1, 2], [3, 4]]) + +# zeros and ones +Nx.broadcast(0, {2, 3}) # 2x3 tensor filled with zeros +Nx.broadcast(1, {2, 3}) # 2x3 tensor filled with ones + +# Range of Numbers (like range()) +Nx.iota({5}, axis: 0) |> Nx.multiply(2) # [0 2 4 6 8] + +# Linearly Spaced Values +Nx.iota({5}) |> Nx.divide(4) # [0.0, 0.25, 0.5, 0.75, 1.0] + +``` \ No newline at end of file From 1b658f4950501f7cd456b23ddc1e7539805d08a1 Mon Sep 17 00:00:00 2001 From: Peter Richards Date: Tue, 4 Feb 2025 04:23:40 -0800 Subject: [PATCH 08/36] Fixed to_pointer/2 docs :kind -> :mode (#1578) --- nx/lib/nx.ex | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 8206f4106b..bde368fad7 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -16812,7 +16812,7 @@ defmodule Nx do ## Options - * `:kind` - one of `:local`, `:ipc`. `:local` means the returned value + * `:mode` - one of `:local`, `:ipc`. `:local` means the returned value represents a pointer internal to the current process. `:ipc` means the returned value represents an IPC handle that can be shared between processes. Defaults to `:local`. @@ -16822,11 +16822,11 @@ defmodule Nx do ## Examples t = Nx.u8([10, 20, 30]) - Nx.to_pointer(t, kind: :local) + Nx.to_pointer(t, mode: :local) %Nx.Pointer{kind: :local, address: 1234, data_size: 3, handle: nil} t = Nx.s32([1, 2, 3]) - Nx.to_pointer(t, kind: :ipc) + Nx.to_pointer(t, mode: :ipc) %Nx.Pointer{kind: :ipc, address: nil, data_size: 32, handle: "some-ipc-handle"} """ @doc type: :creation From 2af0d8db1ef9bbfde882d7b749b03aabaced3bb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Wed, 19 Feb 2025 11:21:26 +0100 Subject: [PATCH 09/36] Refactor EXLA NIFs to use Fine (#1581) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Valim --- exla/Makefile | 8 +- exla/c_src/exla/exla.cc | 1183 +++++++--------------------- exla/c_src/exla/exla_client.cc | 271 +++---- exla/c_src/exla/exla_client.h | 30 +- exla/c_src/exla/exla_cuda.cc | 37 +- exla/c_src/exla/exla_cuda.h | 6 +- exla/c_src/exla/exla_log_sink.h | 94 +-- exla/c_src/exla/exla_mlir.cc | 76 +- exla/c_src/exla/exla_mlir.h | 26 +- exla/c_src/exla/exla_nif_util.cc | 317 -------- exla/c_src/exla/exla_nif_util.h | 528 ++++++------- exla/c_src/exla/ipc.cc | 4 +- exla/c_src/exla/ipc.h | 2 +- exla/lib/exla/backend.ex | 61 +- exla/lib/exla/client.ex | 28 +- exla/lib/exla/device_buffer.ex | 21 +- exla/lib/exla/executable.ex | 23 +- exla/lib/exla/mlir/context_pool.ex | 4 +- exla/lib/exla/mlir/function.ex | 10 +- exla/lib/exla/mlir/module.ex | 16 +- exla/lib/exla/mlir/value.ex | 7 - exla/lib/exla/nif.ex | 123 +-- exla/lib/exla/typespec.ex | 41 - exla/mix.exs | 2 + exla/mix.lock | 1 + nx/lib/nx.ex | 4 +- 26 files changed, 886 insertions(+), 2037 deletions(-) delete mode 100644 exla/c_src/exla/exla_nif_util.cc diff --git a/exla/Makefile b/exla/Makefile index ada1bb7001..77b863dd7c 100644 --- a/exla/Makefile +++ b/exla/Makefile @@ -22,10 +22,8 @@ XLA_EXTENSION_LIB_LINK_PATH = ../$(CWD_RELATIVE_TO_PRIV_PATH)/$(XLA_EXTENSION_LI EXLA_CACHE_SO_LINK_PATH = $(CWD_RELATIVE_TO_PRIV_PATH)/$(EXLA_CACHE_SO) # Build flags -# c++17 is needed, otherwise xla headers -# break on some conflicting llvm/std definitions -# Note: this is on :xla 0.5.0 -- things can change with later versions -CFLAGS = -fPIC -I$(ERTS_INCLUDE_DIR) -I$(XLA_INCLUDE_PATH) -Wall -Wno-sign-compare \ +# Note that XLA requires c++17, Fine as well +CFLAGS = -fPIC -I$(ERTS_INCLUDE_DIR) -I$(FINE_INCLUDE_DIR) -I$(XLA_INCLUDE_PATH) -Wall -Wno-sign-compare \ -Wno-unused-parameter -Wno-missing-field-initializers -Wno-comment \ -std=c++17 -w -DLLVM_VERSION_STRING= @@ -82,7 +80,7 @@ $(EXLA_SO): $(EXLA_CACHE_SO) ln -sf $(EXLA_CACHE_SO_LINK_PATH) $(EXLA_SO) ; \ fi -SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/exla_nif_util.cc $(EXLA_DIR)/ipc.cc +SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/ipc.cc SOURCES += $(wildcard $(EXLA_DIR)/custom_calls/*.cc) HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls/qr.h $(EXLA_DIR)/custom_calls/eigh.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h $(EXLA_DIR)/ipc.h OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index 42d0cbb9a4..2108de5c6e 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -1,5 +1,7 @@ -#include +#include +#include #include +#include #include "exla_client.h" #include "exla_cuda.h" @@ -7,1085 +9,500 @@ #include "exla_mlir.h" #include "exla_nif_util.h" #include "ipc.h" +#include "mlir/IR/MLIRContext.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" #include "xla/pjrt/pjrt_api.h" #include "xla/service/platform_util.h" +#include "xla/statusor.h" #include "llvm/Support/ThreadPool.h" -// All of these are created with calls to `new` and subsequently -// passed to the VM as pointers-to-pointers so we balance it out -// with calls to delete rather than just using the default destructor. +namespace exla { -void free_exla_executable(ErlNifEnv* env, void* obj) { - exla::ExlaExecutable** executable = reinterpret_cast(obj); - if (*executable != nullptr) { - delete *executable; - *executable = nullptr; - } -} - -void free_exla_client(ErlNifEnv* env, void* obj) { - exla::ExlaClient** client = reinterpret_cast(obj); - if (*client != nullptr) { - delete *client; - *client = nullptr; - } -} - -void free_exla_buffer(ErlNifEnv* env, void* obj) { - exla::ExlaBuffer** buffer = reinterpret_cast(obj); - if (*buffer != nullptr) { - delete *buffer; - *buffer = nullptr; - } -} - -static int open_resources(ErlNifEnv* env) { - const char* mod = "EXLA"; - - if (!exla::nif::open_resource(env, mod, "Executable", free_exla_executable)) { - return -1; - } - if (!exla::nif::open_resource(env, mod, "ExlaClient", free_exla_client)) { - return -1; - } - if (!exla::nif::open_resource(env, mod, "ExlaBuffer", free_exla_buffer)) { - return -1; - } - // MLIR - if (!exla::nif::open_resource(env, mod, "MLIRFunction")) { - return -1; - } - if (!exla::nif::open_resource(env, mod, "MLIRValue")) { - return -1; - } - if (!exla::nif::open_resource(env, mod, "MLIRRegion")) { - return -1; - } - if (!exla::nif::open_resource(env, mod, "ExlaMLIRModule")) { - return -1; - } - - if (!exla::nif::open_resource(env, mod, "MLIRContext")) { - return -1; - } - - if (!exla::nif::open_resource(env, mod, "TheadPool")) { - return -1; - } - return 1; -} - -static int load(ErlNifEnv* env, void** priv, ERL_NIF_TERM load_info) { - if (open_resources(env) == -1) return -1; - - return 0; -} - -static int upgrade(ErlNifEnv* env, void** priv_data, void** old_priv_data, ERL_NIF_TERM load_info) { - // Silence "unused var" warnings. - (void)(env); - (void)(priv_data); - (void)(old_priv_data); - (void)(load_info); - - return 0; -} +FINE_RESOURCE(llvm::StdThreadPool); +FINE_RESOURCE(mlir::MLIRContext); +FINE_RESOURCE(mlir::Value); +FINE_RESOURCE(mlir::Region); +FINE_RESOURCE(exla::ExlaClient); +FINE_RESOURCE(exla::ExlaBuffer); +FINE_RESOURCE(exla::ExlaExecutable); +FINE_RESOURCE(exla::MLIRModule); +FINE_RESOURCE(exla::MLIRFunction); // MLIR Functions -ERL_NIF_TERM type_parsing_error(ErlNifEnv* env, std::string type_string) { - return exla::nif::make(env, "Unable to parse MLIR type: " + type_string); -} - -ERL_NIF_TERM attribute_parsing_error(ErlNifEnv* env, std::string attribute_string) { - return exla::nif::make(env, "Unable to parse MLIR attribute: " + attribute_string); -} - -ERL_NIF_TERM mlir_compile(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 7) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::ExlaClient** client; - exla::MLIRModule** module; - std::vector argument_layouts; - xla::ExecutableBuildOptions build_options; - int num_replicas; - int num_partitions; - bool use_spmd; - int device_id; - - if (!exla::nif::get(env, argv[0], client)) { - return exla::nif::error(env, "Unable to get client."); - } - if (!exla::nif::get(env, argv[1], module)) { - return exla::nif::error(env, "Unable to get module."); - } - if (!exla::nif::get_list(env, argv[2], argument_layouts)) { - return exla::nif::error(env, "Unable to get argument layouts."); - } - if (!exla::nif::get(env, argv[3], &num_replicas)) { - return exla::nif::error(env, "Unable to get Number of Replicas."); - } - if (!exla::nif::get(env, argv[4], &num_partitions)) { - return exla::nif::error(env, "Unable to get Number of Partitions."); - } - if (!exla::nif::get(env, argv[5], &use_spmd)) { - return exla::nif::error(env, "Unable to get SPMD Partitioning Flag."); - } - if (!exla::nif::get(env, argv[6], &device_id)) { - return exla::nif::error(env, "Unable to get device ID."); - } - - build_options.set_num_replicas(num_replicas); - build_options.set_num_partitions(num_partitions); - build_options.set_use_spmd_partitioning(use_spmd); - - bool compile_portable_executable = false; - if (device_id >= 0) { - compile_portable_executable = true; - build_options.set_device_ordinal(device_id); +fine::ResourcePtr decode_exla_buffer(ErlNifEnv *env, + fine::Term buffer_term) { + try { + return fine::decode>(env, buffer_term); + } catch (std::invalid_argument) { + throw std::invalid_argument( + "unable to get buffer. It may belong to another node, " + "consider using Nx.backend_transfer/1"); } - - EXLA_ASSIGN_OR_RETURN_NIF(exla::ExlaExecutable * executable, - (*client)->Compile((*module)->module(), argument_layouts, build_options, compile_portable_executable), env); - - return exla::nif::ok(env, exla::nif::make(env, executable)); } - -ERL_NIF_TERM mlir_new_thread_pool(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 1) { - return exla::nif::error(env, "Bad argument count."); - } - - int concurrency; - - if (!exla::nif::get(env, argv[0], &concurrency)) { - return exla::nif::error(env, "Unable to get concurrency."); - } - - llvm::ThreadPoolStrategy strategy = llvm::hardware_concurrency(concurrency); - llvm::StdThreadPool* pool = new llvm::StdThreadPool(strategy); - - auto ret = exla::nif::make(env, pool); - return exla::nif::ok(env, ret); +fine::ResourcePtr +mlir_new_thread_pool(ErlNifEnv *env, int64_t concurrency) { + auto strategy = llvm::hardware_concurrency(concurrency); + return fine::make_resource(strategy); } -ERL_NIF_TERM mlir_new_context(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 1) { - return exla::nif::error(env, "Bad argument count."); - } - - llvm::StdThreadPool** thread_pool; - - if (!exla::nif::get(env, argv[0], thread_pool)) { - return exla::nif::error(env, "Unable to get thread pool."); - } +FINE_NIF(mlir_new_thread_pool, 0); - mlir::MLIRContext* context = new mlir::MLIRContext(mlir::MLIRContext::Threading::DISABLED); +fine::ResourcePtr +mlir_new_context(ErlNifEnv *env, + fine::ResourcePtr thread_pool) { + auto context = fine::make_resource( + mlir::MLIRContext::Threading::DISABLED); - auto interface_ptr = reinterpret_cast(*thread_pool); - context->setThreadPool(*interface_ptr); + context->setThreadPool(*thread_pool); context->getOrLoadDialect(); context->getOrLoadDialect(); context->getOrLoadDialect(); - auto ret = exla::nif::make(env, context); - return exla::nif::ok(env, ret); + return context; } -ERL_NIF_TERM mlir_new_module(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 1) { - return exla::nif::error(env, "Bad argument count."); - } - - mlir::MLIRContext** ctx; - - if (!exla::nif::get(env, argv[0], ctx)) { - return exla::nif::error(env, "Unable to get context."); - } +FINE_NIF(mlir_new_context, 0); - exla::MLIRModule* module = new exla::MLIRModule(*ctx); - - return exla::nif::ok(env, exla::nif::make(env, module)); +fine::ResourcePtr +mlir_new_module(ErlNifEnv *env, fine::ResourcePtr ctx) { + return fine::make_resource(ctx); } -ERL_NIF_TERM mlir_create_function(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 5) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRModule** module; - std::string func_name; - std::vector arg_type_strings; - std::vector ret_type_strings; - bool is_public; - - if (!exla::nif::get(env, argv[0], module)) { - return exla::nif::error(env, "Unable to get module."); - } - if (!exla::nif::get(env, argv[1], func_name)) { - return exla::nif::error(env, "Unable to get function name."); - } - if (!exla::nif::get_list(env, argv[2], arg_type_strings)) { - return exla::nif::error(env, "Unable to get args."); - } - if (!exla::nif::get_list(env, argv[3], ret_type_strings)) { - return exla::nif::error(env, "Unable to get return."); - } - if (!exla::nif::get(env, argv[4], &is_public)) { - return exla::nif::error(env, "Unable to get is_public."); - } +FINE_NIF(mlir_new_module, 0); +fine::ResourcePtr mlir_create_function( + ErlNifEnv *env, fine::ResourcePtr module, std::string func_name, + std::vector arg_type_strings, + std::vector ret_type_strings, bool is_public) { auto arg_types = std::vector{}; - for (auto const& type_string : arg_type_strings) { - auto type = (*module)->ParseType(type_string); - if (type == nullptr) { - return type_parsing_error(env, type_string); - } + for (auto const &type_string : arg_type_strings) { + auto type = module->ParseType(type_string); arg_types.push_back(type); } auto ret_types = std::vector{}; - for (auto const& type_string : ret_type_strings) { - auto type = (*module)->ParseType(type_string); - if (type == nullptr) { - return type_parsing_error(env, type_string); - } + for (auto const &type_string : ret_type_strings) { + auto type = module->ParseType(type_string); ret_types.push_back(type); } - exla::MLIRFunction* func = (*module)->CreateFunction(func_name, arg_types, ret_types, is_public); - - return exla::nif::ok(env, exla::nif::make(env, func)); + auto func_op = + module->CreateFunction(func_name, arg_types, ret_types, is_public); + return fine::make_resource(module, std::move(func_op)); } -ERL_NIF_TERM mlir_get_function_arguments(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 1) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } +FINE_NIF(mlir_create_function, 0); - llvm::MutableArrayRef args = (*function)->GetArguments(); - std::vector terms; - terms.reserve(args.size()); +std::vector> +mlir_get_function_arguments(ErlNifEnv *env, + fine::ResourcePtr function) { + auto args = function->GetArguments(); + std::vector> values; + values.reserve(args.size()); - for (auto arg : args) { - ERL_NIF_TERM term = exla::nif::make(env, arg); - terms.push_back(term); + for (const auto &arg : args) { + values.push_back(fine::make_resource(arg)); } - return exla::nif::ok(env, enif_make_list_from_array(env, terms.data(), terms.size())); + return values; } -ERL_NIF_TERM mlir_op(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 6) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - std::string op_name; - std::vector operands; - std::vector result_type_strings; - std::vector> attributes_kwlist; - std::vector regions; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], op_name)) { - return exla::nif::error(env, "Unable to get op name."); - } - if (!exla::nif::get_list(env, argv[2], operands)) { - return exla::nif::error(env, "Unable to get operands."); - } - if (!exla::nif::get_list(env, argv[3], result_type_strings)) { - return exla::nif::error(env, "Unable to get result types."); - } - if (!exla::nif::get_keyword_list(env, argv[4], attributes_kwlist)) { - return exla::nif::error(env, "Unable to get attributes."); - } - if (!exla::nif::get_list(env, argv[5], regions)) { - return exla::nif::error(env, "Unable to get regions."); - } +FINE_NIF(mlir_get_function_arguments, 0); +std::vector> +mlir_op(ErlNifEnv *env, fine::ResourcePtr function, + std::string op_name, + std::vector> operands, + std::vector result_type_strings, + std::vector> attributes_kwlist, + std::vector> regions) { auto result_types = std::vector{}; - for (auto const& type_string : result_type_strings) { - auto type = (*function)->module()->ParseType(type_string); - if (type == nullptr) { - return type_parsing_error(env, type_string); - } + for (auto const &type_string : result_type_strings) { + auto type = function->module()->ParseType(type_string); result_types.push_back(type); } - auto attributes = std::vector>{}; + auto attributes = std::vector>{}; - for (auto const& pair : attributes_kwlist) { - auto attribute_value = (*function)->module()->ParseAttribute(pair.second); - if (attribute_value == nullptr) { - return attribute_parsing_error(env, pair.second); - } - attributes.push_back(std::pair{pair.first, attribute_value}); + for (auto const &[key, value] : attributes_kwlist) { + auto attribute_value = function->module()->ParseAttribute(value); + attributes.push_back(std::make_tuple(key.to_string(), attribute_value)); } - auto results = (*function)->Op(op_name, operands, result_types, attributes, regions); - - return exla::nif::ok(env, exla::nif::make_list(env, results)); + return function->Op(op_name, operands, result_types, attributes, regions); } -ERL_NIF_TERM mlir_push_region(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 2) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - std::vector arg_types; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get_list(env, argv[1], arg_types)) { - return exla::nif::error(env, "Unable to get arg types."); - } +FINE_NIF(mlir_op, 0); +std::tuple, + std::vector>> +mlir_push_region(ErlNifEnv *env, fine::ResourcePtr function, + std::vector arg_types) { auto types = std::vector{}; - for (auto const& type_string : arg_types) { - auto type = (*function)->module()->ParseType(type_string); - if (type == nullptr) { - return type_parsing_error(env, type_string); - } + for (auto const &type_string : arg_types) { + auto type = function->module()->ParseType(type_string); types.push_back(type); } - mlir::Region* region; - std::vector args; - std::tie(region, args) = (*function)->PushRegion(types); - - return exla::nif::ok(env, enif_make_tuple2(env, exla::nif::make(env, region), exla::nif::make_list(env, args))); + return function->PushRegion(types); } -ERL_NIF_TERM mlir_pop_region(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 1) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } +FINE_NIF(mlir_push_region, 0); - (*function)->PopRegion(); - return exla::nif::ok(env); +fine::Ok<> mlir_pop_region(ErlNifEnv *env, + fine::ResourcePtr function) { + function->PopRegion(); + return fine::Ok(); } -std::string mlir_numeric_type_to_string(mlir::Type type) { - if (type.isSignlessInteger(1)) { - return "pred"; - } - if (auto integer_type = type.dyn_cast()) { - if (integer_type.isUnsigned()) { - return "u" + std::to_string(integer_type.getWidth()); - } else { - return "s" + std::to_string(integer_type.getWidth()); - } - } - if (type.isBF16()) { - return "bf16"; - } - if (auto float_type = type.dyn_cast()) { - return "f" + std::to_string(float_type.getWidth()); - } - if (auto complex_type = type.dyn_cast()) { - auto element_type = complex_type.getElementType(); - return "c" + std::to_string(element_type.cast().getWidth() * 2); - } +FINE_NIF(mlir_pop_region, 0); - std::cerr << "Unexpected mlir type" << std::endl; - exit(1); +mlir::Type mlir_get_typespec(ErlNifEnv *env, + fine::ResourcePtr value) { + return value->getType(); } -ERL_NIF_TERM make_typespec(ErlNifEnv* env, mlir::Type type) { - if (type.isa()) { - auto type_term = exla::nif::make(env, "token"); - auto shape_term = enif_make_tuple(env, 0); - - return enif_make_tuple(env, 2, type_term, shape_term); - } - - if (type.isa()) { - auto tensor_type = type.cast(); - auto dims = tensor_type.getShape(); - auto element_type = tensor_type.getElementType(); +FINE_NIF(mlir_get_typespec, 0); - auto dims_array = std::vector{}; - dims_array.reserve(dims.size()); - - for (auto dim : dims) { - dims_array.push_back(enif_make_int(env, dim)); - } +std::string mlir_module_to_string(ErlNifEnv *env, + fine::ResourcePtr module) { + return module->ToString(); +} - auto type_term = exla::nif::make(env, mlir_numeric_type_to_string(element_type)); - auto shape_term = enif_make_tuple_from_array(env, dims_array.data(), dims_array.size()); +FINE_NIF(mlir_module_to_string, 0); - return enif_make_tuple(env, 2, type_term, shape_term); +template T unwrap(xla::StatusOr status_or) { + if (!status_or.ok()) { + throw std::runtime_error(status_or.status().message().data()); } - std::cerr << "Unexpected mlir type" << std::endl; - exit(1); + return std::move(status_or.value()); } -ERL_NIF_TERM mlir_get_typespec(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 1) { - return exla::nif::error(env, "Bad argument count."); +void unwrap(xla::Status status) { + if (!status.ok()) { + throw std::runtime_error(status.message().data()); } - - mlir::Value* t; - - if (!exla::nif::get(env, argv[0], t)) { - return exla::nif::error(env, "Unable to get tensor."); - } - - mlir::Type type = t->getType(); - - return exla::nif::ok(env, make_typespec(env, type)); } -ERL_NIF_TERM mlir_module_to_string(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 1) { - return exla::nif::error(env, "Bad argument count."); - } +fine::ResourcePtr +mlir_compile(ErlNifEnv *env, fine::ResourcePtr client, + fine::ResourcePtr module, + std::vector argument_layouts, int64_t num_replicas, + int64_t num_partitions, bool use_spmd, int64_t device_id) { + auto build_options = xla::ExecutableBuildOptions(); - exla::MLIRModule** module; + build_options.set_num_replicas(num_replicas); + build_options.set_num_partitions(num_partitions); + build_options.set_use_spmd_partitioning(use_spmd); - if (!exla::nif::get(env, argv[0], module)) { - return exla::nif::error(env, "Unable to get builder."); + auto compile_portable_executable = false; + if (device_id >= 0) { + compile_portable_executable = true; + build_options.set_device_ordinal(device_id); } - std::string string = (*module)->ToString(); - - ErlNifBinary bin; - enif_alloc_binary(string.size(), &bin); - memcpy(bin.data, string.c_str(), string.size()); - - return exla::nif::ok(env, exla::nif::make(env, bin)); + return unwrap(client->Compile(module->module(), argument_layouts, + build_options, compile_portable_executable)); } +FINE_NIF(mlir_compile, ERL_NIF_DIRTY_JOB_CPU_BOUND); + // ExlaBuffer Functions -ERL_NIF_TERM get_buffer_device_pointer(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 3) { - return exla::nif::error(env, "Bad argument count."); - } +std::variant, + fine::Ok, + fine::Ok, fine::Error> +get_buffer_device_pointer(ErlNifEnv *env, fine::ResourcePtr client, + fine::Term buffer_term, fine::Atom pointer_kind) { + auto buffer = decode_exla_buffer(env, buffer_term); - exla::ExlaClient** client; - exla::ExlaBuffer** buffer; - std::string pointer_kind; + uint64_t device_size = unwrap(buffer->GetOnDeviceSizeInBytes()); + uint64_t ptr = unwrap(buffer->GetDevicePointer(client->client())); - if (!exla::nif::get(env, argv[0], client)) { - return exla::nif::error(env, "Unable to get client."); - } - if (!exla::nif::get(env, argv[1], buffer)) { - return exla::nif::error(env, "Unable to get buffer (it may belong to another node or have been garbage collected, consider using Nx.backend_transfer/1)."); - } - if (!exla::nif::get_atom(env, argv[2], pointer_kind)) { - return exla::nif::error(env, "Unable to get device pointer kind."); + if (pointer_kind == "local") { + return fine::Ok(ptr, device_size); } - EXLA_ASSIGN_OR_RETURN_NIF(unsigned long device_size, (*buffer)->GetOnDeviceSizeInBytes(), env); - - EXLA_ASSIGN_OR_RETURN_NIF(std::uintptr_t ptr, - (*buffer)->GetDevicePointer((*client)->client()), env); - - ERL_NIF_TERM out_term; - if (pointer_kind == "local") { - ERL_NIF_TERM ptr_term = enif_make_ulong(env, ptr); - ERL_NIF_TERM size_term = enif_make_ulong(env, device_size); - out_term = enif_make_tuple2(env, ptr_term, size_term); - } else if (pointer_kind == "host_ipc") { - std::ostringstream handle_name_stream; - handle_name_stream << "exla:ipc:" << device_size << ":" << ptr; - std::string handle_name = handle_name_stream.str(); - int fd = get_ipc_handle((char*)handle_name.c_str(), device_size); + if (pointer_kind == "host_ipc") { + auto handle_name = + "exla:ipc:" + std::to_string(device_size) + ":" + std::to_string(ptr); + auto fd = get_ipc_handle(handle_name.c_str(), device_size); if (fd == -1) { - return exla::nif::error(env, "Unable to get IPC handle"); + return fine::Error(std::string("unable to get IPC handle")); } - void* ipc_ptr = open_ipc_handle(fd, device_size); + auto ipc_ptr = open_ipc_handle(fd, device_size); if (ipc_ptr == nullptr) { - return exla::nif::error(env, "Unable to open IPC handle"); + return fine::Error(std::string("unable to open IPC handle")); } - memcpy(ipc_ptr, (void*)ptr, device_size); + memcpy(ipc_ptr, reinterpret_cast(ptr), device_size); - ErlNifBinary handle_name_bin; - enif_alloc_binary(handle_name.size(), &handle_name_bin); - for (int i = 0; i < handle_name.size(); i++) { - handle_name_bin.data[i] = handle_name[i]; - } - ERL_NIF_TERM handle_name_term = enif_make_binary(env, &handle_name_bin); - ERL_NIF_TERM size_term = enif_make_uint64(env, device_size); - ERL_NIF_TERM fd_term = enif_make_int(env, fd); - out_term = enif_make_tuple3(env, handle_name_term, fd_term, size_term); - } else if (pointer_kind == "cuda_ipc") { - auto result = get_cuda_ipc_handle(ptr); - if (result.second) { - return exla::nif::error(env, "Unable to get cuda IPC handle"); - } - auto pointer_vec = result.first; + return fine::Ok(handle_name, static_cast(fd), device_size); + } - ErlNifBinary handle_bin; - enif_alloc_binary(pointer_vec.size(), &handle_bin); - for (int i = 0; i < pointer_vec.size(); i++) { - handle_bin.data[i] = pointer_vec[i]; + if (pointer_kind == "cuda_ipc") { + auto maybe_handle = get_cuda_ipc_handle(ptr); + if (!maybe_handle) { + return fine::Error(std::string("unable to get cuda IPC handle")); } - ERL_NIF_TERM handle_term = enif_make_binary(env, &handle_bin); - ERL_NIF_TERM size_term = enif_make_uint64(env, device_size); - out_term = enif_make_tuple2(env, handle_term, size_term); + + return fine::Ok(maybe_handle.value(), device_size); } - return exla::nif::ok(env, out_term); + throw std::invalid_argument("unexpected pointer type"); } -ERL_NIF_TERM create_buffer_from_device_pointer(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 5) { - return exla::nif::error(env, "Bad argument count."); - } +FINE_NIF(get_buffer_device_pointer, 0); - exla::ExlaClient** client; - ErlNifBinary cuda_ipc_handle_bin; - int cuda_ipc_handle_size = 0; - xla::Shape shape; - int device_id; - std::string pointer_kind; - void* ptr; - int fd = -1; - std::string memname; - - if (!exla::nif::get(env, argv[0], client)) { - return exla::nif::error(env, "Unable to get client."); - } - if (!exla::nif::get_atom(env, argv[1], pointer_kind)) { - return exla::nif::error(env, "Unable to get device pointer kind."); - } +std::variant>, fine::Error> +create_buffer_from_device_pointer(ErlNifEnv *env, + fine::ResourcePtr client, + fine::Atom pointer_kind, + fine::Term pointer_data, xla::Shape shape, + int64_t device_id) { + void *ptr = nullptr; + std::function on_delete_callback = []() {}; if (pointer_kind == "cuda_ipc") { - if (!enif_inspect_binary(env, argv[2], &cuda_ipc_handle_bin)) { - return exla::nif::error(env, "Unable to get CUDA IPC handle."); + auto cuda_ipc_handle_bin = fine::decode(env, pointer_data); + auto maybe_pointer = get_pointer_for_ipc_handle( + cuda_ipc_handle_bin.data, cuda_ipc_handle_bin.size, device_id); + if (!maybe_pointer) { + return fine::Error("unable to get pointer for IPC handle"); } + ptr = maybe_pointer.value(); } else if (pointer_kind == "host_ipc") { - const ERL_NIF_TERM* tuple; - int arity; - if ( - !enif_get_tuple(env, argv[2], &arity, &tuple) || - (arity != 2) || - !exla::nif::get(env, tuple[0], &fd) || - (fd == -1) || - !exla::nif::get(env, tuple[1], memname)) { - return exla::nif::error(env, "Unable to get IPC handle."); - } - } else if (pointer_kind == "local") { - int64_t ptr_int; - if (!exla::nif::get(env, argv[2], &ptr_int)) { - return exla::nif::error(env, "Unable to get pointer."); - } - - ptr = (void*)ptr_int; - } - - if (!exla::nif::get_typespec_as_xla_shape(env, argv[3], &shape)) { - return exla::nif::error(env, "Unable to get shape."); - } - if (!exla::nif::get(env, argv[4], &device_id)) { - return exla::nif::error(env, "Unable to get device ordinal."); - } - - std::function on_delete_callback = []() {}; - - if (pointer_kind == "host_ipc") { - size_t device_size = (size_t)xla::ShapeUtil::ByteSizeOf(shape); - + auto tuple = + fine::decode>(env, pointer_data); + auto fd = std::get<0>(tuple); + auto memname = std::get<1>(tuple); + auto device_size = xla::ShapeUtil::ByteSizeOf(shape); ptr = open_ipc_handle(fd, device_size); if (ptr == nullptr) { - return exla::nif::error(env, "Unable to get pointer for IPC handle."); + return fine::Error("unable to get pointer for IPC handle"); } - on_delete_callback = [fd, memname, ptr, device_size]() { - close_ipc_handle(fd, ptr, (char*)memname.c_str(), device_size); + close_ipc_handle(fd, ptr, memname.c_str(), device_size); }; - } else if (pointer_kind == "cuda_ipc") { - auto result = get_pointer_for_ipc_handle(cuda_ipc_handle_bin.data, cuda_ipc_handle_bin.size, device_id); - if (result.second) { - return exla::nif::error(env, "Unable to get pointer for IPC handle."); - } - ptr = result.first; + } else if (pointer_kind == "local") { + auto ptr_int = fine::decode(env, pointer_data); + ptr = reinterpret_cast(ptr_int); + } else { + throw std::invalid_argument("unexpected pointer type"); } - EXLA_ASSIGN_OR_RETURN_NIF(xla::PjRtDevice * device, (*client)->client()->LookupDevice(xla::PjRtGlobalDeviceId(device_id)), env); - - EXLA_ASSIGN_OR_RETURN_NIF(std::unique_ptr buffer, (*client)->client()->CreateViewOfDeviceBuffer(ptr, shape, device, on_delete_callback), env); - - exla::ExlaBuffer* exla_buffer = new exla::ExlaBuffer(std::move(buffer)); - return exla::nif::ok(env, exla::nif::make(env, exla_buffer)); + auto device = unwrap( + client->client()->LookupDevice(xla::PjRtGlobalDeviceId(device_id))); + auto buffer = unwrap(client->client()->CreateViewOfDeviceBuffer( + ptr, shape, device, on_delete_callback)); + return fine::Ok(fine::make_resource(std::move(buffer))); } -ERL_NIF_TERM binary_to_device_mem(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 4) { - return exla::nif::error(env, "Bad argument count."); - } - - xla::Shape shape; - exla::ExlaClient** client; - int device_id; - - if (!exla::nif::get(env, argv[0], client)) { - return exla::nif::error(env, "Unable to get client."); - } - if (!exla::nif::get_typespec_as_xla_shape(env, argv[2], &shape)) { - return exla::nif::error(env, "Unable to get shape."); - } - if (!exla::nif::get(env, argv[3], &device_id)) { - return exla::nif::error(env, "Unable to get device ordinal."); - } +FINE_NIF(create_buffer_from_device_pointer, 0); - EXLA_ASSIGN_OR_RETURN_NIF(exla::ExlaBuffer * buffer, - (*client)->BufferFromBinary(env, argv[1], shape, device_id), env); - return exla::nif::ok(env, exla::nif::make(env, buffer)); +fine::ResourcePtr +binary_to_device_mem(ErlNifEnv *env, fine::ResourcePtr client, + fine::Term data, xla::Shape shape, int64_t device_id) { + return unwrap(client->BufferFromBinary(data, shape, device_id)); } -ERL_NIF_TERM read_device_mem(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 2) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::ExlaBuffer** buffer; - exla::int64 size; - - if (!exla::nif::get(env, argv[0], buffer)) { - return exla::nif::error(env, "Unable to get buffer (it may belong to another node or have been garbage collected, consider using Nx.backend_transfer/1)."); - } - if (!exla::nif::get(env, argv[1], &size)) { - return exla::nif::error(env, "Unable to get size."); - } +FINE_NIF(binary_to_device_mem, ERL_NIF_DIRTY_JOB_IO_BOUND); - EXLA_ASSIGN_OR_RETURN_NIF(ERL_NIF_TERM binary, (*buffer)->ToBinary(env, size), env); - - return exla::nif::ok(env, binary); +fine::Term read_device_mem(ErlNifEnv *env, fine::Term buffer_term, + int64_t size) { + auto buffer = decode_exla_buffer(env, buffer_term); + return unwrap(buffer->ToBinary(env, size)); } -ERL_NIF_TERM deallocate_device_mem(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 1) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::ExlaBuffer** buffer; +FINE_NIF(read_device_mem, ERL_NIF_DIRTY_JOB_IO_BOUND); - if (!exla::nif::get(env, argv[0], buffer)) { - return exla::nif::error(env, "Unable to get buffer (it may belong to another node or have been garbage collected, consider using Nx.backend_transfer/1)."); - } +std::variant, fine::Error> +deallocate_device_mem(ErlNifEnv *env, fine::Term buffer_term) { + auto buffer = decode_exla_buffer(env, buffer_term); - xla::Status dealloc_status = (*buffer)->Deallocate(); + xla::Status dealloc_status = buffer->Deallocate(); if (!dealloc_status.ok()) { - return exla::nif::atom(env, "already_deallocated"); + return fine::Error(atoms::already_deallocated); } else { - return exla::nif::ok(env); + return fine::Ok(); } } -ERL_NIF_TERM transfer_to_infeed(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 3) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::ExlaClient** client; - int device_id; - ERL_NIF_TERM data = argv[2]; - - if (!exla::nif::get(env, argv[0], client)) { - return exla::nif::error(env, "Unable to get client."); - } - if (!exla::nif::get(env, argv[1], &device_id)) { - return exla::nif::error(env, "Unable to get device ID."); - } +FINE_NIF(deallocate_device_mem, ERL_NIF_DIRTY_JOB_IO_BOUND); - std::vector buffer_bins; - std::vector shapes; +fine::Ok<> transfer_to_infeed(ErlNifEnv *env, + fine::ResourcePtr client, + int64_t device_id, + std::vector buffers, + std::vector shapes) { + unwrap(client->TransferToInfeed(env, buffers, shapes, device_id)); - ERL_NIF_TERM head, tail; - while (enif_get_list_cell(env, data, &head, &tail)) { - const ERL_NIF_TERM* terms; - int count; - - if (!enif_get_tuple(env, head, &count, &terms) && count != 2) { - return exla::nif::error(env, "Unable to {binary, shape} tuple."); - } - - ErlNifBinary buffer_bin; - if (!exla::nif::get_binary(env, terms[0], &buffer_bin)) { - return exla::nif::error(env, "Unable to binary."); - } - - xla::Shape shape; - if (!exla::nif::get_typespec_as_xla_shape(env, terms[1], &shape)) { - return exla::nif::error(env, "Unable to get shape."); - } - - buffer_bins.push_back(buffer_bin); - shapes.push_back(shape); - - data = tail; - } - - xla::Status transfer_status = (*client)->TransferToInfeed(env, buffer_bins, shapes, device_id); - - if (!transfer_status.ok()) { - return exla::nif::error(env, transfer_status.message().data()); - } - - return exla::nif::ok(env); + return fine::Ok(); } -ERL_NIF_TERM transfer_from_outfeed(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 5) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::ExlaClient** client; - int device_id; - ErlNifPid pid; - - if (!exla::nif::get(env, argv[0], client)) { - return exla::nif::error(env, "Unable to get client."); - } - if (!exla::nif::get(env, argv[1], &device_id)) { - return exla::nif::error(env, "Unable to get device ID."); - } - if (!enif_get_local_pid(env, argv[3], &pid)) { - return exla::nif::error(env, "Unable to get pid."); - } - - ERL_NIF_TERM data = argv[2]; - ERL_NIF_TERM head, tail; - while (enif_get_list_cell(env, data, &head, &tail)) { - xla::Shape shape; - - if (!exla::nif::get_typespec_as_xla_shape(env, head, &shape)) { - return exla::nif::error(env, "Unable to get shape."); - } - - ErlNifEnv* penv = enif_alloc_env(); - ERL_NIF_TERM ref = enif_make_copy(penv, argv[4]); - auto statusor = (*client)->TransferFromOutfeed(penv, device_id, shape); - - if (!statusor.ok()) { - enif_clear_env(penv); - return exla::nif::error(env, statusor.status().message().data()); - } - - ERL_NIF_TERM msg = std::move(statusor.value()); +FINE_NIF(transfer_to_infeed, ERL_NIF_DIRTY_JOB_IO_BOUND); - if (!enif_send(env, &pid, penv, enif_make_tuple(penv, 2, ref, msg))) { - enif_clear_env(penv); - } - - data = tail; +fine::Ok<> transfer_from_outfeed(ErlNifEnv *env, + fine::ResourcePtr client, + int64_t device_id, + std::vector shapes, ErlNifPid pid, + fine::Term ref) { + for (auto &shape : shapes) { + auto msg_env = enif_alloc_env(); + auto msg = unwrap(client->TransferFromOutfeed(msg_env, device_id, shape)); + enif_send(env, &pid, msg_env, enif_make_tuple(msg_env, 2, ref, msg)); + enif_free_env(msg_env); } - return exla::nif::ok(env); + return fine::Ok(); } -ERL_NIF_TERM copy_buffer_to_device(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 3) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::ExlaClient** client; - exla::ExlaBuffer** buffer; - int device_id; +FINE_NIF(transfer_from_outfeed, ERL_NIF_DIRTY_JOB_IO_BOUND); - if (!exla::nif::get(env, argv[0], client)) { - return exla::nif::error(env, "Unable to get client."); - } - if (!exla::nif::get(env, argv[1], buffer)) { - return exla::nif::error(env, "Unable to get buffer (it may belong to another node or have been garbage collected, consider using Nx.backend_transfer/1)."); - } - if (!exla::nif::get(env, argv[2], &device_id)) { - return exla::nif::error(env, "Unable to get device ID."); - } +fine::ResourcePtr +copy_buffer_to_device(ErlNifEnv *env, fine::ResourcePtr client, + fine::Term buffer_term, int64_t device_id) { + auto buffer = decode_exla_buffer(env, buffer_term); - EXLA_ASSIGN_OR_RETURN_NIF(xla::PjRtDevice * device, - (*client)->client()->LookupDevice(xla::PjRtGlobalDeviceId(device_id)), env); - EXLA_ASSIGN_OR_RETURN_NIF(exla::ExlaBuffer * buf, - (*buffer)->CopyToDevice(device), env); + auto device = unwrap( + client->client()->LookupDevice(xla::PjRtGlobalDeviceId(device_id))); - return exla::nif::ok(env, exla::nif::make(env, buf)); + return unwrap(buffer->CopyToDevice(device)); } -// ExlaClient Functions +FINE_NIF(copy_buffer_to_device, ERL_NIF_DIRTY_JOB_IO_BOUND); -ERL_NIF_TERM get_host_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 0) { - return exla::nif::error(env, "Bad argument count."); - } - - EXLA_ASSIGN_OR_RETURN_NIF(exla::ExlaClient * client, exla::GetHostClient(), env); +// ExlaClient Functions - return exla::nif::ok(env, exla::nif::make(env, client)); +fine::ResourcePtr get_host_client(ErlNifEnv *env) { + return unwrap(GetHostClient()); } -ERL_NIF_TERM get_gpu_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 2) { - return exla::nif::error(env, "Bad argument count."); - } - - double memory_fraction; - bool preallocate; +FINE_NIF(get_host_client, 0); - if (!exla::nif::get(env, argv[0], &memory_fraction)) { - return exla::nif::error(env, "Unable to get memory fraction."); - } - if (!exla::nif::get(env, argv[1], &preallocate)) { - return exla::nif::error(env, "Unable to get preallocate flag."); - } - EXLA_ASSIGN_OR_RETURN_NIF(exla::ExlaClient * client, - exla::GetGpuClient(memory_fraction, preallocate, xla::GpuAllocatorConfig::Kind::kBFC), env); - - return exla::nif::ok(env, exla::nif::make(env, client)); +fine::ResourcePtr +get_gpu_client(ErlNifEnv *env, double memory_fraction, bool preallocate) { + return unwrap(GetGpuClient(memory_fraction, preallocate, + xla::GpuAllocatorConfig::Kind::kBFC)); } -ERL_NIF_TERM get_tpu_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 0) { - return exla::nif::error(env, "Bad argument count."); - } - - EXLA_ASSIGN_OR_RETURN_NIF(exla::ExlaClient * client, exla::GetTpuClient(), env); +FINE_NIF(get_gpu_client, 0); - return exla::nif::ok(env, exla::nif::make(env, client)); +fine::ResourcePtr get_tpu_client(ErlNifEnv *env) { + return unwrap(GetTpuClient()); } -ERL_NIF_TERM get_c_api_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 1) { - return exla::nif::error(env, "Bad argument count."); - } - - std::string device_type; - if (!exla::nif::get(env, argv[0], device_type)) { - return exla::nif::error(env, "Unable to get device type."); - } - - EXLA_ASSIGN_OR_RETURN_NIF(exla::ExlaClient * client, exla::GetCApiClient(device_type), env); +FINE_NIF(get_tpu_client, 0); - return exla::nif::ok(env, exla::nif::make(env, client)); +fine::ResourcePtr get_c_api_client(ErlNifEnv *env, + std::string device_type) { + return unwrap(GetCApiClient(device_type)); } -ERL_NIF_TERM load_pjrt_plugin(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 2) { - return exla::nif::error(env, "Bad argument count."); - } - - std::string device_type; - std::string library_path; - if (!exla::nif::get(env, argv[0], device_type)) { - return exla::nif::error(env, "Unable to get device type."); - } - if (!exla::nif::get(env, argv[1], library_path)) { - return exla::nif::error(env, "Unable to get library path."); - } +FINE_NIF(get_c_api_client, 0); - auto result = pjrt::LoadPjrtPlugin(device_type, library_path); - - if (!result.ok()) { - return exla::nif::error(env, result.status().message().data()); - } else { - return exla::nif::ok(env); - } +fine::Ok<> load_pjrt_plugin(ErlNifEnv *env, std::string device_type, + std::string library_path) { + unwrap(pjrt::LoadPjrtPlugin(device_type, library_path)); + return fine::Ok(); } -ERL_NIF_TERM get_device_count(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 1) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::ExlaClient** client; - - if (!exla::nif::get(env, argv[0], client)) { - return exla::nif::error(env, "Unable to get client."); - } +FINE_NIF(load_pjrt_plugin, 0); - int device_count = (*client)->client()->device_count(); - - return exla::nif::ok(env, exla::nif::make(env, device_count)); +int64_t get_device_count(ErlNifEnv *env, fine::ResourcePtr client) { + return client->client()->device_count(); } -ERL_NIF_TERM get_supported_platforms(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 0) { - return exla::nif::error(env, "Bad argument count."); - } +FINE_NIF(get_device_count, 0); - EXLA_ASSIGN_OR_RETURN_NIF( - std::vector platforms, - xla::PlatformUtil::GetSupportedPlatforms(), - env); +std::map get_supported_platforms(ErlNifEnv *env) { + auto platforms = unwrap(xla::PlatformUtil::GetSupportedPlatforms()); - std::vector platform_names; - std::map platform_info; + std::map platform_info; - for (auto& platform : platforms) { - std::string key = platform->Name(); - int device_count = platform->VisibleDeviceCount(); + for (auto &platform : platforms) { + auto key = fine::Atom(absl::AsciiStrToLower(platform->Name())); + auto device_count = platform->VisibleDeviceCount(); platform_info.insert({key, device_count}); } - return exla::nif::ok(env, exla::nif::make_map(env, platform_info)); + return platform_info; } -// ExlaExecutable Functions - -ERL_NIF_TERM run(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 4) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::ExlaClient** client; - exla::ExlaExecutable** executable; - int device_id; +FINE_NIF(get_supported_platforms, 0); - ERL_NIF_TERM arguments = argv[2]; - - if (!exla::nif::get(env, argv[0], client)) { - return exla::nif::error(env, "Unable to get client."); - } - if (!exla::nif::get(env, argv[1], executable)) { - return exla::nif::error(env, "Unable to get executable."); - } - if (!exla::nif::get(env, argv[3], &device_id)) { - return exla::nif::error(env, "Unable to get device ID."); - } - - EXLA_ASSIGN_OR_RETURN_NIF(ERL_NIF_TERM term, - (*executable)->Run(env, arguments, device_id), env); +// ExlaExecutable Functions - return term; +ExlaExecutable::RunResult run(ErlNifEnv *env, + fine::ResourcePtr executable, + ExlaExecutable::RunArguments arguments, + int64_t device_id) { + return unwrap(executable->Run(env, arguments, device_id)); } -// Serialization Functions - -ERL_NIF_TERM serialize_executable(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 1) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::ExlaExecutable** executable; - - if (!exla::nif::get(env, argv[0], executable)) { - return exla::nif::error(env, "Unable to get executable."); - } - - EXLA_ASSIGN_OR_RETURN_NIF(std::string serialized, (*executable)->SerializeExecutable(), env); - ErlNifBinary raw; - enif_alloc_binary(serialized.size(), &raw); - std::memcpy((&raw)->data, serialized.data(), serialized.size()); - - return exla::nif::ok(env, exla::nif::make(env, raw)); +ExlaExecutable::RunResult run_cpu(ErlNifEnv *env, + fine::ResourcePtr executable, + ExlaExecutable::RunArguments arguments, + int64_t device_id) { + return run(env, executable, arguments, device_id); } -ERL_NIF_TERM deserialize_executable(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 2) { - return exla::nif::error(env, "Bad argument count."); - } +FINE_NIF(run_cpu, ERL_NIF_DIRTY_JOB_CPU_BOUND); - exla::ExlaClient** client; - std::string serialized; +ExlaExecutable::RunResult run_io(ErlNifEnv *env, + fine::ResourcePtr executable, + ExlaExecutable::RunArguments arguments, + int64_t device_id) { + return run(env, executable, arguments, device_id); +} - if (!exla::nif::get(env, argv[0], client)) { - return exla::nif::error(env, "Unable to get client."); - } - if (!exla::nif::get(env, argv[1], serialized)) { - return exla::nif::error(env, "Unable to get executable."); - } +FINE_NIF(run_io, ERL_NIF_DIRTY_JOB_IO_BOUND); - EXLA_ASSIGN_OR_RETURN_NIF(exla::ExlaExecutable * executable, - (*client)->DeserializeExecutable(serialized), env); +// Serialization Functions - return exla::nif::ok(env, exla::nif::make(env, executable)); +std::string serialize_executable(ErlNifEnv *env, + fine::ResourcePtr executable) { + return unwrap(executable->SerializeExecutable()); } -// Logging +FINE_NIF(serialize_executable, 0); -ERL_NIF_TERM start_log_sink(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 1) { - return exla::nif::error(env, "Bad argument count."); - } +fine::ResourcePtr +deserialize_executable(ErlNifEnv *env, fine::ResourcePtr client, + std::string serialized) { + return unwrap(client->DeserializeExecutable(serialized)); +} - ErlNifPid logger_pid; +FINE_NIF(deserialize_executable, 0); - if (!enif_get_local_pid(env, argv[0], &logger_pid)) { - return exla::nif::error(env, "Unable to get logger pid"); - } +// Logging - exla::ExlaLogSink* sink = new exla::ExlaLogSink(logger_pid); +fine::Ok<> start_log_sink(ErlNifEnv *env, ErlNifPid logger_pid) { + ExlaLogSink *sink = new ExlaLogSink(logger_pid); // NO_DEFAULT_LOGGER doesn't behave right - for (auto* log_sink : tsl::TFGetLogSinks()) { + for (auto *log_sink : tsl::TFGetLogSinks()) { tsl::TFRemoveLogSink(log_sink); } tsl::TFAddLogSink(sink); - return exla::nif::ok(env); + return fine::Ok(); } -static ErlNifFunc exla_funcs[] = { - // MLIR Builder - {"mlir_new_thread_pool", 1, mlir_new_thread_pool}, - {"mlir_new_context", 1, mlir_new_context}, - {"mlir_new_module", 1, mlir_new_module}, - {"mlir_create_function", 5, mlir_create_function}, - {"mlir_get_function_arguments", 1, mlir_get_function_arguments}, - {"mlir_op", 6, mlir_op}, - {"mlir_push_region", 2, mlir_push_region}, - {"mlir_get_typespec", 1, mlir_get_typespec}, - {"mlir_pop_region", 1, mlir_pop_region}, - {"mlir_module_to_string", 1, mlir_module_to_string}, - // ExlaClient - {"get_host_client", 0, get_host_client}, - {"get_gpu_client", 2, get_gpu_client}, - {"get_tpu_client", 0, get_tpu_client}, - {"get_c_api_client", 1, get_c_api_client}, - {"load_pjrt_plugin", 2, load_pjrt_plugin}, - {"get_device_count", 1, get_device_count}, - {"get_supported_platforms", 0, get_supported_platforms}, - {"mlir_compile", 7, mlir_compile, ERL_NIF_DIRTY_JOB_CPU_BOUND}, - // ExlaBuffer - {"get_buffer_device_pointer", 3, get_buffer_device_pointer}, - {"create_buffer_from_device_pointer", 5, create_buffer_from_device_pointer}, - {"binary_to_device_mem", 4, binary_to_device_mem, ERL_NIF_DIRTY_JOB_IO_BOUND}, - {"read_device_mem", 2, read_device_mem, ERL_NIF_DIRTY_JOB_IO_BOUND}, - {"deallocate_device_mem", 1, deallocate_device_mem, ERL_NIF_DIRTY_JOB_IO_BOUND}, - {"transfer_to_infeed", 3, transfer_to_infeed, ERL_NIF_DIRTY_JOB_IO_BOUND}, - {"transfer_from_outfeed", 5, transfer_from_outfeed, ERL_NIF_DIRTY_JOB_IO_BOUND}, - {"copy_buffer_to_device", 3, copy_buffer_to_device, ERL_NIF_DIRTY_JOB_IO_BOUND}, - // ExlaExecutable - {"run_io", 4, run, ERL_NIF_DIRTY_JOB_IO_BOUND}, - {"run_cpu", 4, run, ERL_NIF_DIRTY_JOB_CPU_BOUND}, - // Log Sink - {"start_log_sink", 1, start_log_sink}, - // Serialization - {"serialize_executable", 1, serialize_executable}, - {"deserialize_executable", 2, deserialize_executable}}; - -ERL_NIF_INIT(Elixir.EXLA.NIF, exla_funcs, &load, NULL, &upgrade, NULL); +FINE_NIF(start_log_sink, 0); + +} // namespace exla + +FINE_INIT("Elixir.EXLA.NIF"); diff --git a/exla/c_src/exla/exla_client.cc b/exla/c_src/exla/exla_client.cc index b6bba1806b..a13f87082c 100644 --- a/exla/c_src/exla/exla_client.cc +++ b/exla/c_src/exla/exla_client.cc @@ -1,5 +1,7 @@ #include "exla_client.h" +#include +#include #include "exla_nif_util.h" #include "xla/layout_util.h" #include "xla/pjrt/gpu/gpu_helpers.h" @@ -9,6 +11,7 @@ #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/tfrt_cpu_pjrt_client.h" #include "xla/shape_util.h" +#include namespace exla { @@ -23,9 +26,15 @@ void CopyLiteralToBinary(xla::Literal* literal, ErlNifBinary* binary, exla::int6 xla::StatusOr ExlaBuffer::ToBinary(ErlNifEnv* env, exla::int64 size) { EXLA_ASSIGN_OR_RETURN(std::shared_ptr literal, buffer_->ToLiteralSync()); - ErlNifBinary binary; - CopyLiteralToBinary(literal.get(), &binary, size); - return nif::make(env, binary); + + exla::int64 actual_size = literal->size_bytes(); + if (size < 0 or size > actual_size) size = actual_size; + + ERL_NIF_TERM binary_term; + auto data = enif_make_new_binary(env, size, &binary_term); + memcpy(data, literal->untyped_data(), size); + + return binary_term; } xla::Status ExlaBuffer::Deallocate() { @@ -37,10 +46,10 @@ xla::Status ExlaBuffer::Deallocate() { } } -xla::StatusOr ExlaBuffer::CopyToDevice(xla::PjRtDevice* dst_device) { +xla::StatusOr> ExlaBuffer::CopyToDevice(xla::PjRtDevice* dst_device) { EXLA_ASSIGN_OR_RETURN(std::unique_ptr buf, buffer_->CopyToDevice(dst_device)); - return new ExlaBuffer(std::move(buf)); + return fine::make_resource(std::move(buf)); } ExlaExecutable::ExlaExecutable(std::unique_ptr executable, @@ -50,17 +59,17 @@ ExlaExecutable::ExlaExecutable(std::unique_ptr execut client_(client) {} xla::StatusOr> PjRtBufferFromBinary(xla::PjRtClient* client, - ErlNifEnv* env, ERL_NIF_TERM source_term, const xla::Shape& shape, int device_id) { + // We copy the binary term into a new env and point the buffer to + // the binary content. Since larger binaries are shared and refcounted + // this should be zero-copy. + ErlNifEnv* copy_env = enif_alloc_env(); ERL_NIF_TERM dest_term = enif_make_copy(copy_env, source_term); - ErlNifBinary binary; - if (!nif::get_binary(copy_env, dest_term, &binary)) { - return xla::InvalidArgument("Expected buffer to be binary."); - } + auto binary = fine::decode(copy_env, dest_term); xla::PjRtClient::HostBufferSemantics semantics = xla::PjRtClient::HostBufferSemantics::kImmutableZeroCopy; std::function on_done_with_host_buffer = [copy_env]() { enif_free_env(copy_env); }; @@ -75,164 +84,86 @@ xla::StatusOr> PjRtBufferFromBinary(xla::PjRtCl return std::move(buffer); } -xla::StatusOr> -UnpackReplicaArguments(ErlNifEnv* env, - ERL_NIF_TERM replica_arguments, - ExlaClient* client, - int device) { - unsigned int length; - if (!enif_get_list_length(env, replica_arguments, &length)) { - return xla::InvalidArgument("Argument is not a list."); - } - - ERL_NIF_TERM head, tail; - std::vector replica_buffers; - replica_buffers.reserve(length); - - // for a single replica, the argument is a flat list of buffers where - // each buffer can either be an erlang binary or a reference to another - // EXLA buffer, it is not possible for any of the arguments to be nested - // tuples because we handle normalization/flattening of tuples on the - // Elixir side - while (enif_get_list_cell(env, replica_arguments, &head, &tail)) { - int arity; - const ERL_NIF_TERM* tuple; - ExlaBuffer** buffer; - - if (enif_get_tuple(env, head, &arity, &tuple)) { - // if the term is a tuple, that means it represents a {shape, binary} - // tuple which we must convert into an exla buffer for use in the computation - xla::Shape shape; - - if (!nif::get_typespec_as_xla_shape(env, tuple[1], &shape)) { - return xla::InvalidArgument("Expected argument to be a typespec."); - } - - // we convert the binary into a buffer and transfer it to the correct device, - // this buffer is not managed by the erlang vm so it must be deallocated explicitly - // after use by the execution - EXLA_ASSIGN_OR_RETURN(std::unique_ptr buf, - PjRtBufferFromBinary(client->client(), env, tuple[0], shape, device)); - replica_buffers.push_back(buf.release()); - } else if (nif::get(env, head, buffer)) { - // if the buffer is not a tuple it must be a reference to an exla buffer - // which means the resource is already managed by the vm, and should already - // be on the correct device, if it is not, we will not do any implicit transfers - // and instead raise an error - if ((*buffer)->device_id() != device) { - return xla::InvalidArgument("Expected buffer to be placed on device %d", device); - } - replica_buffers.push_back((*buffer)->buffer()); - } else { - return xla::InvalidArgument("Expected argument to be buffer reference."); - } - - replica_arguments = tail; - } - - return replica_buffers; -} - xla::StatusOr>> UnpackRunArguments(ErlNifEnv* env, - ERL_NIF_TERM arguments, + ExlaExecutable::RunArguments arguments, + std::vector> &transient_buffers, ExlaClient* client, xla::DeviceAssignment device_assignment, int device_id) { - unsigned int length; - if (!enif_get_list_length(env, arguments, &length)) { - return xla::InvalidArgument("Argument is not a list."); - } - - ERL_NIF_TERM head, tail; std::vector> arg_buffers; - arg_buffers.reserve(length); + arg_buffers.reserve(arguments.size()); int replica = 0; - int device; - while (enif_get_list_cell(env, arguments, &head, &tail)) { - device = device_id >= 0 ? device_id : device_assignment(replica, 0); + for (const auto & replica_arguments : arguments) { + auto device = device_id >= 0 ? device_id : device_assignment(replica, 0); + + auto replica_buffers = std::vector(); + replica_buffers.reserve(replica_arguments.size()); + + // For a single replica, the argument is a flat list of buffers where + // each buffer can either be an erlang binary or a reference to another + // EXLA buffer, it is not possible for any of the arguments to be nested + // tuples because we handle normalization/flattening of tuples on the + // Elixir side + for (const auto & argument : replica_arguments) { + if (auto value = std::get_if>(&argument)) { + auto [term, shape] = *value; + // We convert the binary into a buffer and transfer it to the + // correct device, this buffer is not managed by the erlang vm + // so it must be deallocated explicitly after use by the execution + EXLA_ASSIGN_OR_RETURN(std::unique_ptr buf, + PjRtBufferFromBinary(client->client(), term, shape, device)); + replica_buffers.push_back(buf.get()); + // Keep track of the buffer pointer, for automatic deallocation later + transient_buffers.push_back(std::move(buf)); + } else if (auto value = std::get_if>(&argument)) { + auto buffer = *value; + // if the buffer is not a tuple it must be a reference to an exla buffer + // which means the resource is already managed by the vm, and should already + // be on the correct device, if it is not, we will not do any implicit transfers + // and instead raise an error + if (buffer->device_id() != device) { + return xla::InvalidArgument("Expected buffer to be placed on device %d", device); + } + replica_buffers.push_back(buffer->buffer()); + } + } - EXLA_ASSIGN_OR_RETURN(std::vector replica_buffers, - UnpackReplicaArguments(env, head, client, device)); + arg_buffers.push_back(std::move(replica_buffers)); - arg_buffers.push_back(replica_buffers); replica++; - arguments = tail; } return arg_buffers; } -xla::StatusOr UnpackResult(ErlNifEnv* env, - std::vector>> result, - xla::DeviceAssignment device_assignment, - int device_id) { - std::vector per_replica_results; +ExlaExecutable::RunResult UnpackResult(ErlNifEnv* env, + std::vector>> result, + xla::DeviceAssignment device_assignment, + int device_id) { + auto per_replica_results = std::vector>, int64_t>>(); for (int i = 0; i < result.size(); i++) { - std::vector terms; - terms.reserve(result.size()); - int device = device_id >= 0 ? device_id : device_assignment(i, 0); + auto replica_results = std::vector>(); + int64_t device = device_id >= 0 ? device_id : device_assignment(i, 0); for (auto& pjrt_buf : result.at(i)) { pjrt_buf->BlockHostUntilReady(); - ExlaBuffer* buf = new ExlaBuffer(std::move(pjrt_buf)); - ERL_NIF_TERM term = nif::make(env, buf); - terms.push_back(term); - } - - ERL_NIF_TERM replica_term = enif_make_int(env, device); - ERL_NIF_TERM replica_results = enif_make_list_from_array(env, terms.data(), terms.size()); - per_replica_results.push_back(enif_make_tuple2(env, replica_results, replica_term)); - } - - ERL_NIF_TERM per_replica_term = enif_make_list_from_array(env, per_replica_results.data(), per_replica_results.size()); - - return nif::ok(env, per_replica_term); -} - -void FreeReplicaArguments(ErlNifEnv* env, ERL_NIF_TERM replica_arguments, std::vector buffers) { - unsigned int length; - if (!enif_get_list_length(env, replica_arguments, &length)) { - return; - } - - ERL_NIF_TERM head, tail; - int arg = 0; - - while (enif_get_list_cell(env, replica_arguments, &head, &tail)) { - xla::PjRtBuffer* buffer = buffers.at(arg); - - if (enif_is_tuple(env, head)) { - delete buffer; + auto result = fine::make_resource(std::move(pjrt_buf)); + replica_results.push_back(result); } - arg++; - replica_arguments = tail; - } -} - -void FreeRunArguments(ErlNifEnv* env, ERL_NIF_TERM arguments, std::vector> buffers) { - unsigned int length; - if (!enif_get_list_length(env, arguments, &length)) { - return; + per_replica_results.push_back(std::make_tuple(std::move(replica_results), device)); } - ERL_NIF_TERM head, tail; - int replica = 0; - - while (enif_get_list_cell(env, arguments, &head, &tail)) { - FreeReplicaArguments(env, head, buffers.at(replica)); - arguments = tail; - replica++; - } + return per_replica_results; } -xla::StatusOr ExlaExecutable::Run(ErlNifEnv* env, - ERL_NIF_TERM arguments, - int device_id) { +xla::StatusOr ExlaExecutable::Run(ErlNifEnv* env, + ExlaExecutable::RunArguments arguments, + int device_id) { xla::ExecuteOptions options; // arguments are not passed as a single PjRt tuple buffer, but instead // as multiple pjrt buffers @@ -271,6 +202,12 @@ xla::StatusOr ExlaExecutable::Run(ErlNifEnv* env, client_->client()->GetDefaultDeviceAssignment(num_replicas, 1)); } + // Buffers allocated from binaries for this specific run need to be + // freed at the end. We store the corresponding pointers in this + // vector, so they are all freed automatically when this function + // finishes + auto transient_buffers = std::vector>(); + if (device_id >= 0 && num_replicas > 1) { // if the device id is greater than or equal to 1, that means we've specified // a portable executable which cannot be pmapped, this code path should never @@ -279,7 +216,7 @@ xla::StatusOr ExlaExecutable::Run(ErlNifEnv* env, } else { // else we handle unpacking/validating the run arguments to the correct devices // according to the device id and the device assignment - EXLA_ASSIGN_OR_RETURN(input_buffers, UnpackRunArguments(env, arguments, client_, device_assignment, device_id)); + EXLA_ASSIGN_OR_RETURN(input_buffers, UnpackRunArguments(env, arguments, transient_buffers, client_, device_assignment, device_id)); } // at this point input buffers is a vector of arguments per replica @@ -324,24 +261,18 @@ xla::StatusOr ExlaExecutable::Run(ErlNifEnv* env, // the number of replicas and each individual replica is a vector of buffers, the // inner buffer represents a flattened output because we told PjRt we would always // return a tuple from the computation - EXLA_ASSIGN_OR_RETURN(ERL_NIF_TERM ret, - UnpackResult(env, std::move(per_replica_results), device_assignment, device_id)); - - // finally, we need to free any of the arguments we created for this computation - FreeRunArguments(env, arguments, input_buffers); + auto ret = UnpackResult(env, std::move(per_replica_results), device_assignment, device_id); return ret; } ExlaClient::ExlaClient(std::shared_ptr client) : client_(std::move(client)) {} -xla::StatusOr ExlaClient::BufferFromBinary(ErlNifEnv* env, - ERL_NIF_TERM source_term, +xla::StatusOr> ExlaClient::BufferFromBinary(ERL_NIF_TERM source_term, xla::Shape& shape, int device_id) { - EXLA_ASSIGN_OR_RETURN(auto buffer, PjRtBufferFromBinary(client(), env, source_term, shape, device_id)); - ExlaBuffer* exla_buffer = new ExlaBuffer(std::move(buffer)); - return exla_buffer; + EXLA_ASSIGN_OR_RETURN(auto buffer, PjRtBufferFromBinary(client(), source_term, shape, device_id)); + return fine::make_resource(std::move(buffer)); } xla::StatusOr> ExecutableFingerprint(std::unique_ptr& executable) { @@ -357,17 +288,17 @@ xla::StatusOr> ExecutableFingerprint(std::unique_ptr< } } -xla::StatusOr ExlaClient::DeserializeExecutable(std::string deserialized_executable) { +xla::StatusOr> ExlaClient::DeserializeExecutable(std::string deserialized_executable) { EXLA_ASSIGN_OR_RETURN(std::unique_ptr executable, client_->DeserializeExecutable(deserialized_executable, std::nullopt)); EXLA_ASSIGN_OR_RETURN(absl::optional fingerprint, ExecutableFingerprint(executable)); - return new ExlaExecutable(std::move(executable), std::move(fingerprint), this); + return fine::make_resource(std::move(executable), std::move(fingerprint), this); } -xla::StatusOr ExlaClient::Compile(const mlir::OwningOpRef& module, +xla::StatusOr> ExlaClient::Compile(mlir::ModuleOp module, std::vector argument_layouts, xla::ExecutableBuildOptions& options, bool compile_portable_executable) { @@ -386,11 +317,11 @@ xla::StatusOr ExlaClient::Compile(const mlir::OwningOpRef executable, - client_->Compile(*module, std::move(compile_opts))); + client_->Compile(module, std::move(compile_opts))); EXLA_ASSIGN_OR_RETURN(absl::optional fingerprint, ExecutableFingerprint(executable)); - return new ExlaExecutable(std::move(executable), std::move(fingerprint), this); + return fine::make_resource(std::move(executable), std::move(fingerprint), this); } xla::Status ExlaClient::TransferToInfeed(ErlNifEnv* env, @@ -440,27 +371,29 @@ xla::StatusOr ExlaClient::TransferFromOutfeed(ErlNifEnv* env, int auto literal = std::make_shared(shape); - xla::Status transfer_status = device->TransferFromOutfeed(literal.get()); + auto transfer_status = device->TransferFromOutfeed(literal.get()); if (!transfer_status.ok()) { return transfer_status; } - ErlNifBinary binary; - enif_alloc_binary(literal->size_bytes(), &binary); - std::memcpy(binary.data, literal->untyped_data(), literal->size_bytes()); + auto size = literal->size_bytes(); + + ERL_NIF_TERM binary_term; + auto data = enif_make_new_binary(env, size, &binary_term); + memcpy(data, literal->untyped_data(), size); - return nif::make(env, binary); + return binary_term; } -xla::StatusOr GetHostClient() { +xla::StatusOr> GetHostClient() { EXLA_ASSIGN_OR_RETURN(std::unique_ptr client, xla::GetTfrtCpuClient(false)); - return new ExlaClient(std::move(client)); + return fine::make_resource(std::move(client)); } -xla::StatusOr GetGpuClient(double memory_fraction, +xla::StatusOr> GetGpuClient(double memory_fraction, bool preallocate, xla::GpuAllocatorConfig::Kind kind) { xla::GpuAllocatorConfig allocator_config = { @@ -474,10 +407,10 @@ xla::StatusOr GetGpuClient(double memory_fraction, EXLA_ASSIGN_OR_RETURN(std::unique_ptr client, xla::GetStreamExecutorGpuClient(client_options)); - return new ExlaClient(std::move(client)); + return fine::make_resource(std::move(client)); } -xla::StatusOr GetTpuClient() { +xla::StatusOr> GetTpuClient() { auto statusor = pjrt::LoadPjrtPlugin("tpu", "libtpu.so"); if (!statusor.ok()) { return statusor.status(); @@ -492,13 +425,13 @@ xla::StatusOr GetTpuClient() { EXLA_ASSIGN_OR_RETURN(std::unique_ptr client, xla::GetCApiClient("TPU")); - return new ExlaClient(std::move(client)); + return fine::make_resource(std::move(client)); } -xla::StatusOr GetCApiClient(std::string device_type) { +xla::StatusOr> GetCApiClient(std::string device_type) { EXLA_ASSIGN_OR_RETURN(std::unique_ptr client, xla::GetCApiClient(device_type)); - return new ExlaClient(std::move(client)); + return fine::make_resource(std::move(client)); } } // namespace exla diff --git a/exla/c_src/exla/exla_client.h b/exla/c_src/exla/exla_client.h index 229853c4c9..0dcc0842cb 100644 --- a/exla/c_src/exla/exla_client.h +++ b/exla/c_src/exla/exla_client.h @@ -5,7 +5,8 @@ #include #include -#include "erl_nif.h" +#include +#include #include "exla_types.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OwningOpRef.h" @@ -27,7 +28,7 @@ class ExlaBuffer { int device_id() { return buffer_->device()->id(); } xla::PjRtBuffer* buffer() { return buffer_.get(); } - xla::StatusOr CopyToDevice(xla::PjRtDevice* dst_device); + xla::StatusOr> CopyToDevice(xla::PjRtDevice* dst_device); xla::StatusOr ToBinary(ErlNifEnv* env, exla::int64 size); xla::Status Deallocate(); @@ -50,15 +51,19 @@ class ExlaBuffer { class ExlaExecutable { public: + using ReplicaArgument = std::variant, std::tuple>; + using RunArguments = std::vector>; + + using RunReplicaResult = std::tuple>, int64_t>; + using RunResult =std::vector; + ExlaExecutable(std::unique_ptr executable, absl::optional fingerprint, ExlaClient* client); xla::PjRtLoadedExecutable* executable() { return executable_.get(); } - xla::StatusOr Run(ErlNifEnv* env, - ERL_NIF_TERM arguments, - int device_id); + xla::StatusOr Run(ErlNifEnv* env, RunArguments arguments, int device_id); xla::StatusOr SerializeExecutable() { return executable_->SerializeExecutable(); } @@ -78,17 +83,16 @@ class ExlaClient { // Compiles the given computation with the given compile options - xla::StatusOr Compile(const mlir::OwningOpRef& computation, + xla::StatusOr> Compile(mlir::ModuleOp computation, std::vector argument_layouts, xla::ExecutableBuildOptions& options, bool compile_portable_executable); - xla::StatusOr BufferFromBinary(ErlNifEnv* env, - ERL_NIF_TERM binary_term, + xla::StatusOr> BufferFromBinary(ERL_NIF_TERM binary_term, xla::Shape& shape, int device_id); - xla::StatusOr DeserializeExecutable(std::string serialized_executable); + xla::StatusOr> DeserializeExecutable(std::string serialized_executable); // TODO(seanmor5): This is device logic and should be refactored xla::Status TransferToInfeed(ErlNifEnv* env, @@ -102,15 +106,15 @@ class ExlaClient { std::shared_ptr client_; }; -xla::StatusOr GetHostClient(); +xla::StatusOr> GetHostClient(); -xla::StatusOr GetGpuClient(double memory_fraction, +xla::StatusOr> GetGpuClient(double memory_fraction, bool preallocate, xla::GpuAllocatorConfig::Kind kind); -xla::StatusOr GetTpuClient(); +xla::StatusOr> GetTpuClient(); -xla::StatusOr GetCApiClient(std::string device_type); +xla::StatusOr> GetCApiClient(std::string device_type); } // namespace exla #endif diff --git a/exla/c_src/exla/exla_cuda.cc b/exla/c_src/exla/exla_cuda.cc index 172a03bfdb..395fce3f9b 100644 --- a/exla/c_src/exla/exla_cuda.cc +++ b/exla/c_src/exla/exla_cuda.cc @@ -5,24 +5,31 @@ #include #include +#include +#include -std::pair, int> get_cuda_ipc_handle(std::uintptr_t ptr) { +std::optional get_cuda_ipc_handle(std::uintptr_t ptr) { cudaIpcMemHandle_t ipc_handle; cudaError_t status = cudaIpcGetMemHandle(&ipc_handle, reinterpret_cast(ptr)); + if (status != cudaSuccess) { + return std::nullopt; + } + // Assuming sizeof(cudaIpcMemHandle_t) is constant const size_t size = sizeof(cudaIpcMemHandle_t); - // Copy the memory handle to a byte array - std::vector result(size); - memcpy(result.data(), &ipc_handle, size); + // Copy the memory handle to a buffer + std::string buffer; + buffer.resize(size); + memcpy(&(*(buffer.begin())), &ipc_handle, size); - return std::make_pair(result, status != cudaSuccess); + return buffer; } -std::pair get_pointer_for_ipc_handle(uint8_t* handle_bin, size_t handle_size, int device_id) { +std::optional get_pointer_for_ipc_handle(uint8_t* handle_bin, size_t handle_size, int device_id) { if (handle_size != sizeof(cudaIpcMemHandle_t)) { - return std::make_pair(nullptr, 1); // Return with error status + return std::make_tuple(nullptr, 1); // Return with error status } unsigned char ipc_handle_data[sizeof(cudaIpcMemHandle_t)]; @@ -37,23 +44,23 @@ std::pair get_pointer_for_ipc_handle(uint8_t* handle_bin, size_t han cudaError_t cuda_status = cudaSetDevice(device_id); // Assuming device 0, change as needed if (cuda_status != cudaSuccess) { printf("Error setting CUDA device: %s\n", cudaGetErrorString(cuda_status)); - return std::make_pair(nullptr, 1); // Return with error status + return std::nullopt; } cuda_status = cudaIpcOpenMemHandle((void**)&ptr, ipc_handle, cudaIpcMemLazyEnablePeerAccess); if (cuda_status != cudaSuccess) { printf("Error opening CUDA IPC memory handle: %s\n", cudaGetErrorString(cuda_status)); - return std::make_pair(nullptr, 1); // Return with error status + return std::nullopt; } - return std::make_pair(ptr, cuda_status != cudaSuccess); + return ptr; } #else -std::pair, int> get_cuda_ipc_handle(std::uintptr_t ptr) { - return std::make_pair(std::vector(0), 1); +std::optional get_cuda_ipc_handle(std::uintptr_t ptr) { + return std::nullopt; } -std::pair get_pointer_for_ipc_handle(uint8_t* handle_bin, size_t handle_size, int device_id) { - return std::make_pair(nullptr, 1); +std::optional get_pointer_for_ipc_handle(uint8_t* handle_bin, size_t handle_size, int device_id) { + return std::nullopt; } -#endif \ No newline at end of file +#endif diff --git a/exla/c_src/exla/exla_cuda.h b/exla/c_src/exla/exla_cuda.h index a61e47ce65..10c258fb12 100644 --- a/exla/c_src/exla/exla_cuda.h +++ b/exla/c_src/exla/exla_cuda.h @@ -2,7 +2,9 @@ #include #include +#include +#include #include -std::pair, int> get_cuda_ipc_handle(std::uintptr_t); -std::pair get_pointer_for_ipc_handle(uint8_t*, size_t, int); \ No newline at end of file +std::optional get_cuda_ipc_handle(std::uintptr_t); +std::optional get_pointer_for_ipc_handle(uint8_t *, size_t, int); diff --git a/exla/c_src/exla/exla_log_sink.h b/exla/c_src/exla/exla_log_sink.h index 7857f44de3..54d2a9cff6 100644 --- a/exla/c_src/exla/exla_log_sink.h +++ b/exla/c_src/exla/exla_log_sink.h @@ -3,9 +3,10 @@ #include +#include "absl/base/log_severity.h" #include "exla_nif_util.h" +#include #include "tsl/platform/logging.h" -#include "absl/base/log_severity.h" namespace exla { @@ -13,74 +14,51 @@ namespace exla { // is the PID for a GenServer in Elixir which receives messages // with logging information on every call to `LOG(severity)`. class ExlaLogSink : public tsl::TFLogSink { - public: - explicit ExlaLogSink(ErlNifPid sink_pid) : sink_pid_(sink_pid) { - // Logger Env - env_ = enif_alloc_env(); - } +public: + explicit ExlaLogSink(ErlNifPid sink_pid) : sink_pid_(sink_pid) {} - ~ExlaLogSink() { enif_free_env(env_); } + void Send(const tsl::TFLogEntry &entry) { + auto string = entry.ToString(); + auto fname = entry.FName(); + int64_t line = entry.Line(); + auto severity = entry.log_severity(); - ERL_NIF_TERM info(std::string str, std::string fname, int32 line) { - ERL_NIF_TERM status = nif::atom(env_, "info"); - ERL_NIF_TERM msg = nif::make(env_, str); - ERL_NIF_TERM file = nif::make(env_, fname); - ERL_NIF_TERM line_term = nif::make(env_, line); - return enif_make_tuple4(env_, status, msg, file, line_term); - } + if (severity == absl::LogSeverity::kFatal) { + // LOG(FATAL) aborts the program before we are able to send and + // log the information from Elixir, so we need to get it out + // there for debugging before everything crashes + std::cerr << "[FATAL] " << fname << ":" << line << " " << string << "\n"; + } - ERL_NIF_TERM warning(std::string str, std::string fname, int32 line) { - ERL_NIF_TERM status = nif::atom(env_, "warning"); - ERL_NIF_TERM msg = nif::make(env_, str); - ERL_NIF_TERM file = nif::make(env_, fname); - ERL_NIF_TERM line_term = nif::make(env_, line); - return enif_make_tuple4(env_, status, msg, file, line_term); - } + auto env = enif_alloc_env(); + + auto message = fine::encode( + env, std::make_tuple(severity_to_atom(severity), string, fname, line)); + + enif_send(NULL, &sink_pid_, env, message); - ERL_NIF_TERM error(std::string str, std::string fname, int32 line) { - ERL_NIF_TERM status = nif::atom(env_, "error"); - ERL_NIF_TERM msg = nif::make(env_, str); - ERL_NIF_TERM file = nif::make(env_, fname); - ERL_NIF_TERM line_term = nif::make(env_, line); - return enif_make_tuple4(env_, status, msg, file, line_term); + enif_free_env(env); } - void Send(const tsl::TFLogEntry& entry) { - ERL_NIF_TERM msg; - std::string msg_str = entry.ToString(); - std::string fname = entry.FName(); - int32 line = entry.Line(); - switch (entry.log_severity()) { - case absl::LogSeverity::kInfo: - msg = info(msg_str, fname, line); - break; - case absl::LogSeverity::kWarning: - msg = warning(msg_str, fname, line); - break; - case absl::LogSeverity::kError: - msg = error(msg_str, fname, line); - break; - case absl::LogSeverity::kFatal: - // LOG(FATAL) aborts the program before we are able - // to send and log the information from Elixir, so we - // need to get it out there for debugging before everything - // crashes - std::cerr << "[FATAL] " << fname << ":" - << line << " " << msg_str << "\n"; - // In case there is a race, set msg just to be safe - msg = error(msg_str, fname, line); - break; - default: - msg = info(msg_str, fname, line); +private: + fine::Atom severity_to_atom(absl::LogSeverity severity) { + switch (severity) { + case absl::LogSeverity::kInfo: + return atoms::info; + case absl::LogSeverity::kWarning: + return atoms::warning; + case absl::LogSeverity::kError: + return atoms::error; + case absl::LogSeverity::kFatal: + return atoms::error; + default: + return atoms::info; } - enif_send(env_, &sink_pid_, NULL, msg); } - private: ErlNifPid sink_pid_; - ErlNifEnv* env_; }; -} // namespace exla +} // namespace exla #endif diff --git a/exla/c_src/exla/exla_mlir.cc b/exla/c_src/exla/exla_mlir.cc index 3448e95262..c4aabdcd9e 100644 --- a/exla/c_src/exla/exla_mlir.cc +++ b/exla/c_src/exla/exla_mlir.cc @@ -1,32 +1,40 @@ #include "exla_mlir.h" +#include #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/AsmParser/AsmParser.h" +#include "mlir/IR/Region.h" namespace exla { -MLIRFunction::MLIRFunction(MLIRModule *module, std::unique_ptr func) +MLIRFunction::MLIRFunction(fine::ResourcePtr module, std::unique_ptr func) : func_(std::move(func)), module_(module) {} -std::vector MLIRFunction::Op( - std::string op_name, std::vector operands, +std::vector> MLIRFunction::Op( + std::string op_name, std::vector> operands, std::vector result_types, - std::vector> attributes, - std::vector regions) { + std::vector> attributes, + std::vector> regions) { auto builder = module_->builder(); auto context = builder->getContext(); auto types_range = mlir::TypeRange{llvm::ArrayRef{result_types}}; auto named_attributes = std::vector{}; - for (auto const &pair : attributes) { - auto attribute = builder->getNamedAttr(pair.first, pair.second); + for (auto const &[key, value] : attributes) { + auto attribute = builder->getNamedAttr(key, value); named_attributes.push_back(attribute); } - auto operands_range = mlir::ValueRange{ - llvm::ArrayRef(operands.data(), operands.size())}; + + auto operand_values = std::vector(); + operand_values.reserve(operands.size()); + for (const auto &operand : operands) { + operand_values.push_back(*operand); + } + + auto operands_range = mlir::ValueRange{llvm::ArrayRef{operand_values}}; auto attributes_array = llvm::ArrayRef{named_attributes}; setInsertionPoint(); @@ -46,29 +54,37 @@ std::vector MLIRFunction::Op( auto op = builder->create(op_state); - auto results = op->getResults(); - return std::vector(results.begin(), results.end()); + auto result_values = op->getResults(); + + auto results = std::vector>(); + results.reserve(result_values.size()); + for (const auto &result : result_values) { + results.push_back(fine::make_resource(result)); + } + + return results; } -std::pair> MLIRFunction::PushRegion(std::vector types) { +std::tuple, std::vector>> +MLIRFunction::PushRegion(std::vector types) { auto context = module_->builder()->getContext(); - auto region = new mlir::Region(); + auto region = fine::make_resource(); auto & block = region->emplaceBlock(); for (mlir::Type type : types) { block.addArgument(type, mlir::UnknownLoc::get(context)); } - auto args = std::vector{}; + auto args = std::vector>(); for (auto &arg : block.getArguments()) { - args.push_back(arg); + args.push_back(fine::make_resource(arg)); } - region_stack.push(std::move(region)); + region_stack.push(region); setInsertionPoint(); - return {region, args}; + return std::make_tuple(region, args); } void MLIRFunction::PopRegion() { @@ -84,14 +100,14 @@ void MLIRFunction::setInsertionPoint() { } } -MLIRModule::MLIRModule(mlir::MLIRContext *context) { +MLIRModule::MLIRModule(fine::ResourcePtr context) { context_ = context; - module_ = mlir::OwningOpRef(mlir::ModuleOp::create(mlir::UnknownLoc::get(context_))); - builder_ = std::make_unique(context_); + module_ = mlir::OwningOpRef(mlir::ModuleOp::create(mlir::UnknownLoc::get(context_.get()))); + builder_ = std::make_unique(context_.get()); builder_->setInsertionPointToStart(module_->getBody()); } -MLIRFunction *MLIRModule::CreateFunction( +std::unique_ptr MLIRModule::CreateFunction( std::string name, std::vector arg_types, std::vector ret_types, @@ -106,7 +122,7 @@ MLIRFunction *MLIRModule::CreateFunction( funcOp->addEntryBlock(); builder_->setInsertionPointToStart(&funcOp->getBody().front()); - return new MLIRFunction(this, std::move(funcOp)); + return funcOp; } std::string MLIRModule::ToString() { @@ -117,11 +133,23 @@ std::string MLIRModule::ToString() { } mlir::Type MLIRModule::ParseType(std::string string) { - return mlir::parseType(string, context_); + auto type = mlir::parseType(string, context_.get()); + + if (type == nullptr) { + throw std::runtime_error("unable to parse MLIR type: " + string); + } + + return type; } mlir::Attribute MLIRModule::ParseAttribute(std::string string) { - return mlir::parseAttribute(string, context_); + auto attribute = mlir::parseAttribute(string, context_.get()); + + if (attribute == nullptr) { + throw std::runtime_error("unable to parse MLIR type: " + string); + } + + return attribute; } } // namespace exla diff --git a/exla/c_src/exla/exla_mlir.h b/exla/c_src/exla/exla_mlir.h index a3c3b91dfa..095ad4c1a7 100644 --- a/exla/c_src/exla/exla_mlir.h +++ b/exla/c_src/exla/exla_mlir.h @@ -1,6 +1,7 @@ #ifndef EXLA_MLIR_BUILDER_H_ #define EXLA_MLIR_BUILDER_H_ +#include #include #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -15,36 +16,35 @@ class MLIRModule; class MLIRFunction { public: - MLIRFunction(MLIRModule *module, std::unique_ptr func); + MLIRFunction(fine::ResourcePtr module, std::unique_ptr func); - std::vector Op( - std::string op_name, - std::vector operands, + std::vector> Op( + std::string op_name, std::vector> operands, std::vector result_types, - std::vector> attributes, - std::vector regions); + std::vector> attributes, + std::vector> regions); - std::pair> PushRegion(std::vector types); + std::tuple, std::vector>> PushRegion(std::vector types); void PopRegion(); llvm::MutableArrayRef GetArguments() { return func_->getBody().front().getArguments(); } - std::shared_ptr module() { return module_; } + fine::ResourcePtr module() { return module_; } private: - std::shared_ptr module_; + fine::ResourcePtr module_; std::unique_ptr func_; - std::stack region_stack; + std::stack> region_stack; void setInsertionPoint(); }; class MLIRModule { public: - MLIRModule(mlir::MLIRContext *context); + MLIRModule(fine::ResourcePtr context); - MLIRFunction *CreateFunction( + std::unique_ptr CreateFunction( std::string name, std::vector arg_types, std::vector ret_types, @@ -60,7 +60,7 @@ class MLIRModule { mlir::OpBuilder *builder() { return builder_.get(); } private: - mlir::MLIRContext *context_; + fine::ResourcePtr context_; mlir::OwningOpRef module_; std::unique_ptr builder_; }; diff --git a/exla/c_src/exla/exla_nif_util.cc b/exla/c_src/exla/exla_nif_util.cc deleted file mode 100644 index f639896fa7..0000000000 --- a/exla/c_src/exla/exla_nif_util.cc +++ /dev/null @@ -1,317 +0,0 @@ -#include "exla_nif_util.h" - -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypes.h" -#include "stablehlo/dialect/StablehloOps.h" -#include "xla/primitive_util.h" -#include "xla/shape_util.h" - -namespace exla { -namespace nif { - -// Status helpers - -ERL_NIF_TERM error(ErlNifEnv* env, const char* msg) { - ERL_NIF_TERM atom = enif_make_atom(env, "error"); - ERL_NIF_TERM msg_term = enif_make_string(env, msg, ERL_NIF_LATIN1); - return enif_make_tuple2(env, atom, msg_term); -} - -ERL_NIF_TERM ok(ErlNifEnv* env, ERL_NIF_TERM term) { - return enif_make_tuple2(env, ok(env), term); -} - -ERL_NIF_TERM ok(ErlNifEnv* env) { - return enif_make_atom(env, "ok"); -} - -// Numeric types - -int get(ErlNifEnv* env, ERL_NIF_TERM term, int8* var) { - int value; - if (!enif_get_int(env, term, &value)) return 0; - *var = static_cast(value); - return 1; -} - -int get(ErlNifEnv* env, ERL_NIF_TERM term, int16* var) { - int value; - if (!enif_get_int(env, term, &value)) return 0; - *var = static_cast(value); - return 1; -} - -int get(ErlNifEnv* env, ERL_NIF_TERM term, int32* var) { - return enif_get_int(env, term, - reinterpret_cast(var)); -} - -int get(ErlNifEnv* env, ERL_NIF_TERM term, int64* var) { - return enif_get_int64(env, term, - reinterpret_cast(var)); -} - -int get(ErlNifEnv* env, ERL_NIF_TERM term, uint8* var) { - unsigned int value; - if (!enif_get_uint(env, term, &value)) return 0; - *var = static_cast(value); - return 1; -} - -int get(ErlNifEnv* env, ERL_NIF_TERM term, uint16* var) { - unsigned int value; - if (!enif_get_uint(env, term, &value)) return 0; - *var = static_cast(value); - return 1; -} - -int get(ErlNifEnv* env, ERL_NIF_TERM term, uint32* var) { - return enif_get_uint(env, term, - reinterpret_cast(var)); -} - -int get(ErlNifEnv* env, ERL_NIF_TERM term, uint64* var) { - return enif_get_uint64(env, term, - reinterpret_cast(var)); -} - -int get(ErlNifEnv* env, ERL_NIF_TERM term, float16* var) { - double value; - if (!enif_get_double(env, term, &value)) return 0; - *var = static_cast(value); - return 1; -} - -int get(ErlNifEnv* env, ERL_NIF_TERM term, bfloat16* var) { - double value; - if (!enif_get_double(env, term, &value)) return 0; - *var = static_cast(value); - return 1; -} - -int get(ErlNifEnv* env, ERL_NIF_TERM term, float32* var) { - double value; - if (!enif_get_double(env, term, &value)) return 0; - *var = static_cast(value); - return 1; -} - -int get(ErlNifEnv* env, ERL_NIF_TERM term, float64* var) { - return enif_get_double(env, term, var); -} - -int get(ErlNifEnv* env, ERL_NIF_TERM term, complex64* var) { - return 0; -} - -int get(ErlNifEnv* env, ERL_NIF_TERM term, complex128* var) { - return 0; -} - -ERL_NIF_TERM make(ErlNifEnv* env, int32 var) { - return enif_make_int(env, var); -} - -// Standard types - -int get(ErlNifEnv* env, ERL_NIF_TERM term, std::string& var) { - unsigned len; - int ret = enif_get_list_length(env, term, &len); - - if (!ret) { - ErlNifBinary bin; - ret = enif_inspect_binary(env, term, &bin); - if (!ret) { - return 0; - } - var = std::string((const char*)bin.data, bin.size); - return ret; - } - - var.resize(len + 1); - ret = enif_get_string(env, term, &*(var.begin()), var.size(), ERL_NIF_LATIN1); - - if (ret > 0) { - var.resize(ret - 1); - } else if (ret == 0) { - var.resize(0); - } else { - } - - return ret; -} - -int get(ErlNifEnv* env, ERL_NIF_TERM term, bool* var) { - int value; - if (!enif_get_int(env, term, &value)) return 0; - *var = static_cast(value); - return 1; -} - -ERL_NIF_TERM make(ErlNifEnv* env, ErlNifBinary var) { - return enif_make_binary(env, &var); -} - -ERL_NIF_TERM make(ErlNifEnv* env, std::string var) { - return enif_make_string(env, var.c_str(), ERL_NIF_LATIN1); -} - -ERL_NIF_TERM make(ErlNifEnv* env, const char* string) { - return enif_make_string(env, string, ERL_NIF_LATIN1); -} - -// Atoms - -int get_atom(ErlNifEnv* env, ERL_NIF_TERM term, std::string& var) { - unsigned atom_length; - if (!enif_get_atom_length(env, term, &atom_length, ERL_NIF_LATIN1)) { - return 0; - } - - var.resize(atom_length + 1); - - if (!enif_get_atom(env, term, &(*(var.begin())), var.size(), ERL_NIF_LATIN1)) return 0; - - var.resize(atom_length); - - return 1; -} - -ERL_NIF_TERM atom(ErlNifEnv* env, const char* msg) { - return enif_make_atom(env, msg); -} - -// Containers - -int get_tuple(ErlNifEnv* env, ERL_NIF_TERM tuple, std::vector& var) { - const ERL_NIF_TERM* terms; - int length; - if (!enif_get_tuple(env, tuple, &length, &terms)) return 0; - var.reserve(length); - - for (int i = 0; i < length; i++) { - int64 data; - if (!get(env, terms[i], &data)) return 0; - var.push_back(data); - } - return 1; -} - -int get_list(ErlNifEnv* env, - ERL_NIF_TERM list, - std::vector& var) { - unsigned int length; - if (!enif_get_list_length(env, list, &length)) return 0; - var.reserve(length); - ERL_NIF_TERM head, tail; - - while (enif_get_list_cell(env, list, &head, &tail)) { - ErlNifBinary elem; - if (!get_binary(env, head, &elem)) return 0; - var.push_back(elem); - list = tail; - } - return 1; -} - -int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var) { - unsigned int length; - if (!enif_get_list_length(env, list, &length)) return 0; - var.reserve(length); - ERL_NIF_TERM head, tail; - - while (enif_get_list_cell(env, list, &head, &tail)) { - int64 elem; - if (!get(env, head, &elem)) return 0; - var.push_back(elem); - list = tail; - } - return 1; -} - -int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var) { - unsigned int length; - if (!enif_get_list_length(env, list, &length)) { - return 0; - } - var.reserve(length); - ERL_NIF_TERM head, tail; - - while (enif_get_list_cell(env, list, &head, &tail)) { - std::string elem; - if (!get(env, head, elem)) { - return 0; - } - var.push_back(elem); - list = tail; - } - return 1; -} - -int get_binary(ErlNifEnv* env, ERL_NIF_TERM term, ErlNifBinary* var) { - return enif_inspect_binary(env, term, var); -} - -ERL_NIF_TERM make_map(ErlNifEnv* env, std::map& map) { - ERL_NIF_TERM term = enif_make_new_map(env); - std::map::iterator itr; - for (itr = map.begin(); itr != map.end(); ++itr) { - ERL_NIF_TERM key = make(env, itr->first); - ERL_NIF_TERM value = make(env, itr->second); - enif_make_map_put(env, term, key, value, &term); - } - return term; -} - -int get_primitive_type(ErlNifEnv* env, ERL_NIF_TERM term, xla::PrimitiveType* type) { - std::string type_str; - if (!get(env, term, type_str)) return 0; - - xla::StatusOr type_status = - xla::primitive_util::StringToPrimitiveType(type_str); - - if (!type_status.ok()) { - return 0; - } - *type = type_status.value(); - return 1; -} - -int get_typespec_as_xla_shape(ErlNifEnv* env, ERL_NIF_TERM term, xla::Shape* shape) { - int arity; - const ERL_NIF_TERM* tuple; - - if (!enif_get_tuple(env, term, &arity, &tuple)) return 0; - - xla::PrimitiveType element_type; - std::vector dims; - - if (!get_primitive_type(env, tuple[0], &element_type)) return 0; - if (!get_tuple(env, tuple[1], dims)) return 0; - - *shape = xla::ShapeUtil::MakeShape(element_type, dims); - - return 1; -} - -int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var) { - unsigned int length; - if (!enif_get_list_length(env, list, &length)) { - return 0; - } - var.reserve(length); - ERL_NIF_TERM head, tail; - - while (enif_get_list_cell(env, list, &head, &tail)) { - xla::Shape elem; - if (!get_typespec_as_xla_shape(env, head, &elem)) { - return 0; - } - var.push_back(elem); - list = tail; - } - return 1; -} - -} // namespace nif -} // namespace exla diff --git a/exla/c_src/exla/exla_nif_util.h b/exla/c_src/exla/exla_nif_util.h index 1be278be88..714f74f2da 100644 --- a/exla/c_src/exla/exla_nif_util.h +++ b/exla/c_src/exla/exla_nif_util.h @@ -1,336 +1,270 @@ #ifndef EXLA_NIF_UTIL_H_ #define EXLA_NIF_UTIL_H_ -#include -#include -#include -#include +#include +#include -#include "erl_nif.h" #include "xla/shape.h" -#include "exla_types.h" - -#if !defined(__GNUC__) && (defined(__WIN32__) || defined(_WIN32) || defined(_WIN32_)) -typedef unsigned __int64 nif_uint64_t; -typedef signed __int64 nif_int64_t; -#else -typedef unsigned long nif_uint64_t; -typedef signed long nif_int64_t; -#endif - -// Implementation Notes: -// -// In most of these implementations you'll find we prefer output parameters -// over returning values. This follows the convention of the Erlang NIF -// API in which functions for retrieving terms from the VM return an -// integer status and populate an output parameter. -// -// We also follow the naming convention set forth in the the Erlang NIF -// API. Numeric, standard, and resource types use the polymorphic/template -// `get` or `make`. -// -// We mostly use vectors for containers (lists and tuples), and maps for -// returning maps back to the VM. These have suffixes to avoid conflicting -// signatures for retrieving/returning different signatures. -// -// We create separate methods for each XLA protobuf type, so we can guarantee -// the format we receive the protobuf in is correct. +#include "xla/shape_util.h" +#include "mlir/IR/Types.h" +#include "stablehlo/dialect/StablehloOps.h" namespace exla { -namespace nif { - -// Status helpers - -// Helper for returning `{:error, msg}` from NIF. -ERL_NIF_TERM error(ErlNifEnv* env, const char* msg); - -// Helper for returning `{:ok, term}` from NIF. -ERL_NIF_TERM ok(ErlNifEnv* env, ERL_NIF_TERM term); - -// Helper for returning `:ok` from NIF. -ERL_NIF_TERM ok(ErlNifEnv* env); - -// Numeric types -// -// Floating/Complex types will never get used, except -// when defining scalar-constants with `constant`. - -int get(ErlNifEnv* env, ERL_NIF_TERM term, int8* var); -int get(ErlNifEnv* env, ERL_NIF_TERM term, int16* var); -int get(ErlNifEnv* env, ERL_NIF_TERM term, int32* var); -int get(ErlNifEnv* env, ERL_NIF_TERM term, int64* var); -int get(ErlNifEnv* env, ERL_NIF_TERM term, uint8* var); -int get(ErlNifEnv* env, ERL_NIF_TERM term, uint16* var); -int get(ErlNifEnv* env, ERL_NIF_TERM term, uint32* var); -int get(ErlNifEnv* env, ERL_NIF_TERM term, uint64* var); -int get(ErlNifEnv* env, ERL_NIF_TERM term, bfloat16* var); -int get(ErlNifEnv* env, ERL_NIF_TERM term, float16* var); -int get(ErlNifEnv* env, ERL_NIF_TERM term, float32* var); -int get(ErlNifEnv* env, ERL_NIF_TERM term, float64* var); -int get(ErlNifEnv* env, ERL_NIF_TERM term, complex64* var); -int get(ErlNifEnv* env, ERL_NIF_TERM term, complex128* var); - -ERL_NIF_TERM make(ErlNifEnv* env, int32 var); - -// Standard types -// -// We only define implementations for types we use in the -// NIF. - -int get(ErlNifEnv* env, ERL_NIF_TERM term, std::string& var); -int get(ErlNifEnv* env, ERL_NIF_TERM term, bool* var); +namespace atoms { +static auto ElixirEXLATypespec = fine::Atom("Elixir.EXLA.Typespec"); +static auto __struct__ = fine::Atom("__struct__"); +static auto already_deallocated = fine::Atom("already_deallocated"); +static auto bf = fine::Atom("bf"); +static auto c = fine::Atom("c"); +static auto error = fine::Atom("error"); +static auto f = fine::Atom("f"); +static auto info = fine::Atom("info"); +static auto pred = fine::Atom("pred"); +static auto s = fine::Atom("s"); +static auto shape = fine::Atom("shape"); +static auto token = fine::Atom("token"); +static auto type = fine::Atom("type"); +static auto u = fine::Atom("u"); +static auto warning = fine::Atom("warning"); +} // namespace atoms +} // namespace exla + +namespace fine { + +// Define decoding for xla::Shape from for %EXLA.Typespec{} term +template <> struct Decoder { + static xla::Shape decode(ErlNifEnv *env, const ERL_NIF_TERM &term) { + ERL_NIF_TERM type_term; + ERL_NIF_TERM shape_term; + + if (!enif_get_map_value(env, term, fine::encode(env, exla::atoms::type), + &type_term)) { + throw std::invalid_argument( + "decode failed, expected EXLA.Typespec struct"); + } + + if (!enif_get_map_value(env, term, fine::encode(env, exla::atoms::shape), + &shape_term)) { + throw std::invalid_argument( + "decode failed, expected EXLA.Typespec struct"); + } + + return xla::ShapeUtil::MakeShape(decode_type(env, type_term), + decode_shape(env, shape_term)); + } -ERL_NIF_TERM make(ErlNifEnv* env, std::string var); -ERL_NIF_TERM make(ErlNifEnv* env, ErlNifBinary var); -ERL_NIF_TERM make(ErlNifEnv* env, const char* string); +private: + static std::vector decode_shape(ErlNifEnv *env, + const ERL_NIF_TERM &term) { + int size; + const ERL_NIF_TERM *terms; -// Atoms -// -// We have to be explicit in naming these functions because -// their signatures are the same for retrieving/returning -// regular strings. + if (!enif_get_tuple(env, term, &size, &terms)) { + throw std::invalid_argument( + "decode failed, expected shape to be a tuple"); + } -int get_atom(ErlNifEnv* env, ERL_NIF_TERM term, std::string& var); + auto vector = std::vector(); + vector.reserve(size); -ERL_NIF_TERM atom(ErlNifEnv* env, const char* status); + for (auto i = 0; i < size; i++) { + auto elem = fine::decode(env, terms[i]); + vector.push_back(elem); + } -// Template struct for resources. The struct lets us use templates -// to store and retrieve open resources later on. This implementation -// is the same as the approach taken in the goertzenator/nifpp -// C++11 wrapper around the Erlang NIF API. -template -struct resource_object { - static ErlNifResourceType* type; -}; -template -ErlNifResourceType* resource_object::type = 0; - -// Default destructor passed when opening a resource. The default -// behavior is to invoke the underlying objects destructor and -// set the resource pointer to NULL. -template -void default_dtor(ErlNifEnv* env, void* obj) { - T* resource = reinterpret_cast(obj); - resource->~T(); - resource = nullptr; -} - -// Opens a resource for the given template type T. If no -// destructor is given, uses the default destructor defined -// above. -template -int open_resource(ErlNifEnv* env, - const char* mod, - const char* name, - ErlNifResourceDtor* dtor = nullptr) { - if (dtor == nullptr) { - dtor = &default_dtor; - } - ErlNifResourceType* type; - ErlNifResourceFlags flags = ErlNifResourceFlags(ERL_NIF_RT_CREATE | ERL_NIF_RT_TAKEOVER); - type = enif_open_resource_type(env, mod, name, dtor, flags, NULL); - if (type == NULL) { - resource_object::type = 0; - return -1; - } else { - resource_object::type = type; - } - return 1; -} - -// Returns a resource of the given template type T. -template -ERL_NIF_TERM get(ErlNifEnv* env, ERL_NIF_TERM term, T*& var) { - return enif_get_resource(env, term, - resource_object::type, - reinterpret_cast(&var)); -} - -// Creates a reference to the given resource of type T. We -// use the move constructor by default because some XLA -// objects delete the copy-constructor. The move is intended -// to represent a transfer of ownership of the object to -// the VM. - -template -ERL_NIF_TERM make(ErlNifEnv* env, T& var) { - void* ptr = enif_alloc_resource(resource_object::type, sizeof(T)); - new (ptr) T(std::move(var)); - ERL_NIF_TERM ret = enif_make_resource(env, ptr); - enif_release_resource(ptr); - return ret; -} - -template -ERL_NIF_TERM make_list(ErlNifEnv* env, std::vector result) { - size_t n = result.size(); - - std::vector nif_terms; - nif_terms.reserve(n); - - for (size_t i = 0; i < n; i++) { - nif_terms[i] = exla::nif::make(env, result[i]); + return vector; } - auto data = nif_terms.data(); - auto list = enif_make_list_from_array(env, &data[0], n); - return list; -} -// Containers -// -// Both tuples and lists are treated as vectors, but extracting -// terms from both is slightly different, so we have to be -// explicit in the naming convention in order to differentiate. -// -// We also support reading resources into vectors from both tuples -// and lists. Once again, implementation is slightly different -// for resources, so we need to be explicit. -// -// Similar to standard types, we only define implementations for -// types used. - -int get_tuple(ErlNifEnv* env, - ERL_NIF_TERM tuple, - std::vector& var); - -template -int get_tuple(ErlNifEnv* env, ERL_NIF_TERM tuple, std::vector& var) { - const ERL_NIF_TERM* terms; - int length; - if (!enif_get_tuple(env, tuple, &length, &terms)) return 0; - var.reserve(length); - - for (int i = 0; i < length; i++) { - T* elem; - if (!get(env, terms[i], elem)) return 0; - var.push_back(*elem); + static xla::PrimitiveType decode_type(ErlNifEnv *env, + const ERL_NIF_TERM &term) { + auto [element, size] = + fine::decode>(env, term); + + if (element == "u") { + switch (size) { + case 2: + return xla::U2; + case 4: + return xla::U4; + case 8: + return xla::U8; + case 16: + return xla::U16; + case 32: + return xla::U32; + case 64: + return xla::U64; + } + } + if (element == "s") { + switch (size) { + case 2: + return xla::S2; + case 4: + return xla::S4; + case 8: + return xla::S8; + case 16: + return xla::S16; + case 32: + return xla::S32; + case 64: + return xla::S64; + } + } + if (element == "f") { + switch (size) { + case 8: + return xla::F8E5M2; + case 16: + return xla::F16; + case 32: + return xla::F32; + case 64: + return xla::F64; + } + } + if (element == "bf") { + switch (size) { + case 16: + return xla::BF16; + } + } + if (element == "c") { + switch (size) { + case 64: + return xla::C64; + case 128: + return xla::C128; + } + } + if (element == "pred") { + return xla::PRED; + } + + throw std::invalid_argument("decode failed, unexpected type"); } - return 1; -} - -int get_list(ErlNifEnv* env, - ERL_NIF_TERM list, - std::vector& var); -int get_list(ErlNifEnv* env, - ERL_NIF_TERM list, - std::vector& var); - -int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var); - -int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var); - -template -int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var) { - unsigned int length; - if (!enif_get_list_length(env, list, &length)) return 0; - var.reserve(length); - ERL_NIF_TERM head, tail; - - while (enif_get_list_cell(env, list, &head, &tail)) { - T* elem; - if (!get(env, head, elem)) return 0; - var.push_back(elem); - list = tail; +}; + +// Define encoding for mlir::Type into %EXLA.Typespec{} term +template <> struct Encoder { + static ERL_NIF_TERM encode(ErlNifEnv *env, const mlir::Type &type) { + ERL_NIF_TERM keys[] = { + fine::encode(env, exla::atoms::__struct__), + fine::encode(env, exla::atoms::type), + fine::encode(env, exla::atoms::shape), + }; + + ERL_NIF_TERM values[] = { + fine::encode(env, exla::atoms::ElixirEXLATypespec), + encode_type(env, type), + encode_shape(env, type), + }; + + ERL_NIF_TERM map; + if (!enif_make_map_from_arrays(env, keys, values, 3, &map)) { + throw std::runtime_error("encode: failed to make a map"); + } + + return map; } - return 1; -} - -template -int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var) { - unsigned int length; - if (!enif_get_list_length(env, list, &length)) return 0; - var.reserve(length); - ERL_NIF_TERM head, tail; - - while (enif_get_list_cell(env, list, &head, &tail)) { - T* elem; - if (!get(env, head, elem)) return 0; - var.push_back(*elem); - list = tail; + +private: + static ERL_NIF_TERM encode_type(ErlNifEnv *env, const mlir::Type &type) { + if (mlir::isa(type)) { + return fine::encode(env, exla::atoms::token); + } + + std::optional type_name; + std::optional type_size; + + if (mlir::isa(type)) { + auto tensor_type = mlir::cast(type); + auto element_type = tensor_type.getElementType(); + + if (element_type.isSignlessInteger(1)) { + type_name = exla::atoms::pred; + type_size = 8; + } else if (auto integer_type = + mlir::dyn_cast(element_type)) { + if (integer_type.isUnsigned()) { + type_name = exla::atoms::u; + } else { + type_name = exla::atoms::s; + } + + type_size = integer_type.getWidth(); + } else if (element_type.isBF16()) { + type_name = exla::atoms::bf; + type_size = 16; + } else if (auto float_type = + mlir::dyn_cast(element_type)) { + type_name = exla::atoms::f; + type_size = float_type.getWidth(); + } else if (auto complex_type = + mlir::dyn_cast(element_type)) { + auto element_type = complex_type.getElementType(); + type_name = exla::atoms::c; + type_size = mlir::cast(element_type).getWidth() * 2; + } + } + + if (type_name) { + return fine::encode( + env, std::make_tuple(type_name.value(), type_size.value())); + } else { + throw std::invalid_argument("encode failed, unexpected mlir type"); + } } - return 1; -} -template -int get_keyword_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector>& var) { - unsigned int length; - if (!enif_get_list_length(env, list, &length)) return 0; - var.reserve(length); - ERL_NIF_TERM head, tail; + static ERL_NIF_TERM encode_shape(ErlNifEnv *env, const mlir::Type &type) { + if (mlir::isa(type)) { + return enif_make_tuple(env, 0); + } - while (enif_get_list_cell(env, list, &head, &tail)) { - const ERL_NIF_TERM* terms; - int count; + if (mlir::isa(type)) { + auto tensor_type = mlir::cast(type); + auto dims_array = tensor_type.getShape(); - if (!enif_get_tuple(env, head, &count, &terms)) return 0; - if (count != 2) return 0; + auto dims = std::vector{}; + dims.reserve(dims_array.size()); - std::string lo; - T hi; - if (!get_atom(env, terms[0], lo)) return 0; - if (!get(env, terms[1], hi)) return 0; + for (auto dim : dims_array) { + dims.push_back(fine::encode(env, dim)); + } - var.push_back(std::pair(lo, hi)); + return enif_make_tuple_from_array(env, dims.data(), dims.size()); + } - list = tail; + throw std::invalid_argument("encode failed, unexpected mlir type"); } - return 1; -} - -int get_binary(ErlNifEnv* env, ERL_NIF_TERM term, ErlNifBinary* var); - -ERL_NIF_TERM make_map(ErlNifEnv* env, std::map& map); - -// XLA Protobuf Types -// -// See: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto -// for more details on each type and additional types not listed here. - -// Gets encoded EXLA.Typespec as xla::Shape. -int get_typespec_as_xla_shape(ErlNifEnv* env, ERL_NIF_TERM term, xla::Shape* shape); - -} // namespace nif -} // namespace exla +}; +} // namespace fine // Helper Macros // -// See: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/stream_executor/lib/statusor.h +// See: +// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/stream_executor/lib/statusor.h -#define EXLA_STATUS_MACROS_CONCAT_NAME(x, y) \ +#define EXLA_STATUS_MACROS_CONCAT_NAME(x, y) \ EXLA_STATUS_MACROS_CONCAT_NAME_IMPL(x, y) #define EXLA_STATUS_MACROS_CONCAT_NAME_IMPL(x, y) x##y -// Macro to be used to consume StatusOr from within a NIF. Will -// bind lhs to value if the status is OK, otherwise will return -// `{:error, msg}`. -#define EXLA_ASSIGN_OR_RETURN_NIF(lhs, rexpr, env) \ - EXLA_ASSIGN_OR_RETURN_NIF_IMPL( \ - EXLA_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), \ - lhs, rexpr, env) - -#define EXLA_ASSIGN_OR_RETURN_NIF_IMPL(statusor, lhs, rexpr, env) \ - auto statusor = (rexpr); \ - if (!statusor.ok()) { \ - return exla::nif::error(env, statusor.status().message().data()); \ - } \ - lhs = std::move(statusor.value()); - // Macro to be used to consume StatusOr. Will bind lhs // to value if the status is OK, otherwise will return // the status. -#define EXLA_ASSIGN_OR_RETURN(lhs, rexpr) \ - EXLA_ASSIGN_OR_RETURN_IMPL( \ - EXLA_STATUS_MACROS_CONCAT_NAME( \ - _status_or_value, __COUNTER__), \ - lhs, rexpr) - -#define EXLA_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr) \ - auto statusor = (rexpr); \ - if (!statusor.ok()) { \ - return statusor.status(); \ - } \ +#define EXLA_ASSIGN_OR_RETURN(lhs, rexpr) \ + EXLA_ASSIGN_OR_RETURN_IMPL( \ + EXLA_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \ + rexpr) + +#define EXLA_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr) \ + auto statusor = (rexpr); \ + if (!statusor.ok()) { \ + return statusor.status(); \ + } \ lhs = std::move(statusor.value()); #endif diff --git a/exla/c_src/exla/ipc.cc b/exla/c_src/exla/ipc.cc index 9187170e2a..b1616c046b 100644 --- a/exla/c_src/exla/ipc.cc +++ b/exla/c_src/exla/ipc.cc @@ -32,7 +32,7 @@ void* open_ipc_handle(int fd, size_t memsize) { return ptr; } -int close_ipc_handle(int fd, void* ptr, char* memname, size_t memsize) { +int close_ipc_handle(int fd, void* ptr, const char* memname, size_t memsize) { if (munmap(ptr, memsize) == -1) { return -1; } @@ -44,4 +44,4 @@ int close_ipc_handle(int fd, void* ptr, char* memname, size_t memsize) { shm_unlink(memname); return 0; -} \ No newline at end of file +} diff --git a/exla/c_src/exla/ipc.h b/exla/c_src/exla/ipc.h index 5458a0092a..f11a13d582 100644 --- a/exla/c_src/exla/ipc.h +++ b/exla/c_src/exla/ipc.h @@ -4,4 +4,4 @@ int get_ipc_handle(const char* memname, size_t memsize); void* open_ipc_handle(int fd, size_t memsize); -int close_ipc_handle(int fd, void* ptr, char* memname, size_t memsize); +int close_ipc_handle(int fd, void* ptr, const char* memname, size_t memsize); diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index f242780fc9..a94511cb19 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -114,34 +114,35 @@ defmodule EXLA.Backend do client = EXLA.Client.fetch!(buffer.client_name) case EXLA.NIF.get_buffer_device_pointer(client.ref, buffer.ref, mode) do - {:ok, result} -> - handle = - case {result, mode} do - {{ptr, size}, :local} when is_integer(ptr) -> - # pointer is an integer here - %Nx.Pointer{kind: :local, address: ptr, data_size: size} - - {{handle_name, fd, size}, :host_ipc} -> - %Nx.Pointer{ - kind: :ipc, - handle: handle_name, - address: fd, - data_size: size - } - - {{handle, size}, :cuda_ipc} -> - %Nx.Pointer{ - kind: :ipc, - handle: handle, - address: buffer.device_id, - data_size: size - } - end - - {:ok, handle} - - error -> - error + {:ok, ptr, size} when mode == :local and is_integer(ptr) -> + # Pointer is an integer here + {:ok, + %Nx.Pointer{ + kind: :local, + address: ptr, + data_size: size + }} + + {:ok, handle_name, fd, size} when mode == :host_ipc -> + {:ok, + %Nx.Pointer{ + kind: :ipc, + handle: handle_name, + address: fd, + data_size: size + }} + + {:ok, handle, size} when mode === :cuda_ipc -> + {:ok, + %Nx.Pointer{ + kind: :ipc, + handle: handle, + address: buffer.device_id, + data_size: size + }} + + {:error, reason} -> + {:error, reason} end end @@ -169,7 +170,7 @@ defmodule EXLA.Backend do raise ArgumentError, "invalid pointer data_size for shape, expected: #{shape_size}, got: #{size}" - {%Nx.Pointer{address: address, kind: :local}, _} -> + {%Nx.Pointer{kind: :local, address: address}, _} -> {:local, address} {%Nx.Pointer{kind: :ipc, address: fd, handle: handle}, :host} -> @@ -184,7 +185,7 @@ defmodule EXLA.Backend do client.ref, mode, handle_nif, - EXLA.Typespec.nif_encode(typespec), + typespec, device_id ) diff --git a/exla/lib/exla/client.ex b/exla/lib/exla/client.ex index 688be5a70f..2d14d09955 100644 --- a/exla/lib/exla/client.ex +++ b/exla/lib/exla/client.ex @@ -78,11 +78,6 @@ defmodule EXLA.Client do """ def get_supported_platforms do EXLA.NIF.get_supported_platforms() - |> unwrap!() - |> Map.new(fn {k, v} -> - k = k |> List.to_string() |> String.downcase(:ascii) |> String.to_atom() - {k, v} - end) end @doc """ @@ -94,12 +89,8 @@ defmodule EXLA.Client do """ def to_infeed(%EXLA.Client{ref: client}, device_id, data_and_typespecs) when is_list(data_and_typespecs) do - data_and_typespecs = - Enum.map(data_and_typespecs, fn {binary, typespec} when is_binary(binary) -> - {binary, EXLA.Typespec.nif_encode(typespec)} - end) - - EXLA.NIF.transfer_to_infeed(client, device_id, data_and_typespecs) |> unwrap!() + {buffers, typespecs} = Enum.unzip(data_and_typespecs) + EXLA.NIF.transfer_to_infeed(client, device_id, buffers, typespecs) end @doc """ @@ -107,8 +98,7 @@ defmodule EXLA.Client do """ def from_outfeed(%EXLA.Client{ref: client}, device_id, typespecs, pid, ref) when is_list(typespecs) do - typespecs = Enum.map(typespecs, &EXLA.Typespec.nif_encode/1) - EXLA.NIF.transfer_from_outfeed(client, device_id, typespecs, pid, ref) |> unwrap!() + EXLA.NIF.transfer_from_outfeed(client, device_id, typespecs, pid, ref) end ## Callbacks @@ -134,7 +124,6 @@ defmodule EXLA.Client do platform = Keyword.get(options, :platform) memory_fraction = Keyword.get(options, :memory_fraction, 0.9) preallocate = Keyword.get(options, :preallocate, true) - preallocate_int = if preallocate, do: 1, else: 0 platforms = Map.keys(EXLA.Client.get_supported_platforms()) ref = @@ -151,10 +140,10 @@ defmodule EXLA.Client do EXLA.NIF.get_host_client() :cuda -> - EXLA.NIF.get_gpu_client(memory_fraction, preallocate_int) + EXLA.NIF.get_gpu_client(memory_fraction, preallocate) :rocm -> - EXLA.NIF.get_gpu_client(memory_fraction, preallocate_int) + EXLA.NIF.get_gpu_client(memory_fraction, preallocate) :tpu -> EXLA.NIF.get_tpu_client() @@ -162,9 +151,8 @@ defmodule EXLA.Client do _ -> raise ArgumentError, "unknown EXLA platform: #{inspect(platform)}" end - |> unwrap!() - device_count = EXLA.NIF.get_device_count(ref) |> unwrap!() + device_count = EXLA.NIF.get_device_count(ref) default_device_id = Keyword.get(options, :default_device_id, 0) if default_device_id not in 0..(device_count - 1) do @@ -182,8 +170,4 @@ defmodule EXLA.Client do automatic_transfers: automatic_transfers } end - - defp unwrap!(:ok), do: :ok - defp unwrap!({:ok, ref}), do: ref - defp unwrap!({:error, error}), do: raise(List.to_string(error)) end diff --git a/exla/lib/exla/device_buffer.ex b/exla/lib/exla/device_buffer.ex index d1753d7630..816e4847dc 100644 --- a/exla/lib/exla/device_buffer.ex +++ b/exla/lib/exla/device_buffer.ex @@ -33,10 +33,7 @@ defmodule EXLA.DeviceBuffer do data end - ref = - client.ref - |> EXLA.NIF.binary_to_device_mem(data, EXLA.Typespec.nif_encode(typespec), device_id) - |> unwrap!() + ref = EXLA.NIF.binary_to_device_mem(client.ref, data, typespec, device_id) %DeviceBuffer{ref: ref, client_name: client.name, device_id: device_id, typespec: typespec} end @@ -50,7 +47,7 @@ defmodule EXLA.DeviceBuffer do device_id ) when is_integer(device_id) do - ref = client.ref |> EXLA.NIF.copy_buffer_to_device(buffer, device_id) |> unwrap!() + ref = EXLA.NIF.copy_buffer_to_device(client.ref, buffer, device_id) %DeviceBuffer{ref: ref, client_name: client.name, device_id: device_id, typespec: typespec} end @@ -62,7 +59,7 @@ defmodule EXLA.DeviceBuffer do reads the whole buffer. """ def read(%DeviceBuffer{ref: ref, typespec: typespec}, size \\ -1) do - data = EXLA.NIF.read_device_mem(ref, size) |> unwrap!() + data = EXLA.NIF.read_device_mem(ref, size) # At the moment XLA does not support reading a packed buffer, # so we pack the elements ourselves @@ -83,10 +80,10 @@ defmodule EXLA.DeviceBuffer do Returns `:ok` | `:already_deallocated`. """ - def deallocate(%DeviceBuffer{ref: ref}), - do: EXLA.NIF.deallocate_device_mem(ref) |> unwrap!() - - defp unwrap!({:ok, ref}), do: ref - defp unwrap!({:error, error}), do: raise(List.to_string(error)) - defp unwrap!(status) when is_atom(status), do: status + def deallocate(%DeviceBuffer{ref: ref}) do + case EXLA.NIF.deallocate_device_mem(ref) do + :ok -> :ok + {:error, :already_deallocated} -> :already_deallocated + end + end end diff --git a/exla/lib/exla/executable.ex b/exla/lib/exla/executable.ex index e3494f60bf..15ffbbdfe0 100644 --- a/exla/lib/exla/executable.ex +++ b/exla/lib/exla/executable.ex @@ -51,7 +51,6 @@ defmodule EXLA.Executable do serialized_exec = ref |> EXLA.NIF.serialize_executable() - |> unwrap!() |> IO.iodata_to_binary() %{ @@ -79,10 +78,7 @@ defmodule EXLA.Executable do device_id: device_id } = data - ref = - serialized - |> then(&EXLA.NIF.deserialize_executable(client.ref, &1)) - |> unwrap!() + ref = EXLA.NIF.deserialize_executable(client.ref, serialized) %EXLA.Executable{ output_typespecs: output_typespecs, @@ -102,17 +98,14 @@ defmodule EXLA.Executable do ref %BinaryBuffer{data: data, typespec: typespec} -> - {data, EXLA.Typespec.nif_encode(typespec)} + {data, typespec} end) end - data = - case client.platform do - :host -> EXLA.NIF.run_cpu(client.ref, ref, inputs, device_id) - _ -> EXLA.NIF.run_io(client.ref, ref, inputs, device_id) - end - - unwrap!(data) + case client.platform do + :host -> EXLA.NIF.run_cpu(ref, inputs, device_id) + _ -> EXLA.NIF.run_io(ref, inputs, device_id) + end end defp decompose_output({data, device_id}, output_typespecs, client) do @@ -124,8 +117,4 @@ defmodule EXLA.Executable do BinaryBuffer.from_binary(buf, typespec) end) end - - defp unwrap!(:ok), do: :ok - defp unwrap!({:ok, ref}), do: ref - defp unwrap!({:error, error}), do: raise(List.to_string(error)) end diff --git a/exla/lib/exla/mlir/context_pool.ex b/exla/lib/exla/mlir/context_pool.ex index 9e0d8155bb..081e922a8d 100644 --- a/exla/lib/exla/mlir/context_pool.ex +++ b/exla/lib/exla/mlir/context_pool.ex @@ -14,14 +14,14 @@ defmodule EXLA.MLIR.ContextPool do @impl NimblePool def init_pool(%{pool_size: pool_size}) do - {:ok, thread_pool} = EXLA.NIF.mlir_new_thread_pool(pool_size) + thread_pool = EXLA.NIF.mlir_new_thread_pool(pool_size) {:ok, %{thread_pool: thread_pool}} end @impl NimblePool def init_worker(%{thread_pool: thread_pool} = pool_state) do - {:ok, context} = EXLA.NIF.mlir_new_context(thread_pool) + context = EXLA.NIF.mlir_new_context(thread_pool) {:ok, context, pool_state} end diff --git a/exla/lib/exla/mlir/function.ex b/exla/lib/exla/mlir/function.ex index d9aaa4f7b1..7b1157955a 100644 --- a/exla/lib/exla/mlir/function.ex +++ b/exla/lib/exla/mlir/function.ex @@ -13,7 +13,7 @@ defmodule EXLA.MLIR.Function do which can be used in MLIR operations. """ def get_arguments(%Function{ref: ref} = function) do - arg_refs = EXLA.NIF.mlir_get_function_arguments(ref) |> unwrap!() + arg_refs = EXLA.NIF.mlir_get_function_arguments(ref) Enum.map(arg_refs, fn arg -> %Value{ref: arg, function: function} end) end @@ -26,7 +26,7 @@ defmodule EXLA.MLIR.Function do """ def push_region(%Function{ref: ref} = function, arg_typespecs) do arg_mlir_types = Value.typespecs_to_mlir_types(arg_typespecs) - {region, args} = EXLA.NIF.mlir_push_region(ref, arg_mlir_types) |> unwrap!() + {region, args} = EXLA.NIF.mlir_push_region(ref, arg_mlir_types) {%Region{ref: region}, Enum.map(args, &%Value{function: function, ref: &1})} end @@ -34,10 +34,6 @@ defmodule EXLA.MLIR.Function do Pops region created with `push_region/2`. """ def pop_region(%Function{ref: ref}) do - EXLA.NIF.mlir_pop_region(ref) |> unwrap!() + EXLA.NIF.mlir_pop_region(ref) end - - defp unwrap!(:ok), do: :ok - defp unwrap!({:ok, value}), do: value - defp unwrap!(other), do: raise("error: #{inspect(other)}") end diff --git a/exla/lib/exla/mlir/module.ex b/exla/lib/exla/mlir/module.ex index 76a52af31b..04f1c38a3c 100644 --- a/exla/lib/exla/mlir/module.ex +++ b/exla/lib/exla/mlir/module.ex @@ -15,7 +15,7 @@ defmodule EXLA.MLIR.Module do """ def new(arg_typespecs, return_typespecs, fun) when is_function(fun, 1) do ContextPool.checkout(fn context -> - ref = context |> EXLA.NIF.mlir_new_module() |> unwrap!() + ref = EXLA.NIF.mlir_new_module(context) %__MODULE__{ref: ref} |> create_function("main", arg_typespecs, return_typespecs, true) @@ -47,9 +47,8 @@ defmodule EXLA.MLIR.Module do name, arg_types, return_types, - if(is_public, do: 1, else: 0) + is_public ) - |> unwrap!() %Function{module: module, ref: ref, name: name, return_typespecs: return_typespecs} end @@ -96,7 +95,7 @@ defmodule EXLA.MLIR.Module do # JAX comments say SPMD can lead to subtle bugs so they only enable # when strictly necessary, which is when num_partitions is greater than 1. - use_spmd = if Keyword.get(options, :use_spmd, true) or num_partitions >= 1, do: 1, else: 0 + use_spmd = Keyword.get(options, :use_spmd, true) or num_partitions >= 1 device_id = if num_replicas > 1 or num_partitions > 1, @@ -115,13 +114,12 @@ defmodule EXLA.MLIR.Module do EXLA.NIF.mlir_compile( client.ref, module.ref, - Enum.map(argument_typespecs, &EXLA.Typespec.nif_encode/1), + argument_typespecs, num_replicas, num_partitions, use_spmd, device_id ) - |> unwrap!() end %Executable{ @@ -139,10 +137,6 @@ defmodule EXLA.MLIR.Module do syntax. """ def as_string(module = %__MODULE__{}) do - EXLA.NIF.mlir_module_to_string(module.ref) |> unwrap!() + EXLA.NIF.mlir_module_to_string(module.ref) end - - defp unwrap!(:ok), do: :ok - defp unwrap!({:ok, ref}), do: ref - defp unwrap!({:error, error}), do: raise(List.to_string(error)) end diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 2b25c6f8f6..46baa95c8c 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -909,8 +909,6 @@ defmodule EXLA.MLIR.Value do def get_typespec(value) do EXLA.NIF.mlir_get_typespec(value.ref) - |> unwrap!() - |> Typespec.nif_decode() end def typespecs_to_mlir_types(shapes) do @@ -920,10 +918,6 @@ defmodule EXLA.MLIR.Value do defp typespec_to_mlir_type(%{type: :token}), do: type_token() defp typespec_to_mlir_type(%{type: type, shape: shape}), do: type_tensor(type, shape) - defp unwrap!(:ok), do: :ok - defp unwrap!({:ok, value}), do: value - defp unwrap!(other), do: raise("#{inspect(other)}") - defp one!([value]), do: value defp one!(other) do @@ -951,7 +945,6 @@ defmodule EXLA.MLIR.Value do opts[:attributes], opts[:regions] ) - |> unwrap!() Enum.map(refs, &%Value{function: function, ref: &1}) end diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index 023a0bcbd2..b6d93e5b32 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -20,26 +20,15 @@ defmodule EXLA.NIF do end end - def mlir_new_thread_pool(_concurrency), do: :erlang.nif_error(:undef) - def mlir_new_context(_thread_pool_ref), do: :erlang.nif_error(:undef) - - def mlir_new_module(_context), do: :erlang.nif_error(:undef) - - def mlir_create_function(_module, _name, _arg_types, _ret_type, _is_public), - do: :erlang.nif_error(:undef) - - def mlir_get_function_arguments(_function), do: :erlang.nif_error(:undef) - - def mlir_op(_function, _op_name, _operands, _result_type, _attributes, _blocks), - do: :erlang.nif_error(:undef) - - def mlir_push_region(_function, _arg_types), - do: :erlang.nif_error(:undef) - - def mlir_pop_region(_function), - do: :erlang.nif_error(:undef) - - def mlir_build(_function, _root), do: :erlang.nif_error(:undef) + def mlir_new_thread_pool(_concurrency), do: err!() + def mlir_new_context(_thread_pool_ref), do: err!() + def mlir_new_module(_context), do: err!() + def mlir_create_function(_module, _name, _arg_types, _ret_type, _is_public), do: err!() + def mlir_get_function_arguments(_function), do: err!() + def mlir_op(_function, _op_name, _operands, _result_type, _attributes, _blocks), do: err!() + def mlir_push_region(_function, _arg_types), do: err!() + def mlir_pop_region(_function), do: err!() + def mlir_build(_function, _root), do: err!() def mlir_compile( _client, @@ -50,80 +39,40 @@ defmodule EXLA.NIF do _use_spmd, _device_id ), - do: :erlang.nif_error(:undef) - - def mlir_get_typespec(_tensor), do: :erlang.nif_error(:undef) - - def mlir_module_to_string(_builder), do: :erlang.nif_error(:undef) - - def get_host_client(), - do: :erlang.nif_error(:undef) - - def get_gpu_client( - _memory_fraction, - _preallocate - ), - do: :erlang.nif_error(:undef) - - def get_tpu_client(), do: :erlang.nif_error(:undef) - - def get_supported_platforms, do: :erlang.nif_error(:undef) - - def get_device_count(_client), - do: :erlang.nif_error(:undef) - - def serialize_executable(_executable), do: :erlang.nif_error(:undef) - def deserialize_executable(_client, _string), do: :erlang.nif_error(:undef) - - def run_cpu( - _client, - _executable, - _arguments, - _device_id - ), - do: :erlang.nif_error(:undef) + do: err!() - def run_io( - _client, - _executable, - _arguments, - _device_id - ), - do: :erlang.nif_error(:undef) + def mlir_get_typespec(_tensor), do: err!() + def mlir_module_to_string(_builder), do: err!() - def get_buffer_device_pointer(_client, _buffer, _pointer_kind), do: :erlang.nif_error(:undef) + def get_buffer_device_pointer(_client, _buffer, _pointer_kind), do: err!() def create_buffer_from_device_pointer( _client, - _opaque_pointer, _pointer_kind, + _pointer_data, _typespec, _device_id ), - do: :erlang.nif_error(:undef) - - def binary_to_device_mem(_client, _binary, _typespec, _device_ordinal), - do: :erlang.nif_error(:undef) - - def read_device_mem(_buffer, _size), - do: :erlang.nif_error(:undef) - - def deallocate_device_mem(_buffer), - do: :erlang.nif_error(:undef) - - def transfer_to_infeed(_client, _device, _data_typespecs), - do: :erlang.nif_error(:undef) - - def transfer_from_outfeed(_client, _device, _typespecs, _pid, _ref), - do: :erlang.nif_error(:undef) - - def copy_buffer_to_device(_client, _buffer, _device), - do: :erlang.nif_error(:undef) - - def start_log_sink(_sink_pid), - do: :erlang.nif_error(:undef) - - def get_c_api_client(_device_type), do: :erlang.nif_error(:undef) - - def load_pjrt_plugin(_device_type, _library_path), do: :erlang.nif_error(:undef) + do: err!() + + def binary_to_device_mem(_client, _binary, _typespec, _device_ordinal), do: err!() + def read_device_mem(_buffer, _size), do: err!() + def deallocate_device_mem(_buffer), do: err!() + def transfer_to_infeed(_client, _device, _buffers, _typespecs), do: err!() + def transfer_from_outfeed(_client, _device, _typespecs, _pid, _ref), do: err!() + def copy_buffer_to_device(_client, _buffer, _device), do: err!() + def get_host_client(), do: err!() + def get_gpu_client(_memory_fraction, _preallocate), do: err!() + def get_tpu_client(), do: err!() + def get_c_api_client(_device_type), do: err!() + def load_pjrt_plugin(_device_type, _library_path), do: err!() + def get_device_count(_client), do: err!() + def get_supported_platforms, do: err!() + def run_cpu(_executable, _arguments, _device_id), do: err!() + def run_io(_executable, _arguments, _device_id), do: err!() + def serialize_executable(_executable), do: err!() + def deserialize_executable(_client, _string), do: err!() + def start_log_sink(_sink_pid), do: err!() + + defp err!(), do: :erlang.nif_error(:undef) end diff --git a/exla/lib/exla/typespec.ex b/exla/lib/exla/typespec.ex index 60166ef84d..eeaac9d71e 100644 --- a/exla/lib/exla/typespec.ex +++ b/exla/lib/exla/typespec.ex @@ -39,45 +39,4 @@ defmodule EXLA.Typespec do Returns an updated typespec with the given shape. """ def to_shape(typespec, shape), do: %{typespec | shape: shape} - - @doc false - def nif_encode(typespec) do - {type_to_charlist(typespec.type), typespec.shape} - end - - @doc false - def nif_decode({type_charlist, shape}) do - %__MODULE__{shape: shape, type: charlist_to_type(type_charlist)} - end - - type_to_charlist = %{ - :token => ~c"token", - {:pred, 8} => ~c"pred", - {:s, 2} => ~c"s2", - {:s, 4} => ~c"s4", - {:s, 8} => ~c"s8", - {:s, 16} => ~c"s16", - {:s, 32} => ~c"s32", - {:s, 64} => ~c"s64", - {:u, 2} => ~c"u2", - {:u, 4} => ~c"u4", - {:u, 8} => ~c"u8", - {:u, 16} => ~c"u16", - {:u, 32} => ~c"u32", - {:u, 64} => ~c"u64", - {:f, 16} => ~c"f16", - {:f, 32} => ~c"f32", - {:f, 64} => ~c"f64", - {:bf, 16} => ~c"bf16", - {:c, 64} => ~c"c64", - {:c, 128} => ~c"c128" - } - - defp type_to_charlist({:f, 8}), do: ~c"f8e5m2" - defp charlist_to_type(~c"f8"), do: {:f, 8} - - for {type, charlist} <- type_to_charlist do - defp charlist_to_type(unquote(charlist)), do: unquote(type) - defp type_to_charlist(unquote(type)), do: unquote(charlist) - end end diff --git a/exla/mix.exs b/exla/mix.exs index 7ce6727d2f..2d1980bfd5 100644 --- a/exla/mix.exs +++ b/exla/mix.exs @@ -34,6 +34,7 @@ defmodule EXLA.MixProject do cwd_relative_to_priv = relative_to(File.cwd!(), priv_path) %{ + "FINE_INCLUDE_DIR" => Fine.include_dir(), "MIX_BUILD_EMBEDDED" => "#{Mix.Project.config()[:build_embedded]}", "CWD_RELATIVE_TO_PRIV_PATH" => cwd_relative_to_priv, "EXLA_VERSION" => "#{@version}" @@ -68,6 +69,7 @@ defmodule EXLA.MixProject do {:nx, path: "../nx"}, {:telemetry, "~> 0.4.0 or ~> 1.0"}, {:xla, "~> 0.8.0", runtime: false}, + {:fine, "~> 0.1.0", runtime: false}, {:elixir_make, "~> 0.6", runtime: false}, {:benchee, "~> 1.0", only: :dev}, {:ex_doc, "~> 0.29", only: :docs}, diff --git a/exla/mix.lock b/exla/mix.lock index 51edca3542..de6ba787a3 100644 --- a/exla/mix.lock +++ b/exla/mix.lock @@ -5,6 +5,7 @@ "earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"}, "elixir_make": {:hex, :elixir_make, "0.8.4", "4960a03ce79081dee8fe119d80ad372c4e7badb84c493cc75983f9d3bc8bde0f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "6e7f1d619b5f61dfabd0a20aa268e575572b542ac31723293a4c1a567d5ef040"}, "ex_doc": {:hex, :ex_doc, "0.34.2", "13eedf3844ccdce25cfd837b99bea9ad92c4e511233199440488d217c92571e8", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "5ce5f16b41208a50106afed3de6a2ed34f4acfd65715b82a0b84b49d995f95c1"}, + "fine": {:hex, :fine, "0.1.0", "9bb99a5ff9b968f12c3b458fa1277c39e9a620b23a9439103703a25917293871", [:mix], [], "hexpm", "1d6485bf811b95dc6ae3d197c0e6f994880b86167a827983bb29cbfc03a02684"}, "makeup": {:hex, :makeup, "1.1.2", "9ba8837913bdf757787e71c1581c21f9d2455f4dd04cfca785c70bbfff1a76a3", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cce1566b81fbcbd21eca8ffe808f33b221f9eee2cbc7a1706fc3da9ff18e6cac"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"}, "makeup_erlang": {:hex, :makeup_erlang, "1.0.1", "c7f58c120b2b5aa5fd80d540a89fdf866ed42f1f3994e4fe189abebeab610839", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "8a89a1eeccc2d798d6ea15496a6e4870b75e014d1af514b1b71fa33134f57814"}, diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index bde368fad7..8a1cd582cc 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -16823,11 +16823,11 @@ defmodule Nx do t = Nx.u8([10, 20, 30]) Nx.to_pointer(t, mode: :local) - %Nx.Pointer{kind: :local, address: 1234, data_size: 3, handle: nil} + #=> {:ok, %Nx.Pointer{kind: :local, address: 1234, data_size: 3, handle: nil}} t = Nx.s32([1, 2, 3]) Nx.to_pointer(t, mode: :ipc) - %Nx.Pointer{kind: :ipc, address: nil, data_size: 32, handle: "some-ipc-handle"} + #=> {:ok, %Nx.Pointer{kind: :ipc, address: nil, data_size: 32, handle: "some-ipc-handle"}} """ @doc type: :creation def to_pointer(tensor, opts \\ []) do From 97e0f92445a3d8aa0ba7c4ff3de67e8f8c89f0b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Wed, 19 Feb 2025 13:22:28 +0100 Subject: [PATCH 10/36] Change Nx.to_pointer/2 and Nx.from_pointer/5 to raise on errors (#1582) Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com> --- exla/c_src/exla/exla.cc | 35 +++++++------- exla/c_src/exla/exla_cuda.cc | 4 +- exla/c_src/exla/ipc.cc | 3 +- exla/lib/exla/backend.ex | 48 +++++-------------- exla/test/exla/device_memory_sharing_test.exs | 4 +- nx/lib/nx.ex | 28 +++++------ nx/lib/nx/backend.ex | 4 +- 7 files changed, 48 insertions(+), 78 deletions(-) diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index 2108de5c6e..ed3ce31a03 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -217,9 +217,9 @@ FINE_NIF(mlir_compile, ERL_NIF_DIRTY_JOB_CPU_BOUND); // ExlaBuffer Functions -std::variant, - fine::Ok, - fine::Ok, fine::Error> +std::variant, + std::tuple, + std::tuple> get_buffer_device_pointer(ErlNifEnv *env, fine::ResourcePtr client, fine::Term buffer_term, fine::Atom pointer_kind) { auto buffer = decode_exla_buffer(env, buffer_term); @@ -228,7 +228,7 @@ get_buffer_device_pointer(ErlNifEnv *env, fine::ResourcePtr client, uint64_t ptr = unwrap(buffer->GetDevicePointer(client->client())); if (pointer_kind == "local") { - return fine::Ok(ptr, device_size); + return std::make_tuple(pointer_kind, ptr, device_size); } if (pointer_kind == "host_ipc") { @@ -237,26 +237,27 @@ get_buffer_device_pointer(ErlNifEnv *env, fine::ResourcePtr client, auto fd = get_ipc_handle(handle_name.c_str(), device_size); if (fd == -1) { - return fine::Error(std::string("unable to get IPC handle")); + throw std::runtime_error("unable to get IPC handle"); } auto ipc_ptr = open_ipc_handle(fd, device_size); if (ipc_ptr == nullptr) { - return fine::Error(std::string("unable to open IPC handle")); + throw std::runtime_error("unable to open IPC handle"); } memcpy(ipc_ptr, reinterpret_cast(ptr), device_size); - return fine::Ok(handle_name, static_cast(fd), device_size); + return std::make_tuple(pointer_kind, handle_name, static_cast(fd), + device_size); } if (pointer_kind == "cuda_ipc") { auto maybe_handle = get_cuda_ipc_handle(ptr); if (!maybe_handle) { - return fine::Error(std::string("unable to get cuda IPC handle")); + throw std::runtime_error("unable to get cuda IPC handle"); } - return fine::Ok(maybe_handle.value(), device_size); + return std::make_tuple(pointer_kind, maybe_handle.value(), device_size); } throw std::invalid_argument("unexpected pointer type"); @@ -264,12 +265,10 @@ get_buffer_device_pointer(ErlNifEnv *env, fine::ResourcePtr client, FINE_NIF(get_buffer_device_pointer, 0); -std::variant>, fine::Error> -create_buffer_from_device_pointer(ErlNifEnv *env, - fine::ResourcePtr client, - fine::Atom pointer_kind, - fine::Term pointer_data, xla::Shape shape, - int64_t device_id) { +fine::ResourcePtr create_buffer_from_device_pointer( + ErlNifEnv *env, fine::ResourcePtr client, + fine::Atom pointer_kind, fine::Term pointer_data, xla::Shape shape, + int64_t device_id) { void *ptr = nullptr; std::function on_delete_callback = []() {}; @@ -278,7 +277,7 @@ create_buffer_from_device_pointer(ErlNifEnv *env, auto maybe_pointer = get_pointer_for_ipc_handle( cuda_ipc_handle_bin.data, cuda_ipc_handle_bin.size, device_id); if (!maybe_pointer) { - return fine::Error("unable to get pointer for IPC handle"); + throw std::runtime_error("unable to get pointer for IPC handle"); } ptr = maybe_pointer.value(); } else if (pointer_kind == "host_ipc") { @@ -289,7 +288,7 @@ create_buffer_from_device_pointer(ErlNifEnv *env, auto device_size = xla::ShapeUtil::ByteSizeOf(shape); ptr = open_ipc_handle(fd, device_size); if (ptr == nullptr) { - return fine::Error("unable to get pointer for IPC handle"); + throw std::runtime_error("unable to get pointer for IPC handle"); } on_delete_callback = [fd, memname, ptr, device_size]() { close_ipc_handle(fd, ptr, memname.c_str(), device_size); @@ -305,7 +304,7 @@ create_buffer_from_device_pointer(ErlNifEnv *env, client->client()->LookupDevice(xla::PjRtGlobalDeviceId(device_id))); auto buffer = unwrap(client->client()->CreateViewOfDeviceBuffer( ptr, shape, device, on_delete_callback)); - return fine::Ok(fine::make_resource(std::move(buffer))); + return fine::make_resource(std::move(buffer)); } FINE_NIF(create_buffer_from_device_pointer, 0); diff --git a/exla/c_src/exla/exla_cuda.cc b/exla/c_src/exla/exla_cuda.cc index 395fce3f9b..6f5bbe7b97 100644 --- a/exla/c_src/exla/exla_cuda.cc +++ b/exla/c_src/exla/exla_cuda.cc @@ -20,9 +20,7 @@ std::optional get_cuda_ipc_handle(std::uintptr_t ptr) { const size_t size = sizeof(cudaIpcMemHandle_t); // Copy the memory handle to a buffer - std::string buffer; - buffer.resize(size); - memcpy(&(*(buffer.begin())), &ipc_handle, size); + auto buffer = std::string(reinterpret_cast(&ipc_handle), size); return buffer; } diff --git a/exla/c_src/exla/ipc.cc b/exla/c_src/exla/ipc.cc index b1616c046b..afcbb09859 100644 --- a/exla/c_src/exla/ipc.cc +++ b/exla/c_src/exla/ipc.cc @@ -1,12 +1,11 @@ #include "ipc.h" +#include #include #include #include #include -#include - // Function to create or open a shared memory object and set its size int get_ipc_handle(const char* memname, size_t memsize) { int fd = shm_open(memname, O_CREAT | O_RDWR, 0666); diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index a94511cb19..2ac5c62646 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -114,35 +114,15 @@ defmodule EXLA.Backend do client = EXLA.Client.fetch!(buffer.client_name) case EXLA.NIF.get_buffer_device_pointer(client.ref, buffer.ref, mode) do - {:ok, ptr, size} when mode == :local and is_integer(ptr) -> + {:local, ptr, size} -> # Pointer is an integer here - {:ok, - %Nx.Pointer{ - kind: :local, - address: ptr, - data_size: size - }} - - {:ok, handle_name, fd, size} when mode == :host_ipc -> - {:ok, - %Nx.Pointer{ - kind: :ipc, - handle: handle_name, - address: fd, - data_size: size - }} - - {:ok, handle, size} when mode === :cuda_ipc -> - {:ok, - %Nx.Pointer{ - kind: :ipc, - handle: handle, - address: buffer.device_id, - data_size: size - }} - - {:error, reason} -> - {:error, reason} + %Nx.Pointer{kind: :local, address: ptr, data_size: size} + + {:host_ipc, handle_name, fd, size} -> + %Nx.Pointer{kind: :ipc, handle: handle_name, address: fd, data_size: size} + + {:cuda_ipc, handle, size} -> + %Nx.Pointer{kind: :ipc, handle: handle, address: buffer.device_id, data_size: size} end end @@ -180,7 +160,7 @@ defmodule EXLA.Backend do {:cuda_ipc, handle} end - result = + ref = EXLA.NIF.create_buffer_from_device_pointer( client.ref, mode, @@ -189,14 +169,8 @@ defmodule EXLA.Backend do device_id ) - case result do - {:ok, ref} -> - buffer = EXLA.DeviceBuffer.from_ref(ref, client, device_id, typespec) - {:ok, %{template | data: %EXLA.Backend{buffer: buffer}}} - - error -> - error - end + buffer = EXLA.DeviceBuffer.from_ref(ref, client, device_id, typespec) + %{template | data: %EXLA.Backend{buffer: buffer}} end @impl true diff --git a/exla/test/exla/device_memory_sharing_test.exs b/exla/test/exla/device_memory_sharing_test.exs index 09e54a42eb..7ef3b165ef 100644 --- a/exla/test/exla/device_memory_sharing_test.exs +++ b/exla/test/exla/device_memory_sharing_test.exs @@ -11,9 +11,9 @@ defmodule EXLA.DeviceMemorySharingTest do assert inspect(t1) =~ "1, 2, 3" - assert {:ok, pointer} = Nx.to_pointer(t1, mode: :local) + assert pointer = Nx.to_pointer(t1, mode: :local) - assert {:ok, t2} = + assert t2 = Nx.from_pointer( {EXLA.Backend, client: unquote(client_name)}, pointer, diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 8a1cd582cc..170652e4d7 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -16771,21 +16771,21 @@ defmodule Nx do pointer = %Nx.Pointer{kind: :local, address: 1234} Nx.from_pointer(MyBackend, pointer, {:s, 32}, {1, 3}) - #Nx.Tensor< - s32[1][3] - [ - [10, 20, 30] - ] - > + #=> #Nx.Tensor< + #=> s32[1][3] + #=> [ + #=> [10, 20, 30] + #=> ] + #=> > pointer = %Nx.Pointer{kind: :ipc, handle: "some-ipc-handle"} Nx.from_pointer({MyBackend, some: :opt}, pointer, {:s, 32}, {1, 3}, names: [nil, :col]) - #Nx.Tensor< - s32[1][col: 3] - [ - [10, 20, 30] - ] - > + #=> #Nx.Tensor< + #=> s32[1][col: 3] + #=> [ + #=> [10, 20, 30] + #=> ] + #=> > """ @doc type: :creation def from_pointer(backend, pointer, type, shape, opts \\ []) @@ -16823,11 +16823,11 @@ defmodule Nx do t = Nx.u8([10, 20, 30]) Nx.to_pointer(t, mode: :local) - #=> {:ok, %Nx.Pointer{kind: :local, address: 1234, data_size: 3, handle: nil}} + #=> %Nx.Pointer{kind: :local, address: 1234, data_size: 3, handle: nil} t = Nx.s32([1, 2, 3]) Nx.to_pointer(t, mode: :ipc) - #=> {:ok, %Nx.Pointer{kind: :ipc, address: nil, data_size: 32, handle: "some-ipc-handle"}} + #=> %Nx.Pointer{kind: :ipc, address: nil, data_size: 32, handle: "some-ipc-handle"} """ @doc type: :creation def to_pointer(tensor, opts \\ []) do diff --git a/nx/lib/nx/backend.ex b/nx/lib/nx/backend.ex index 1193a0b8c4..638a1dfdde 100644 --- a/nx/lib/nx/backend.ex +++ b/nx/lib/nx/backend.ex @@ -57,8 +57,8 @@ defmodule Nx.Backend do shape :: tuple(), backend_opts :: keyword(), opts :: keyword() - ) :: {:ok, tensor} | {:error, term()} - @callback to_pointer(tensor, opts :: keyword) :: {:ok, term()} | {:error, term()} + ) :: tensor | no_return() + @callback to_pointer(tensor, opts :: keyword) :: term() | no_return() @callback as_type(out :: tensor, tensor) :: tensor @callback bitcast(out :: tensor, tensor) :: tensor From ca0fe2b9ac1d35d65a8ff385c6d251abc17690b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Lenzi?= <61877861+TomasPegado@users.noreply.github.com> Date: Wed, 5 Mar 2025 13:07:01 -0300 Subject: [PATCH 11/36] Lu matrix decomposition (#1587) Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com> --- exla/lib/exla/backend.ex | 1 - nx/lib/nx/backend.ex | 1 + nx/lib/nx/binary_backend.ex | 31 ------ nx/lib/nx/binary_backend/matrix.ex | 135 ----------------------- nx/lib/nx/defn/expr.ex | 8 -- nx/lib/nx/defn/grad.ex | 23 ---- nx/lib/nx/lin_alg.ex | 54 ++++----- nx/lib/nx/lin_alg/lu.ex | 169 +++++++++++++++++++++++++++++ nx/lib/nx/lin_alg/qr.ex | 4 +- nx/test/nx/lin_alg_test.exs | 4 +- 10 files changed, 202 insertions(+), 228 deletions(-) create mode 100644 nx/lib/nx/lin_alg/lu.ex diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index 2ac5c62646..5135e1770f 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -381,7 +381,6 @@ defmodule EXLA.Backend do [:tensor, :source, :init_value]}, {:indexed_add, [:tensor, :indices, :updates, :opts], [:tensor, :indices, :updates]}, {:indexed_put, [:tensor, :indices, :updates, :opts], [:tensor, :indices, :updates]}, - {:lu, [:tensor, :opts], [:tensor]}, {:triangular_solve, [:a, :b, :opts], [:a, :b]}, {:fft, [:tensor, :opts], [:tensor]}, {:ifft, [:tensor, :opts], [:tensor]} diff --git a/nx/lib/nx/backend.ex b/nx/lib/nx/backend.ex index 638a1dfdde..3c463ba237 100644 --- a/nx/lib/nx/backend.ex +++ b/nx/lib/nx/backend.ex @@ -175,6 +175,7 @@ defmodule Nx.Backend do top_k: 3, fft2: 3, ifft2: 3, + lu: 3, qr: 3, cholesky: 2, eigh: 3, diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index 10c3b58e33..74c3247585 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -1258,26 +1258,6 @@ defmodule Nx.BinaryBackend do output_batch_groups |> Enum.with_index() |> Enum.map(fn {x, i} -> {x, rem(i, groups)} end) end - @impl true - def lu( - {%{type: p_type} = p_holder, %{type: l_type} = l_holder, %{type: u_type} = u_holder}, - %{type: input_type, shape: input_shape} = tensor, - opts - ) do - bin = to_binary(tensor) - rank = tuple_size(input_shape) - n = elem(input_shape, rank - 1) - - {p, l, u} = - bin_batch_reduce(bin, n * n, input_type, {<<>>, <<>>, <<>>}, fn matrix, - {p_acc, l_acc, u_acc} -> - {p, l, u} = B.Matrix.lu(matrix, input_type, {n, n}, p_type, l_type, u_type, opts) - {p_acc <> p, l_acc <> l, u_acc <> u} - end) - - {from_binary(p_holder, p), from_binary(l_holder, l), from_binary(u_holder, u)} - end - @impl true def triangular_solve( %{type: output_type} = out, @@ -2414,17 +2394,6 @@ defmodule Nx.BinaryBackend do bin_zip_reduce_axis(rest1, rest2, s1, s2, bin, acc, fun) end - defp bin_batch_reduce(bin, batch_size, {_, size}, acc, fun) do - batch_bit_size = batch_size * size - batches = bit_size(bin) |> div(batch_bit_size) - - for i <- 0..(batches - 1), reduce: acc do - acc -> - batch = bitstring_part(bin, i * batch_bit_size, batch_bit_size) - fun.(batch, acc) - end - end - ## Conversion helpers defp bitstring_part(bitstring, skip, size) do diff --git a/nx/lib/nx/binary_backend/matrix.ex b/nx/lib/nx/binary_backend/matrix.ex index afb55fb668..f989dd748c 100644 --- a/nx/lib/nx/binary_backend/matrix.ex +++ b/nx/lib/nx/binary_backend/matrix.ex @@ -1,8 +1,6 @@ defmodule Nx.BinaryBackend.Matrix do @moduledoc false use Complex.Kernel - import Kernel, except: [abs: 1] - import Complex, only: [abs: 1] import Nx.Shared @@ -116,107 +114,6 @@ defmodule Nx.BinaryBackend.Matrix do defp do_ts([], [], _idx, acc), do: acc - def lu(input_data, input_type, {n, n} = input_shape, p_type, l_type, u_type, opts) do - a = binary_to_matrix(input_data, input_type, input_shape) - eps = opts[:eps] - - {p, a_prime} = lu_validate_and_pivot(a, n) - - # We'll work with linear indices because of the way each matrix - # needs to be updated/accessed - zeros_matrix = List.duplicate(List.duplicate(0, n), n) - - {l, u} = - for j <- 0..(n - 1), reduce: {zeros_matrix, zeros_matrix} do - {l, u} -> - l = replace_matrix_element(l, j, j, 1.0) - - u = - for i <- 0..j, reduce: u do - u -> - u_slice = slice_matrix(u, [0, j], [i, 1]) - l_slice = slice_matrix(l, [i, 0], [1, i]) - sum = dot_matrix(u_slice, l_slice) - [a_ij] = get_matrix_elements(a_prime, [[i, j]]) - - value = a_ij - sum - - if abs(value) < eps do - replace_matrix_element(u, i, j, 0) - else - replace_matrix_element(u, i, j, value) - end - end - - l = - for i <- j..(n - 1), i != j, reduce: l do - l -> - u_slice = slice_matrix(u, [0, j], [i, 1]) - l_slice = slice_matrix(l, [i, 0], [1, i]) - sum = dot_matrix(u_slice, l_slice) - - [a_ij] = get_matrix_elements(a_prime, [[i, j]]) - [u_jj] = get_matrix_elements(u, [[j, j]]) - - value = - cond do - u_jj != 0 -> - (a_ij - sum) / u_jj - - a_ij >= sum -> - :infinity - - true -> - :neg_infinity - end - - if abs(value) < eps do - replace_matrix_element(l, i, j, 0) - else - replace_matrix_element(l, i, j, value) - end - end - - {l, u} - end - - # Transpose because since P is orthogonal, inv(P) = tranpose(P) - # and we want to return P such that A = P.L.U - {p |> transpose_matrix() |> matrix_to_binary(p_type), - l |> approximate_zeros(eps) |> matrix_to_binary(l_type), - u |> approximate_zeros(eps) |> matrix_to_binary(u_type)} - end - - defp lu_validate_and_pivot(a, n) do - # pivots a tensor so that the biggest elements of each column lie on the diagonal. - # if any of the diagonal elements ends up being 0, raises an ArgumentError - - identity = - Enum.map(0..(n - 1), fn i -> Enum.map(0..(n - 1), fn j -> if i == j, do: 1, else: 0 end) end) - - # For each row, find the max value by column. - # If its index (max_idx) is not in the diagonal (i.e. j != max_idx) - # we need to swap rows j and max_idx in both the permutation matrix - # and in the a matrix. - Enum.reduce(0..(n - 2), {identity, a}, fn j, {p, a} -> - [max_idx | _] = - Enum.sort_by(j..(n - 1), fn i -> a |> Enum.at(i) |> Enum.at(j) |> abs() end, &>=/2) - - if max_idx == j do - {p, a} - else - p_row = Enum.at(p, max_idx) - p_j = Enum.at(p, j) - p = p |> List.replace_at(max_idx, p_j) |> List.replace_at(j, p_row) - - a_row = Enum.at(a, max_idx) - a_j = Enum.at(a, j) - a = a |> List.replace_at(max_idx, a_j) |> List.replace_at(j, a_row) - {p, a} - end - end) - end - ## Matrix (2-D array) manipulation defp dot_matrix([], _), do: 0 @@ -279,41 +176,9 @@ defmodule Nx.BinaryBackend.Matrix do |> Enum.chunk_every(num_cols) end - defp slice_matrix(a, [row_start, col_start], [row_length, col_length]) do - a - |> Enum.slice(row_start, row_length) - |> Enum.flat_map(&Enum.slice(&1, col_start, col_length)) - end - defp get_matrix_column(m, col) do Enum.map(m, fn row -> Enum.at(row, col) end) end - - defp get_matrix_elements(m, row_col_pairs) do - Enum.map(row_col_pairs, fn [row, col] -> - m - |> Enum.at(row, []) - |> Enum.at(col) - |> case do - nil -> raise ArgumentError, "invalid index [#{row},#{col}] for matrix" - item -> item - end - end) - end - - defp replace_matrix_element(m, row, col, value) do - updated = m |> Enum.at(row) |> List.replace_at(col, value) - List.replace_at(m, row, updated) - end - - defp approximate_zeros(matrix, tol) do - do_round = fn x -> if Complex.abs(x) < tol, do: 0 * x, else: x end - - Enum.map(matrix, fn - row when is_list(row) -> Enum.map(row, do_round) - e -> do_round.(e) - end) - end end diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index a51d72e677..3940be47d9 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -1188,14 +1188,6 @@ defmodule Nx.Defn.Expr do expr(out, context, :triangular_solve, [a, b, opts]) end - @impl true - def lu({p, l, u}, tensor, opts) do - tensor = to_expr(tensor) - context = tensor.data.context - out = %T{names: [], shape: {}, type: {:tuple, 3}} - tuple(expr(out, context, :lu, [{p, l, u}, tensor, opts]), [p, l, u]) - end - @impl true def sort(out, tensor, opts) do %{data: %{context: context}} = tensor = to_expr(tensor) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 25b4a1178b..c18dbf0970 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -716,29 +716,6 @@ defmodule Nx.Defn.Grad do pairs end - defp grad(:lu, [{p, l, u}, input, _opts], ans, [_dp, dl, du]) do - # Definition taken from: https://sethaxen.com/blog/2021/02/differentiating-the-lu-decomposition/ - # Where dF = tril_strict(L^* . dL) + triu(dU . U^*) - # dA = P^t . (L^*)^-1 . dF . (U^*)^-1 - - {p, l, u} = Nx.Defn.Expr.tuple(ans, [p, l, u]) - - u_h = Nx.LinAlg.adjoint(u) - l_h = Nx.LinAlg.adjoint(l) - p_t = Nx.LinAlg.adjoint(p) - - lh_dl = Nx.dot(l_h, dl) - du_uh = Nx.dot(du, u_h) - - lt_inv = Nx.LinAlg.invert(l_h) - ut_inv = Nx.LinAlg.invert(u_h) - - df = lh_dl |> Nx.tril(k: -1) |> Nx.add(Nx.triu(du_uh)) - da = p_t |> Nx.dot(lt_inv) |> Nx.dot(df) |> Nx.dot(ut_inv) - - [{input, da}] - end - defp grad(:sort, [t, opts], _ans, g) do idx = Nx.argsort(t, opts) take_along_opts = Keyword.take(opts, [:axis]) diff --git a/nx/lib/nx/lin_alg.ex b/nx/lib/nx/lin_alg.ex index edced80246..abd0fda989 100644 --- a/nx/lib/nx/lin_alg.ex +++ b/nx/lib/nx/lin_alg.ex @@ -1523,7 +1523,7 @@ defmodule Nx.LinAlg do [ [1.0, 0.0, 0.0], [0.5714285969734192, 1.0, 0.0], - [0.1428571492433548, 2.0, 1.0] + [0.1428571492433548, 2.0000009536743164, 1.0] ] > iex> u @@ -1531,7 +1531,7 @@ defmodule Nx.LinAlg do f32[3][3] [ [7.0, 8.0, 9.0], - [0.0, 0.4285714328289032, 0.8571428656578064], + [0.0, 0.4285712242126465, 0.857142448425293], [0.0, 0.0, 0.0] ] > @@ -1607,7 +1607,7 @@ defmodule Nx.LinAlg do [ [1.0, 0.0, 0.0], [0.6666666865348816, 1.0, 0.0], - [0.3333333432674408, 2.0, 1.0] + [0.3333333432674408, 1.9999992847442627, 1.0] ], [ [1.0, 0.0, 0.0], @@ -1622,8 +1622,8 @@ defmodule Nx.LinAlg do [ [ [9.0, 8.0, 7.0], - [0.0, -0.3333333432674408, -0.6666666865348816], - [0.0, 0.0, 0.0] + [0.0, -0.33333349227905273, -0.6666669845581055], + [0.0, 0.0, 5.960464477539063e-8] ], [ [-1.0, 0.0, -1.0], @@ -1638,7 +1638,7 @@ defmodule Nx.LinAlg do [ [ [9.0, 8.0, 7.0], - [6.0, 5.0, 4.0], + [6.0, 5.0, 3.999999761581421], [3.0, 2.0, 1.0] ], [ @@ -1676,7 +1676,7 @@ defmodule Nx.LinAlg do [ [1.0, 0.0, 0.0], [0.6666666865348816, 1.0, 0.0], - [0.3333333432674408, 2.0, 1.0] + [0.3333333432674408, 1.9999992847442627, 1.0] ], [ [1.0, 0.0, 0.0], @@ -1692,8 +1692,8 @@ defmodule Nx.LinAlg do [ [ [9.0, 8.0, 7.0], - [0.0, -0.3333333432674408, -0.6666666865348816], - [0.0, 0.0, 0.0] + [0.0, -0.33333349227905273, -0.6666669845581055], + [0.0, 0.0, 5.960464477539063e-8] ], [ [-1.0, 0.0, -1.0], @@ -1709,22 +1709,22 @@ defmodule Nx.LinAlg do ** (ArgumentError) tensor must be a square matrix or a batch of square matrices, got shape: {3, 4} """ def lu(tensor, opts \\ []) do - apply_vectorized(tensor, fn tensor -> - opts = keyword!(opts, eps: 1.0e-10) - %T{type: type, shape: shape} = tensor - - output_type = Nx.Type.to_floating(type) - {p_shape, l_shape, u_shape} = Nx.Shape.lu(shape) - names = List.duplicate(nil, tuple_size(shape)) - - impl!(tensor).lu( - {%{tensor | type: type, shape: p_shape, names: names}, - %{tensor | type: output_type, shape: l_shape, names: names}, - %{tensor | type: output_type, shape: u_shape, names: names}}, - tensor, - opts - ) - end) + opts = keyword!(opts, eps: 1.0e-10) + %T{vectorized_axes: vectorized_axes} = tensor = Nx.to_tensor(tensor) + %T{type: type, shape: shape} = tensor = Nx.devectorize(tensor) + + output_type = Nx.Type.to_floating(type) + {p_shape, l_shape, u_shape} = Nx.Shape.lu(shape) + names = List.duplicate(nil, tuple_size(shape)) + + output = + {%{tensor | type: type, shape: p_shape, names: names}, + %{tensor | type: output_type, shape: l_shape, names: names}, + %{tensor | type: output_type, shape: u_shape, names: names}} + + :lu + |> Nx.Shared.optional([tensor, opts], output, &Nx.LinAlg.LU.lu/2) + |> Nx.vectorize(vectorized_axes) end @doc """ @@ -1892,7 +1892,7 @@ defmodule Nx.LinAlg do ...> ])) #Nx.Tensor< f32 - 48.0 + 47.999996185302734 > iex> Nx.LinAlg.determinant(Nx.tensor([ @@ -1904,7 +1904,7 @@ defmodule Nx.LinAlg do ...> ])) #Nx.Tensor< f32 - 48.0 + 47.999996185302734 > iex> Nx.LinAlg.determinant(Nx.tensor([ diff --git a/nx/lib/nx/lin_alg/lu.ex b/nx/lib/nx/lin_alg/lu.ex new file mode 100644 index 0000000000..82d7d588be --- /dev/null +++ b/nx/lib/nx/lin_alg/lu.ex @@ -0,0 +1,169 @@ +defmodule Nx.LinAlg.LU do + import Nx.Defn + + defn lu(a, opts \\ []) do + opts = keyword!(opts, eps: 1.0e-10) + + vectorized_axes = a.vectorized_axes + + result = + a + |> Nx.revectorize([collapsed_axes: :auto], + target_shape: {Nx.axis_size(a, -2), Nx.axis_size(a, -1)} + ) + |> lu_matrix(opts) + |> revectorize_result(a.shape, vectorized_axes) + + custom_grad(result, [a], fn g -> + lu_grad(result, g) + end) + end + + defnp lu_matrix(a, opts \\ []) do + eps = opts[:eps] + type = Nx.Type.to_floating(a.type) + real_type = Nx.Type.to_real(type) + + {p, a_prime} = lu_validate_and_pivot(a) + # {p, a_prime} = {Nx.eye(a.shape, vectorized_axes: a.vectorized_axes, type: a.type), a} + a_prime = Nx.as_type(a_prime, type) + + {n, _} = Nx.shape(a) + + l = u = Nx.fill(a_prime, 0.0) + [eps, _] = Nx.broadcast_vectors([Nx.as_type(eps, real_type), l]) + + {l, u, _} = + while {l, u, {a_prime, eps, n}}, j <- 0..(n - 1) do + l = Nx.put_slice(l, [j, j], Nx.tensor([[1.0]], type: type)) + [j, i, _] = Nx.broadcast_vectors([j, 0, l]) + + {u, _} = + while {u, {l, a_prime, eps, j, i}}, i <= j do + sum = vector_dot_slice(u[[.., j]], l[i], i) + a_ij = a_prime[i][j] + + value = a_ij - sum + + updated_u = + if Nx.abs(value) < eps do + Nx.put_slice(u, [i, j], Nx.tensor([[0]], type: type)) + else + Nx.put_slice(u, [i, j], Nx.reshape(value, {1, 1})) + end + + {updated_u, {l, a_prime, eps, j, i + 1}} + end + + {l, _} = + while {l, {u, a_prime, eps, j, n, i = j + 1}}, i <= n - 1 do + sum = vector_dot_slice(u[[.., j]], l[i], i) + + a_ij = a_prime[i][j] + u_jj = u[j][j] + + value = + case Nx.Type.complex?(type) do + true -> + if u_jj != 0 do + (a_ij - sum) / u_jj + else + Nx.Constants.nan(real_type) + end + + false -> + cond do + u_jj != 0 -> + (a_ij - sum) / u_jj + + a_ij >= sum -> + Nx.Constants.infinity(real_type) + + true -> + Nx.Constants.neg_infinity(real_type) + end + end + + updated_l = + if Nx.abs(value) < eps do + Nx.put_slice(l, [i, j], Nx.tensor([[0]], type: type)) + else + Nx.put_slice(l, [i, j], Nx.reshape(value, {1, 1})) + end + + {updated_l, {u, a_prime, eps, j, n, i + 1}} + end + + {l, u, {a_prime, eps, n}} + end + + {p, l, u} + end + + deftransformp revectorize_result({p, l, u}, shape, vectorized_axes) do + {p_shape, l_shape, u_shape} = Nx.Shape.lu(shape) + + { + Nx.revectorize(p, vectorized_axes, target_shape: p_shape), + Nx.revectorize(l, vectorized_axes, target_shape: l_shape), + Nx.revectorize(u, vectorized_axes, target_shape: u_shape) + } + end + + defnp vector_dot_slice(u, v, last_idx) do + {n} = Nx.shape(u) + selector = Nx.iota({n}) < last_idx + u = Nx.select(selector, u, 0) + v = Nx.select(selector, v, 0) + Nx.dot(u, v) + end + + defnp lu_validate_and_pivot(t) do + {n, _} = Nx.shape(t) + p = Nx.iota({n}, vectorized_axes: t.vectorized_axes) + + {p, _} = + while {p, t}, i <- 0..(n - 2) do + max_idx = + Nx.select(Nx.iota({n}) < i, 0, Nx.abs(t[[.., i]])) + |> Nx.argmax(axis: 0) + + if max_idx == i do + {p, t} + else + indices = Nx.stack([i, max_idx]) |> Nx.reshape({2, 1}) + updates = Nx.stack([p[max_idx], p[i]]) + + p = Nx.indexed_put(p, indices, updates) + + {p, Nx.take(t, p)} + end + end + + # The comparison order here is deliberate, because if + # we use p == iota instead, we get the inverse/transposed permutation. + permutation = Nx.iota({n, 1}) == Nx.new_axis(p, 0) + + {Nx.as_type(permutation, t.type), t[p]} + end + + defn lu_grad({p, l, u}, {_dp, dl, du}) do + # Definition taken from https://arxiv.org/pdf/2009.10071.pdf + # Equation (3) + + u_h = Nx.LinAlg.adjoint(u) + l_h = Nx.LinAlg.adjoint(l) + p_t = Nx.LinAlg.adjoint(p) + + lh_dl = Nx.dot(l_h, dl) + du_uh = Nx.dot(du, u_h) + + lt_inv = Nx.LinAlg.invert(l_h) + ut_inv = Nx.LinAlg.invert(u_h) + + df = lh_dl |> Nx.tril(k: -1) |> Nx.add(Nx.triu(du_uh)) + da = p_t |> Nx.dot(lt_inv) |> Nx.dot(df) |> Nx.dot(ut_inv) + + [da] + end +end diff --git a/nx/lib/nx/lin_alg/qr.ex b/nx/lib/nx/lin_alg/qr.ex index 9b428ea3cc..a93602a95e 100644 --- a/nx/lib/nx/lin_alg/qr.ex +++ b/nx/lib/nx/lin_alg/qr.ex @@ -16,7 +16,7 @@ defmodule Nx.LinAlg.QR do |> revectorize_result(a.shape, vectorized_axes, opts) custom_grad(result, [a], fn g -> - qr_grad(result, a, g) + qr_grad(result, g) end) end @@ -145,7 +145,7 @@ defmodule Nx.LinAlg.QR do Nx.select(selector, eye - scale * Nx.outer(v, v), eye) end - defn qr_grad({q, r}, _input, {dq, dr}) do + defn qr_grad({q, r}, {dq, dr}) do # Definition taken from https://arxiv.org/pdf/2009.10071.pdf # Equation (3) r_inv = Nx.LinAlg.invert(r) diff --git a/nx/test/nx/lin_alg_test.exs b/nx/test/nx/lin_alg_test.exs index d8c8fe2bb4..4e146cf53c 100644 --- a/nx/test/nx/lin_alg_test.exs +++ b/nx/test/nx/lin_alg_test.exs @@ -941,7 +941,9 @@ defmodule Nx.LinAlgTest do a = Nx.dot(l_prime, [2], [0], u_prime, [1], [0]) assert {p, l, u} = Nx.LinAlg.lu(a) - assert_all_close(p |> Nx.dot([2], [0], l, [1], [0]) |> Nx.dot([2], [0], u, [1], [0]), a) + + actual = p |> Nx.dot([2], [0], l, [1], [0]) |> Nx.dot([2], [0], u, [1], [0]) + assert_all_close(actual, a) key end end From a49d407ee74281e5a232a11c801602e4a6a597ce Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 5 Mar 2025 14:47:06 -0300 Subject: [PATCH 12/36] docs: autodiff docs (#1580) --- .../advanced/automatic_differentiation.livemd | 187 ++++++++++++++++++ nx/mix.exs | 1 + 2 files changed, 188 insertions(+) create mode 100644 nx/guides/advanced/automatic_differentiation.livemd diff --git a/nx/guides/advanced/automatic_differentiation.livemd b/nx/guides/advanced/automatic_differentiation.livemd new file mode 100644 index 0000000000..eb27215153 --- /dev/null +++ b/nx/guides/advanced/automatic_differentiation.livemd @@ -0,0 +1,187 @@ +# Automatic Differentation + +```elixir +Mix.install([ + {:nx, "~> 0.7"} +]) +``` + +## What is Function Differentiation? + +Nx, through the `Nx.Defn.grad/2` and `Nx.Defn.value_and_grad/3` functions allows the user to differentiate functions that were defined through `defn`. +This is really important in Machine Learning settings because, in general, the training process happens through optimization methods that require calculating the gradient of tensor functions. + +Before we get too far ahead of ourselves, let's talk about what is the derivative or the gradient of a function. +In simple terms, the derivative tells us how a function changes at a given point and lets us measure things such as where a function as maximum, +minimum or turning points (for example, where a parabola has its vertex). + +The ability to measure local minima and maxima is what makes them important to optimization problems, because if we can find them, we can solve problems that want +to minimize a given function. For higher dimensional problems, we deal with functions of many variables, and thus we use the gradient, which measures the "derivative" in the axis of each function. +The gradient, then, is a vector that points in the direction where the function changes the most, which leads to the so-called gradient descent method of optimization. + +In the gradient descent method, we take tiny steps following the gradient of the function in order to find the nearest local minimum (which hopefully is either the global minimum or close enough to it). +This is what makes function differentiation so important for Machine Learning. + +Let's take for example the following $f(x)$ and $f'(x)$ scalar function and derivative pair: + +$$ +f(x) = x^3 + x\\ +f'(x) = 3x^2 + 1 +$$ + +We can define a similar function-derivative pair for tensor functions: + +$$ +f(\bold{x}) = \bold{x}^3 + \bold{x}\\ +\nabla f(\bold{x}) = 3 \bold{x} ^ 2 + 1 +$$ + +These may look similar, but the difference is that $f(\bold{x})$ takes in $\bold{x}$ which is a tensor argument. This means that we can have the following argument and results for the function and its gradient: + +$$ +\bold{x} = +\begin{bmatrix} +1 & 1 \\ +2 & 3 \\ +5 & 8 \\ +\end{bmatrix}\\\ +$$ + +$$ +f(\bold{x}) = \bold{x}^3 + \bold{x} = +\begin{bmatrix} +2 & 2 \\ +10 & 30 \\ +130 & 520 +\end{bmatrix} +$$ + +$$ +\nabla f(\bold{x}) = 3 \bold{x} ^ 2 + 1 = +\begin{bmatrix} +4 & 4 \\ +13 & 28 \\ +76 & 193 +\end{bmatrix} +$$ + +## Automatic Differentiation + +Now that we have a general feeling of what a function and its gradient are, we can talk about how Nx can use `defn` to calculate gradients for us. + +In the following code blocks we're going to define the same tensor function as above and then we'll differentiate it only using Nx, without having to write the explicit derivative at all. + +```elixir +defmodule Math do + import Nx.Defn + + defn f(x) do + x ** 3 + x + end + + defn grad_f(x) do + Nx.Defn.grad(x, &f/1) + end +end +``` + +```elixir +x = + Nx.tensor([ + [1, 1], + [2, 3], + [5, 8] + ]) + +{ + Math.f(x), + Math.grad_f(x) +} +``` + +As we can see, we get the results we expected, aside from the type of the grad, which will always be a floating-point number, even if you pass an integer tensor as input. + +Next, we'll using `Nx.Defn.debug_expr` to see what's happening under the hood. + +```elixir +Nx.Defn.debug_expr(&Math.f/1).(x) +``` + +```elixir +Nx.Defn.debug_expr(&Math.grad_f/1).(x) +``` + +If we look closely at the returned `Nx.Defn.Expr` representations for `f` and `grad_f`, we can see that they pretty much translate to the mathematical definitions we had originally. + +This possible because Nx holds onto the symbolic representation of a `defn` function while inside `defn`-land, and thus `Nx.Defn.grad` (and similar) can operate on that symbolic representation to return a new symbolic representation (as seen in the second block). + + + +`Nx.Defn.value_and_grad` can be used to calculate both things at once for us: + +```elixir +Nx.Defn.value_and_grad(x, &Math.f/1) +``` + +And if we use `debug_expr` again, we can see that the symbolic representation is actually both the function and the grad, returned in a tuple: + +```elixir +Nx.Defn.debug_expr(Nx.Defn.value_and_grad(&Math.f/1)).(x) +``` + +Finally, we can talk about functions that receive many arguments, such as the following `add_multiply` function: + +```elixir +add_multiply = fn x, y, z -> + addition = Nx.add(x, y) + Nx.multiply(z, addition) +end +``` + +At first you may think that if we want to differentiate it, we need to wrap it into a single-argument function so that we can differentiate with respect to a specific argument, which would treat other arguments as constants, as we can see below: + +```elixir +x = Nx.tensor([1, 2]) +y = Nx.tensor([3, 4]) +z = Nx.tensor([5, 6]) + +{ + Nx.Defn.grad(x, fn t -> add_multiply.(t, y, z) end), + Nx.Defn.grad(y, fn t -> add_multiply.(x, t, z) end), + Nx.Defn.grad(z, fn t -> add_multiply.(x, y, t) end) +} +``` + +However, Nx is smart enough to deal with multi-valued functions through `Nx.Container` representations such as a tuple or a map: + +```elixir +Nx.Defn.grad({x, y, z}, fn {x, y, z} -> add_multiply.(x, y, z) end) +``` + +Likewise, we can also deal with functions that return multiple values. + +`Nx.Defn.grad` requires us to return a scalar from function (that is, a tensor of shape `{}`). +However, there are instances where we might want to use `value_and_grad` to get out a tuple from our function, while still calculating its gradient. + +For this, we have the `value_and_grad/3` arity, which accepts a transformation argument. + +```elixir +x = + Nx.tensor([ + [1, 1], + [2, 3], + [5, 8] + ]) + +# Notice that the returned values are the 2 addition terms from `Math.f/1` +multi_valued_return_fn = + fn x -> + {Nx.pow(x, 3), x} + end + +transform_fn = fn {x_cubed, x} -> Nx.add(x_cubed, x) end + +{{x_cubed, x}, grad} = Nx.Defn.value_and_grad(x, multi_valued_return_fn, transform_fn) +``` + +If we go back to the start of this livebook, we can see that `grad` holds exactly the result `Math.grad_f`, but now we have access to `x ** 3`, which wasn't accessible before, as originally we could only obtain `x ** 3 + x`. \ No newline at end of file diff --git a/nx/mix.exs b/nx/mix.exs index 90ffa9c51e..fb0af74bcd 100644 --- a/nx/mix.exs +++ b/nx/mix.exs @@ -63,6 +63,7 @@ defmodule Nx.MixProject do "guides/getting_started/quickstart.livemd", "guides/advanced/vectorization.livemd", "guides/advanced/aggregation.livemd", + "guides/advanced/automatic_differentiation.livemd", "guides/exercises/exercises-1-20.livemd" ], skip_undefined_reference_warnings_on: ["CHANGELOG.md"], From f8132e3df40937814f8710474ae2aa18c4c9bba7 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 5 Mar 2025 15:18:27 -0300 Subject: [PATCH 13/36] chore: hide Nx.LinAlg.LU module --- nx/lib/nx/lin_alg/lu.ex | 1 + 1 file changed, 1 insertion(+) diff --git a/nx/lib/nx/lin_alg/lu.ex b/nx/lib/nx/lin_alg/lu.ex index 82d7d588be..57b1b7a325 100644 --- a/nx/lib/nx/lin_alg/lu.ex +++ b/nx/lib/nx/lin_alg/lu.ex @@ -1,4 +1,5 @@ defmodule Nx.LinAlg.LU do + @moduledoc false import Nx.Defn defn lu(a, opts \\ []) do From 68c8acc9c801baade2d62384dccea1af17ca2bc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Thu, 6 Mar 2025 11:39:02 +0100 Subject: [PATCH 14/36] Hide symbols from the NIF shared library (#1589) --- exla/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exla/Makefile b/exla/Makefile index 77b863dd7c..a8d4733389 100644 --- a/exla/Makefile +++ b/exla/Makefile @@ -36,7 +36,7 @@ else CFLAGS += -O3 endif -LDFLAGS = -L$(XLA_EXTENSION_LIB) -lxla_extension -shared +LDFLAGS = -L$(XLA_EXTENSION_LIB) -lxla_extension -shared -fvisibility=hidden ifeq ($(CROSSCOMPILE),) # Interrogate the system for local compilation From feed978c57d5a12c66710280f8437f36adfb21bf Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 6 Mar 2025 15:29:43 -0300 Subject: [PATCH 15/36] feat(exla): take advantage of the new LU impl (#1590) --- exla/lib/exla/defn.ex | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index fbf392862f..3a676b5942 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -546,21 +546,15 @@ defmodule EXLA.Defn do defp cached_recur_operator( :lu, - %T{data: %Expr{args: [{p_expr, l_expr, u_expr}, tensor, _opts]}}, - state, + %T{ + data: %Expr{args: [{p_expr, l_expr, u_expr}, %{type: {type_kind, _}} = tensor, _opts]} + }, + %{client: %{platform: :host}} = state, cache - ) do - %{type: {p_type_kind, _}} = p_expr - %{type: {out_type_kind, _}} = l_expr - - if state.client.platform != :host do - raise ArgumentError, "XLA does not currently support the LU operation on non-host devices" - end - - if p_type_kind == :c or out_type_kind == :c do - raise ArgumentError, "XLA does not currently support the LU operation for complex inputs" - end - + ) + when type_kind != :c do + # We only want to accelerate the LU operation for real inputs on the host device. + # Otherwise, we use the default implementation in Nx. {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() tensor = From 23bf64c2ab474a6efb479ff786c840844217ed10 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 6 Mar 2025 15:32:39 -0300 Subject: [PATCH 16/36] feat: allow explicitly disabling CUDA step (#1588) --- exla/Makefile | 16 +++++++++------- exla/README.md | 2 ++ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/exla/Makefile b/exla/Makefile index a8d4733389..a36df218a8 100644 --- a/exla/Makefile +++ b/exla/Makefile @@ -36,27 +36,29 @@ else CFLAGS += -O3 endif +NVCC := $(CXX) +NVCCFLAGS = $(CFLAGS) LDFLAGS = -L$(XLA_EXTENSION_LIB) -lxla_extension -shared -fvisibility=hidden ifeq ($(CROSSCOMPILE),) # Interrogate the system for local compilation UNAME_S = $(shell uname -s) +ifdef ($(EXLA_CPU_ONLY),) +$(info EXLA_CPU_ONLY is not set, checking for nvcc availability) NVCC_RESULT := $(shell which nvcc 2> /dev/null) NVCC_TEST := $(notdir $(NVCC_RESULT)) -ifeq ($(NVCC_TEST),nvcc) - NVCC := nvcc - NVCCFLAGS += -DCUDA_ENABLED + ifeq ($(NVCC_TEST),nvcc) + NVCC := nvcc + NVCCFLAGS += -DCUDA_ENABLED + endif else - NVCC := $(CXX) - NVCCFLAGS = $(CFLAGS) +$(info EXLA_CPU_ONLY is set, skipping nvcc step) endif else # Determine settings for cross-compiled builds like for Nerves UNAME_S = Linux - NVCC := $(CXX) - NVCCFLAGS = $(CFLAGS) endif ifeq ($(UNAME_S), Darwin) diff --git a/exla/README.md b/exla/README.md index 2ffe144ff7..d837b5ec73 100644 --- a/exla/README.md +++ b/exla/README.md @@ -43,6 +43,8 @@ EXLA relies on the [XLA](https://github.com/elixir-nx/xla) package to provide th * Incompatible protocol buffer versions * Error message: "this file was generated by an older version of protoc which is incompatible with your Protocol Buffer headers". * If you have `protoc` installed on your machine, it may conflict with the `protoc` precompiled inside XLA. Uninstall, unlink, or remove `protoc` from your path to continue. + * Missing CUDA symbols + * In some cases, you might be compiling a CPU-only version of `:xla` in an environment that has CUDA available. For these cases, you can set the `EXLA_CPU_ONLY` environment variable to any value to disable custom CUDA functionality in EXLA. ### Usage with Nerves From 9b8ad38014a09c2234dbea6d106bf25a035bec9a Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 11 Mar 2025 22:55:26 -0300 Subject: [PATCH 17/36] fix(exla): batched eigh (#1591) --- exla/c_src/exla/custom_calls/eigh.h | 64 ++++++++++++++++------ exla/lib/exla/defn.ex | 6 ++- exla/test/exla/nx_linalg_doctest_test.exs | 66 +++++++++++++++++++++++ 3 files changed, 117 insertions(+), 19 deletions(-) diff --git a/exla/c_src/exla/custom_calls/eigh.h b/exla/c_src/exla/custom_calls/eigh.h index 5acc8af664..55cb5adfc1 100644 --- a/exla/c_src/exla/custom_calls/eigh.h +++ b/exla/c_src/exla/custom_calls/eigh.h @@ -2,11 +2,18 @@ #include "Eigen/Eigenvalues" +#include #include +#include +#include template -void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eigenvectors_out, DataType *in, uint64_t m, uint64_t n) { - typedef Eigen::Matrix RowMajorMatrix; +void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, + DataType *eigenvectors_out, + DataType *in, uint64_t m, uint64_t n) { + typedef Eigen::Matrix + RowMajorMatrix; // Map the input matrix Eigen::Map input(in, m, n); @@ -20,14 +27,32 @@ void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eig } // Get the eigenvalues and eigenvectors - Eigen::Matrix eigenvalues = eigensolver.eigenvalues(); + Eigen::Matrix eigenvalues = + eigensolver.eigenvalues(); RowMajorMatrix eigenvectors = eigensolver.eigenvectors(); - // Copy the eigenvalues to the output - std::memcpy(eigenvalues_out, eigenvalues.data(), m * sizeof(DataType)); + // Create a vector of indices and sort it based on eigenvalues in decreasing + // order + std::vector indices(m); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&eigenvalues](int i, int j) { + return std::abs(eigenvalues(i)) > std::abs(eigenvalues(j)); + }); + + // Sort eigenvalues and rearrange eigenvectors + Eigen::Matrix sorted_eigenvalues(m); + RowMajorMatrix sorted_eigenvectors(m, n); + for (int i = 0; i < m; ++i) { + sorted_eigenvalues(i) = eigenvalues(indices[i]); + sorted_eigenvectors.col(i) = eigenvectors.col(indices[i]); + } + + // Copy the sorted eigenvalues to the output + std::memcpy(eigenvalues_out, sorted_eigenvalues.data(), m * sizeof(DataType)); - // Copy the eigenvectors to the output - std::memcpy(eigenvectors_out, eigenvectors.data(), m * n * sizeof(DataType)); + // Copy the sorted eigenvectors to the output + std::memcpy(eigenvectors_out, sorted_eigenvectors.data(), + m * n * sizeof(DataType)); } template @@ -40,18 +65,22 @@ void eigh_cpu_custom_call(void *out[], const void *in[]) { uint64_t num_eigenvectors_dims = dim_sizes[2]; uint64_t *operand_dims_ptr = (uint64_t *)in[2]; - std::vector operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims); + std::vector operand_dims(operand_dims_ptr, + operand_dims_ptr + num_operand_dims); uint64_t *eigenvalues_dims_ptr = (uint64_t *)in[3]; - std::vector eigenvalues_dims(eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims); + std::vector eigenvalues_dims( + eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims); uint64_t *eigenvectors_dims_ptr = (uint64_t *)in[4]; - std::vector eigenvectors_dims(eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims); + std::vector eigenvectors_dims( + eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims); uint64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2]; uint64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1]; - auto leading_dimensions = std::vector(operand_dims.begin(), operand_dims.end() - 2); + auto leading_dimensions = + std::vector(operand_dims.begin(), operand_dims.end() - 2); uint64_t batch_items = 1; for (uint64_t i = 0; i < leading_dimensions.size(); i++) { @@ -61,15 +90,16 @@ void eigh_cpu_custom_call(void *out[], const void *in[]) { DataType *eigenvalues = (DataType *)out[0]; DataType *eigenvectors = (DataType *)out[1]; - uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1] * sizeof(DataType); - uint64_t eigenvectors_stride = eigenvectors_dims[eigenvectors_dims.size() - 1] * eigenvectors_dims[eigenvectors_dims.size() - 2] * sizeof(DataType); - uint64_t inner_stride = m * n * sizeof(DataType); + uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1]; + uint64_t eigenvectors_stride = + eigenvectors_dims[eigenvectors_dims.size() - 1] * + eigenvectors_dims[eigenvectors_dims.size() - 2]; + uint64_t inner_stride = m * n; for (uint64_t i = 0; i < batch_items; i++) { single_matrix_eigh_cpu_custom_call( eigenvalues + i * eigenvalues_stride, - eigenvectors + i * eigenvectors_stride, - operand + i * inner_stride / sizeof(DataType), - m, n); + eigenvectors + i * eigenvectors_stride, operand + i * inner_stride, m, + n); } } \ No newline at end of file diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 3a676b5942..3bb478258d 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -404,14 +404,16 @@ defmodule EXLA.Defn do data: %Expr{ args: [ %{data: %{op: :eigh, args: [tensor, _opts]}}, - {eigenvecs_expr, eigenvals_expr}, + {%{type: {evec_type_kind, _}} = eigenvecs_expr, + %{type: {eval_type_kind, _}} = eigenvals_expr}, _callback ] } }, %{client: %EXLA.Client{platform: :host}, builder: %Function{}} = state, cache - ) do + ) + when evec_type_kind != :c and eval_type_kind != :c do # We match only on platform: :host for MLIR, as we want to support # eigh-on-cpu as a custom call only in this case {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() diff --git a/exla/test/exla/nx_linalg_doctest_test.exs b/exla/test/exla/nx_linalg_doctest_test.exs index 09d60ba8f6..3d49d2c826 100644 --- a/exla/test/exla/nx_linalg_doctest_test.exs +++ b/exla/test/exla/nx_linalg_doctest_test.exs @@ -25,4 +25,70 @@ defmodule EXLA.MLIR.NxLinAlgDoctestTest do @invalid_type_error_doctests ++ [:moduledoc] doctest Nx.LinAlg, except: @excluded_doctests + + describe "eigh" do + test "properties for matrices with different eigenvalues" do + # Generate real Hermitian matrices with different eigenvalues + # from random matrices based on the relation A = Q.Λ.Q^* + # where Λ is the diagonal matrix of eigenvalues and Q is unitary matrix. + + key = Nx.Random.key(System.unique_integer()) + + for type <- [f: 32, c: 64], reduce: key do + key -> + # Unitary matrix from a random matrix + {base, key} = Nx.Random.uniform(key, shape: {2, 3, 3}, type: type) + {q, _} = Nx.LinAlg.qr(base) + + # Different eigenvalues from random values + evals_test = + [100, 10, 1] + |> Enum.map(fn magnitude -> + sign = + if :rand.uniform() - 0.5 > 0 do + 1 + else + -1 + end + + rand = :rand.uniform() * magnitude * 0.1 + magnitude + rand * sign + end) + |> Nx.tensor(type: type) + + evals_test_diag = + evals_test + |> Nx.make_diagonal() + |> Nx.reshape({1, 3, 3}) + |> Nx.tile([2, 1, 1]) + + # Hermitian matrix with different eigenvalues + # using A = A^* = Q^*.Λ.Q. + a = + q + |> Nx.LinAlg.adjoint() + |> Nx.dot([2], [0], evals_test_diag, [1], [0]) + |> Nx.dot([2], [0], q, [1], [0]) + + # Eigenvalues and eigenvectors + assert {evals, evecs} = Nx.LinAlg.eigh(a, eps: 1.0e-8) + + assert_all_close(evals_test, evals[0], atol: 1.0e-8) + assert_all_close(evals_test, evals[1], atol: 1.0e-8) + + evals = + evals + |> Nx.vectorize(:x) + |> Nx.make_diagonal() + |> Nx.devectorize(keep_names: false) + + # Eigenvalue equation + evecs_evals = Nx.dot(evecs, [2], [0], evals, [1], [0]) + a_evecs = Nx.dot(evecs_evals, [2], [0], Nx.LinAlg.adjoint(evecs), [1], [0]) + + assert_all_close(a, a_evecs, atol: 1.0e-8) + key + end + end + end end From d07145359103e02c195548fc26f55a6650a84a8b Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 13 Mar 2025 09:04:05 -0300 Subject: [PATCH 18/36] fix(exla): respect device id when automatic transfers are disabled (#1592) --- exla/config/runtime.exs | 3 ++- exla/lib/exla/backend.ex | 4 ++-- exla/test/exla/backend_test.exs | 27 +++++++++++++++++++++++++++ 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/exla/config/runtime.exs b/exla/config/runtime.exs index 60b8f5102f..80e2d65da1 100644 --- a/exla/config/runtime.exs +++ b/exla/config/runtime.exs @@ -3,7 +3,8 @@ import Config config :exla, :clients, cuda: [platform: :cuda, memory_fraction: 0.8], rocm: [platform: :rocm, memory_fraction: 0.8], - other_host: [platform: :host] + other_host: [platform: :host], + no_automatic_transfers_host: [platform: :host, automatic_transfers: false] config :exla, default_client: String.to_atom(System.get_env("EXLA_TARGET", "host")) diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index 5135e1770f..50747de4bb 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -264,14 +264,14 @@ defmodule EXLA.Backend do def concatenate(out, tensors, axis) do copied = Enum.map(tensors, &Nx.backend_copy(&1, Nx.BinaryBackend)) result = Nx.BinaryBackend.concatenate(out, copied, axis) - Nx.backend_transfer(result, {EXLA.Backend, jit_opts([], tensors)}) + Nx.backend_transfer(result, {EXLA.Backend, jit_opts(tensors, [])}) end @impl true def stack(out, tensors, axis) do copied = Enum.map(tensors, &Nx.backend_copy(&1, Nx.BinaryBackend)) result = Nx.BinaryBackend.stack(out, copied, axis) - Nx.backend_transfer(result, {EXLA.Backend, jit_opts([], tensors)}) + Nx.backend_transfer(result, {EXLA.Backend, jit_opts(tensors, [])}) end @impl true diff --git a/exla/test/exla/backend_test.exs b/exla/test/exla/backend_test.exs index 22e3c60850..7f99c8bc6f 100644 --- a/exla/test/exla/backend_test.exs +++ b/exla/test/exla/backend_test.exs @@ -147,6 +147,33 @@ defmodule EXLA.BackendTest do assert %{device_id: 1, client_name: :other_host} = Nx.reshape(a, {1}).data.buffer end + @tag :multi_device + test "stack and concatenate should end up in the same client" do + t_0 = + Nx.tensor([1], backend: {EXLA.Backend, client: :no_automatic_transfers_host, device_id: 0}) + + t_1 = + Nx.tensor([1], backend: {EXLA.Backend, client: :no_automatic_transfers_host, device_id: 1}) + + t_stack_0 = Nx.stack([t_0, t_1]) + t_concat_0 = Nx.concatenate([t_0, t_1]) + + assert t_stack_0.data.buffer.client_name == :no_automatic_transfers_host + assert t_stack_0.data.buffer.device_id == 1 + + assert t_concat_0.data.buffer.client_name == :no_automatic_transfers_host + assert t_concat_0.data.buffer.device_id == 1 + + t_stack_1 = Nx.stack([t_1, t_0]) + t_concat_1 = Nx.concatenate([t_1, t_0]) + + assert t_stack_1.data.buffer.client_name == :no_automatic_transfers_host + assert t_stack_1.data.buffer.device_id == 0 + + assert t_concat_1.data.buffer.client_name == :no_automatic_transfers_host + assert t_concat_1.data.buffer.device_id == 0 + end + test "Kernel.inspect/2" do t = Nx.tensor([1, 2, 3, 4], backend: EXLA.Backend) From 42861da51da9eb931b62430615aa18361962ca4e Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 17 Mar 2025 03:22:41 -0300 Subject: [PATCH 19/36] test(exla): add more tests for LinAlg functions (#1594) --- exla/test/exla/nx_linalg_doctest_test.exs | 329 +++++++++++++++++++++- 1 file changed, 320 insertions(+), 9 deletions(-) diff --git a/exla/test/exla/nx_linalg_doctest_test.exs b/exla/test/exla/nx_linalg_doctest_test.exs index 3d49d2c826..3eeeb30546 100644 --- a/exla/test/exla/nx_linalg_doctest_test.exs +++ b/exla/test/exla/nx_linalg_doctest_test.exs @@ -1,17 +1,15 @@ -defmodule EXLA.MLIR.NxLinAlgDoctestTest do +defmodule EXLA.NxLinAlgDoctestTest do use EXLA.Case, async: true - - @invalid_type_error_doctests [ - svd: 2, - pinv: 2 - ] + import Nx, only: :sigils @function_clause_error_doctests [ - solve: 2 + solve: 2, + triangular_solve: 3 ] @rounding_error_doctests [ - triangular_solve: 3, + svd: 2, + pinv: 2, eigh: 2, cholesky: 1, least_squares: 3, @@ -22,7 +20,6 @@ defmodule EXLA.MLIR.NxLinAlgDoctestTest do @excluded_doctests @function_clause_error_doctests ++ @rounding_error_doctests ++ - @invalid_type_error_doctests ++ [:moduledoc] doctest Nx.LinAlg, except: @excluded_doctests @@ -91,4 +88,318 @@ defmodule EXLA.MLIR.NxLinAlgDoctestTest do end end end + + describe "cholesky" do + test "property" do + key = Nx.Random.key(System.unique_integer()) + + for _ <- 1..10, type <- [{:f, 32}, {:c, 64}], reduce: key do + key -> + # Generate random L matrix so we can construct + # a factorizable A matrix: + shape = {3, 4, 4} + + {a_prime, key} = Nx.Random.normal(key, 0, 1, shape: shape, type: type) + + a_prime = Nx.add(a_prime, Nx.eye(shape)) + b = Nx.dot(Nx.LinAlg.adjoint(a_prime), [-1], [0], a_prime, [-2], [0]) + + d = Nx.eye(shape) |> Nx.multiply(0.1) + + a = Nx.add(b, d) + + assert l = Nx.LinAlg.cholesky(a) + assert_all_close(Nx.dot(l, [2], [0], Nx.LinAlg.adjoint(l), [1], [0]), a, atol: 1.0e-2) + key + end + end + end + + describe "least_squares" do + test "properties for linear equations" do + key = Nx.Random.key(System.unique_integer()) + + # Calucate linear equations Ax = y by using least-squares solution + for {m, n} <- [{2, 2}, {3, 2}, {4, 3}], reduce: key do + key -> + # Generate x as temporary solution and A as base matrix + {a_base, key} = Nx.Random.randint(key, 1, 10, shape: {m, n}) + {x_temp, key} = Nx.Random.randint(key, 1, 10, shape: {n}) + + # Generate y as base vector by x and A + # to prepare an equation that can be solved exactly + y_base = Nx.dot(a_base, x_temp) + + # Generate y as random noise vector and A as random noise matrix + noise_eps = 1.0e-2 + {a_noise, key} = Nx.Random.uniform(key, 0, noise_eps, shape: {m, n}) + {y_noise, key} = Nx.Random.uniform(key, 0, noise_eps, shape: {m}) + + # Add noise to prepare equations that cannot be solved without approximation, + # such as the least-squares method + a = Nx.add(a_base, a_noise) + y = Nx.add(y_base, y_noise) + + # Calculate least-squares solution to a linear matrix equation Ax = y + x = Nx.LinAlg.least_squares(a, y) + + # Check linear matrix equation + + assert_all_close(Nx.dot(a, x), y, atol: noise_eps * 10) + + key + end + end + end + + describe "determinant" do + test "supports batched matrices" do + two_by_two = Nx.tensor([[[2, 3], [4, 5]], [[6, 3], [4, 8]]]) + assert_equal(Nx.LinAlg.determinant(two_by_two), Nx.tensor([-2.0, 36.0])) + + three_by_three = + Nx.tensor([ + [[1.0, 2.0, 3.0], [1.0, 5.0, 3.0], [7.0, 6.0, 9.0]], + [[5.0, 2.0, 3.0], [8.0, 5.0, 4.0], [3.0, 1.0, -9.0]] + ]) + + assert_equal(Nx.LinAlg.determinant(three_by_three), Nx.tensor([-36.0, -98.0])) + + four_by_four = + Nx.tensor([ + [ + [1.0, 2.0, 3.0, 0.0], + [1.0, 5.0, 3.0, 0.0], + [7.0, 6.0, 9.0, 0.0], + [0.0, -11.0, 2.0, 3.0] + ], + [ + [5.0, 2.0, 3.0, 0.0], + [8.0, 5.0, 4.0, 0.0], + [3.0, 1.0, -9.0, 0.0], + [8.0, 2.0, -4.0, 5.0] + ] + ]) + + assert_all_close(Nx.LinAlg.determinant(four_by_four), Nx.tensor([-108.0, -490])) + end + end + + describe "matrix_power" do + test "supports complex with positive exponent" do + a = ~MAT[ + 1 1i + -1i 1 + ] + + n = 5 + + assert_all_close(Nx.LinAlg.matrix_power(a, n), Nx.multiply(2 ** (n - 1), a)) + end + + test "supports complex with 0 exponent" do + a = ~MAT[ + 1 1i + -1i 1 + ] + + assert_all_close(Nx.LinAlg.matrix_power(a, 0), Nx.eye(Nx.shape(a))) + end + + test "supports complex with negative exponent" do + a = ~MAT[ + 1 -0.5i + 0 0.5 + ] + + result = ~MAT[ + 1 15i + 0 16 + ] + + assert_all_close(Nx.LinAlg.matrix_power(a, -4), result) + end + + test "supports batched matrices" do + a = + Nx.tensor([ + [[5, 3], [1, 2]], + [[9, 0], [4, 7]] + ]) + + result = + Nx.tensor([ + [[161, 126], [42, 35]], + [[729, 0], [772, 343]] + ]) + + assert_all_close(Nx.LinAlg.matrix_power(a, 3), result) + end + end + + describe "lu" do + test "property" do + key = Nx.Random.key(System.unique_integer()) + + for _ <- 1..10, type <- [{:f, 32}, {:c, 64}], reduce: key do + key -> + # Generate random L and U matrices so we can construct + # a factorizable A matrix: + shape = {3, 4, 4} + lower_selector = Nx.iota(shape, axis: 1) |> Nx.greater_equal(Nx.iota(shape, axis: 2)) + upper_selector = Nx.LinAlg.adjoint(lower_selector) + + {l_prime, key} = Nx.Random.uniform(key, 0, 1, shape: shape, type: type) + l_prime = Nx.multiply(l_prime, lower_selector) + + {u_prime, key} = Nx.Random.uniform(key, 0, 1, shape: shape, type: type) + u_prime = Nx.multiply(u_prime, upper_selector) + + a = Nx.dot(l_prime, [2], [0], u_prime, [1], [0]) + + assert {p, l, u} = Nx.LinAlg.lu(a) + + actual = p |> Nx.dot([2], [0], l, [1], [0]) |> Nx.dot([2], [0], u, [1], [0]) + assert_all_close(actual, a) + key + end + end + end + + describe "svd" do + test "finds the singular values of tall matrices" do + t = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]) + + assert {%{type: output_type} = u, %{type: output_type} = s, %{type: output_type} = v} = + Nx.LinAlg.svd(t, max_iter: 1000) + + s_matrix = 0 |> Nx.broadcast({4, 3}) |> Nx.put_diagonal(s) + + assert_all_close(t, u |> Nx.dot(s_matrix) |> Nx.dot(v), atol: 1.0e-2, rtol: 1.0e-2) + + assert_all_close( + u, + Nx.tensor([ + [0.140, 0.824, 0.521, -0.166], + [0.343, 0.426, -0.571, 0.611], + [0.547, 0.0278, -0.422, -0.722], + [0.750, -0.370, 0.472, 0.277] + ]), + atol: 1.0e-3, + rtol: 1.0e-3 + ) + + assert_all_close(Nx.tensor([25.462, 1.291, 0.0]), s, atol: 1.0e-3, rtol: 1.0e-3) + + assert_all_close( + Nx.tensor([ + [0.504, 0.574, 0.644], + [-0.760, -0.057, 0.646], + [0.408, -0.816, 0.408] + ]), + v, + atol: 1.0e-3, + rtol: 1.0e-3 + ) + end + + test "works with batched matrices" do + t = + Nx.tensor([ + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0], [0.0, 4.0, 0.0], [0.0, 0.0, 9.0]] + ]) + + assert {u, s, v} = Nx.LinAlg.svd(t) + + s_matrix = + Nx.stack([ + Nx.broadcast(0, {3, 3}) |> Nx.put_diagonal(s[0]), + Nx.broadcast(0, {3, 3}) |> Nx.put_diagonal(s[1]) + ]) + + reconstructed_t = + u + |> Nx.dot([2], [0], s_matrix, [1], [0]) + |> Nx.dot([2], [0], v, [1], [0]) + + assert_all_close(t, reconstructed_t, atol: 1.0e-2, rtol: 1.0e-2) + end + + test "works with vectorized tensors matrices" do + t = + Nx.tensor([ + [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]], + [[[1.0, 2.0, 3.0], [0.0, 4.0, 0.0], [0.0, 0.0, 9.0]]] + ]) + |> Nx.vectorize(x: 2, y: 1) + + assert {u, s, v} = Nx.LinAlg.svd(t) + + s_matrix = Nx.put_diagonal(Nx.broadcast(0, {3, 3}), s) + + reconstructed_t = + u + |> Nx.dot(s_matrix) + |> Nx.dot(v) + + assert reconstructed_t.vectorized_axes == [x: 2, y: 1] + assert reconstructed_t.shape == {3, 3} + + assert_all_close(Nx.devectorize(t), Nx.devectorize(reconstructed_t), + atol: 1.0e-2, + rtol: 1.0e-2 + ) + end + + test "works with vectors" do + t = Nx.tensor([[-2], [1]]) + + {u, s, vt} = Nx.LinAlg.svd(t) + assert_all_close(u |> Nx.dot(Nx.stack([s, Nx.tensor([0])])) |> Nx.dot(vt), t) + end + + test "works with zero-tensor" do + for {m, n, k} <- [{3, 3, 3}, {3, 4, 3}, {4, 3, 3}] do + t = Nx.broadcast(0, {m, n}) + {u, s, vt} = Nx.LinAlg.svd(t) + assert_all_close(u, Nx.eye({m, m})) + assert_all_close(s, Nx.broadcast(0, {k})) + assert_all_close(vt, Nx.eye({n, n})) + end + end + end + + describe "pinv" do + test "does not raise for 0 singular values" do + key = Nx.Random.key(System.unique_integer()) + + for {m, n} <- [{3, 4}, {3, 3}, {4, 3}], reduce: key do + key -> + # generate u and vt as random orthonormal matrices + {base_u, key} = Nx.Random.uniform(key, 0, 1, shape: {m, m}, type: :f64) + {u, _} = Nx.LinAlg.qr(base_u) + {base_vt, key} = Nx.Random.uniform(key, 0, 1, shape: {n, n}, type: :f64) + {vt, _} = Nx.LinAlg.qr(base_vt) + + # because min(m, n) is always 3, we can use fixed values here + # the important thing is that there's at least one zero in the + # diagonal, to ensure that we're guarding against 0 division + zeros = Nx.broadcast(0, {m, n}) + s = Nx.put_diagonal(zeros, Nx.f64([1, 4, 0])) + s_inv = Nx.put_diagonal(Nx.transpose(zeros), Nx.f64([1, 0.25, 0])) + + # construct t with the given singular values + t = u |> Nx.dot(s) |> Nx.dot(vt) + pinv = Nx.LinAlg.pinv(t) + + # ensure that the returned pinv is close to what we expect + assert_all_close(pinv, Nx.transpose(vt) |> Nx.dot(s_inv) |> Nx.dot(Nx.transpose(u)), + atol: 1.0e-2 + ) + + key + end + end + end end From 3934cdb327d82184ed20a38e2c2ebfad8d004388 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 17 Mar 2025 04:31:00 -0300 Subject: [PATCH 20/36] fix(exla): vectorized gather (#1595) --- exla/lib/exla/defn.ex | 3 +-- exla/test/exla/backend_test.exs | 36 +++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 3bb478258d..fece273800 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -1141,8 +1141,7 @@ defmodule EXLA.Defn do end batch_size = tensor_rank - length(axes) - offset_size = indices_rank - length(axes) - offset_dims = count_up(batch_size, offset_size) + offset_dims = count_up(batch_size, index_vector_dim) Value.gather( tensor, diff --git a/exla/test/exla/backend_test.exs b/exla/test/exla/backend_test.exs index 7f99c8bc6f..e917c68f72 100644 --- a/exla/test/exla/backend_test.exs +++ b/exla/test/exla/backend_test.exs @@ -225,6 +225,42 @@ defmodule EXLA.BackendTest do "1.0-0.0i, 2.0+0.0i, 3.0-0.0i, 0.0+1.0i, 0.0+2.0i" end + test "gather vectorized regression" do + gradients = + Nx.tensor( + [ + [1.0, 1.0], + [-1.0, 1.0], + [1.0, -1.0], + [-1.0, -1.0] + ], + backend: EXLA.Backend + ) + + i = + Nx.tensor([[0, 2, 3, 2, 2, 2, 2, 1]], type: {:u, 16}, backend: EXLA.Backend) + |> Nx.vectorize([:x, :octaves]) + + result = Nx.gather(gradients, Nx.reshape(i, {1})) + + assert_equal( + result, + Nx.tensor([ + [ + [1.0, 1.0], + [1.0, -1.0], + [-1.0, -1.0], + [1.0, -1.0], + [1.0, -1.0], + [1.0, -1.0], + [1.0, -1.0], + [-1.0, 1.0] + ] + ]) + |> Nx.vectorize([:x, :octaves]) + ) + end + describe "quantized types" do test "s2" do tensor = Nx.s2(-1) From abb94ce17b686405fbf462611e54a2fe0d2f8d47 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 18 Mar 2025 04:02:17 -0300 Subject: [PATCH 21/36] feat: Nx.Defn.Graph (#1544) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Valim --- nx/lib/nx/defn/graph.ex | 327 +++++++++++++++++++++++ nx/test/nx/defn/graph_test.exs | 461 +++++++++++++++++++++++++++++++++ 2 files changed, 788 insertions(+) create mode 100644 nx/lib/nx/defn/graph.ex create mode 100644 nx/test/nx/defn/graph_test.exs diff --git a/nx/lib/nx/defn/graph.ex b/nx/lib/nx/defn/graph.ex new file mode 100644 index 0000000000..c11c1ef18b --- /dev/null +++ b/nx/lib/nx/defn/graph.ex @@ -0,0 +1,327 @@ +defmodule Nx.Defn.Graph do + @moduledoc """ + A module for splitting `Nx.Defn.Expr` into stages. + + This module is used to split an `Nx.Defn.Expr` into stages, which are then + executed in a chain. + + `split/2` and `t:Stage.t()` describe how to split + the graph and what's the expected result. + + `run/2` executes the given graph against the provided arguments in a sequential manner. + """ + alias Nx.Defn.Composite + + alias Nx.Tensor, as: T + alias Nx.Defn.Expr + + defmodule Stage do + @typedoc """ + A stage in the graph splitter. + + * `:arguments`: a list of maps that point to the source from which to fetch the corresponding + value for the given argument. + * `:expr`: the expression that represents the computation for the Stage. + * `:id`: the unique id for the Stage. + """ + @type t :: %__MODULE__{ + id: reference(), + expr: %{__struct__: Nx.Defn.Expr}, + arguments: [%{source: {reference() | nil, non_neg_integer()}}] + } + + defstruct [:id, :expr, :arguments] + end + + @doc """ + Splits the received Nx.Defn.Expr into stages given the rules. + + `expr_split_fn` is a function that receives an `Nx.Tensor` containing an `Nx.Defn.Expr` + and returns `true` when a split must happen, and `false` otherwise. + + ## Examples + + iex> expr = Nx.Defn.debug_expr(fn x, y -> x |> Nx.negate() |> Nx.sin() |> Nx.cos() |> Nx.add(y) end).(1, 2) + iex> [stage0, stage1] = Nx.Defn.Graph.split(expr, fn %Nx.Tensor{data: %Nx.Defn.Expr{op: op}} -> op == :cos end) + iex> {out0} = stage0.expr + iex> out0 + #Nx.Tensor< + f32 + \n\ + Nx.Defn.Expr + parameter a:0 s32 + b = negate a s32 + c = sin b f32 + > + iex> stage1.expr + #Nx.Tensor< + f32 + \n\ + Nx.Defn.Expr + parameter a:1 f32 + parameter c:0 s32 + b = cos a f32 + d = add b, c f32 + > + """ + def split(expr, expr_split_fn) when is_function(expr_split_fn, 1) do + {chain, _, _} = __split__(expr, expr_split_fn) + chain + end + + @doc """ + Executes the stage chain with the given arguments. + """ + def run(chain, args) do + scope = + Enum.with_index(args, fn arg, idx -> {{nil, idx}, arg} end) + |> Map.new() + + {result, _scope} = + Enum.reduce(chain, {nil, scope}, fn stage, {_result, scope} -> + %{id: id, expr: expr, arguments: arguments} = stage + + args = + Enum.map(arguments, fn %{source: source} -> + Map.fetch!(scope, source) + end) + + case Nx.Defn.jit_apply(fn _ -> expr end, [List.to_tuple(args)]) do + %T{} = tensor -> + {tensor, Map.put(scope, {id, 0}, tensor)} + + tuple -> + {_idx, scope} = + tuple + |> Tuple.to_list() + |> Enum.reduce({0, scope}, fn tensor, {idx, scope} -> + {idx + 1, Map.put(scope, {id, idx}, tensor)} + end) + + {tuple, scope} + end + end) + + result + end + + @doc false + def __split__(expr, expr_split_fn) do + # state.expression_chain is a reverse accumulation of the stages and + # snapshots of the state at each one so that we can properly remap parameters for each stage. + state = %{ + expression_chain: [], + nodes_to_replace: %{}, + expr_split_fn: expr_split_fn, + # args is a map of id -> {stage_id, output_container_position} + args: %{} + } + + cache = %{} + {expr, {cache, state}} = composite_eval(expr, state, cache) + + expr_chain = + Enum.reduce( + [{make_ref(), expr, state.nodes_to_replace} | state.expression_chain], + [], + fn {id, expr, nodes_to_replace}, acc -> + # TO-DO: we need to also do a pass to avoid recalculating results that have been previously calculated. + # For example: + # x = arg0 + arg1 + # y = arg0 - arg1 + # z = x + y + # ----- + # w = dot(z, arg1) + # y + w <- here, we currently have to recalculate y given that only z, arg0 and arg1 will be passed as arguments. + # ideally, we should also pass y as a value to avoid recalculating it. + # We might be able to calculate this in the first traversal somehow. + + {expr, %{used_args: used_args}} = + composite_rewrite_subtree( + expr, + %{state | nodes_to_replace: nodes_to_replace} + ) + + arg_remapping = + used_args + |> Enum.sort_by(fn {_id, %T{data: %Expr{op: :parameter, args: [idx]}}} -> idx end) + |> Enum.with_index(fn + {id, expr}, idx -> + {id, put_in(expr.data.args, [idx])} + end) + |> Map.new() + + {expr, _} = + composite_rewrite_subtree(expr, %{state | nodes_to_replace: arg_remapping}) + + arguments = + arg_remapping + |> Enum.map(fn {_id, arg_expr} -> + id = arg_expr.data.id + [idx] = arg_expr.data.args + source = Map.fetch!(state.args, id) + {idx, %{source: source}} + end) + |> Enum.sort_by(fn {idx, _} -> idx end) + |> Enum.map(fn {_, arg} -> arg end) + + [ + %Stage{ + id: id, + expr: expr, + arguments: arguments + } + | acc + ] + end + ) + + {expr_chain, cache, Map.delete(state, :expression_chain)} + end + + defp composite_eval(expr, state, cache) do + Composite.traverse(expr, {cache, state}, &eval/2) + end + + defp eval(%T{data: %Expr{id: id, op: op}} = ans, {cache, state}) do + case {cache, state.nodes_to_replace} do + {_, %{^id => res}} -> + # Replace the node with the corresponding parameter + {res, {Map.put(cache, id, res), state}} + + {%{^id => res}, _} -> + {res, {cache, state}} + + _ -> + if state.expr_split_fn.(ans) do + split_expr(ans, {cache, state}) + else + eval_apply(op, ans, {cache, state}) + end + end + end + + defp eval(other, {cache, state}) do + {other, {cache, state}} + end + + defp split_expr(expr, {cache, state}) do + {args, {cache, state}} = Nx.Defn.Tree.apply_args(expr, {cache, state}, &eval/2) + # We need to save this so that each previous stage + # isn't affected by following ones + nodes_to_replace = state.nodes_to_replace + + stage_id = make_ref() + + {args, {tensor_args, _out_position, state}} = + Enum.map_reduce(args, {[], 0, state}, fn + %T{} = expr, {tensor_args, out_position, state} -> + arg = Expr.parameter(expr, map_size(state.args)) + + state = %{ + state + | args: Map.put(state.args, arg.data.id, {stage_id, out_position}), + nodes_to_replace: Map.put(state.nodes_to_replace, expr.data.id, arg) + } + + {arg, {[expr | tensor_args], out_position + 1, state}} + + non_tensor_arg, acc -> + {non_tensor_arg, acc} + end) + + new_expr = put_in(expr.data.args, args) + + state = + update_in( + state.expression_chain, + &[ + {stage_id, List.to_tuple(Enum.reverse(tensor_args)), nodes_to_replace} + | &1 + ] + ) + + cache = Map.put(cache, new_expr.data.id, new_expr) + + {new_expr, {cache, state}} + end + + defp eval_apply(:parameter, %T{data: %Expr{id: id, args: [idx]}} = expr, {cache, state}) do + state = put_in(state.args[id], {nil, idx}) + {expr, {Map.put(cache, id, expr), state}} + end + + defp eval_apply(:elem, %T{data: %Expr{id: id, args: [tuple, i]}}, {cache, state}) do + {tuple, cache} = composite_eval(tuple, state, cache) + res = elem(tuple, i) + {res, {Map.put(cache, id, res), state}} + end + + defp eval_apply(_op, %T{data: %Expr{id: id}} = ans, {cache, state}) do + {args, {cache, state}} = Nx.Defn.Tree.apply_args(ans, {cache, state}, &eval/2) + ans = put_in(ans.data.args, args) + {ans, {Map.put(cache, id, ans), state}} + end + + defp composite_rewrite_subtree(container, state, acc \\ %{used_args: %{}}) + + defp composite_rewrite_subtree(container, state, acc) when is_list(container) do + Enum.map_reduce(container, acc, fn + %T{} = arg, acc -> + composite_rewrite_subtree(arg, state, acc) + + arg, acc when is_list(arg) -> + composite_rewrite_subtree(arg, state, acc) + + arg, acc -> + {arg, acc} + end) + end + + defp composite_rewrite_subtree(container, state, acc) do + Composite.traverse(container, acc, &rewrite_subtree(&1, state, &2)) + end + + defp rewrite_subtree(%T{data: %Expr{id: id, op: :parameter}} = expr, state, acc) do + case state.nodes_to_replace do + %{^id => res} -> + {res, put_in(acc.used_args[id], res)} + + _ -> + {expr, put_in(acc.used_args[id], expr)} + end + end + + defp rewrite_subtree( + %T{data: %Expr{op: :optional, id: id, args: [call, subexpr, fun]}} = expr, + state, + acc + ) do + case state.nodes_to_replace do + %{^id => res} -> + {res, put_in(acc.used_args[id], res)} + + _ -> + {call, acc} = rewrite_subtree(call, state, acc) + # `subexpr` is hermetic, in the sense that it is a self-contained scope + # from which the arguments always come from `call`, so we can + # keep it as is. + + {put_in(expr.data.args, [call, subexpr, fun]), acc} + end + end + + defp rewrite_subtree(%T{data: %Expr{id: id, args: args}} = expr, state, acc) do + case state.nodes_to_replace do + %{^id => res} -> + # nodes_to_replace always contains a param + {res, put_in(acc.used_args[id], res)} + + _ -> + {args, acc} = composite_rewrite_subtree(args, state, acc) + {put_in(expr.data.args, args), acc} + end + end + + defp rewrite_subtree(other, _, acc), do: {other, acc} +end diff --git a/nx/test/nx/defn/graph_test.exs b/nx/test/nx/defn/graph_test.exs new file mode 100644 index 0000000000..144b413873 --- /dev/null +++ b/nx/test/nx/defn/graph_test.exs @@ -0,0 +1,461 @@ +defmodule Nx.Defn.GraphTest do + use ExUnit.Case, async: true + + alias Nx.Defn.Graph + alias Nx.Defn.Graph.Stage + + alias Nx.Tensor, as: T + alias Nx.Defn.Expr + + doctest Nx.Defn.Graph + + describe "traverse/1" do + test "simple expression with 1 split and no common nodes" do + expr = + Nx.Defn.debug_expr(fn arg0, arg1 -> + x = Nx.add(arg0, arg1) + y = Nx.subtract(arg0, arg1) + z = Nx.dot(x, y) + w = Nx.multiply(z, 2) + Nx.divide(w, 4) + end).(Nx.tensor([1, 2]), Nx.tensor([3, 4])) + + split_fn = fn + %T{data: %Expr{op: :dot}} -> true + _ -> false + end + + {chain, cache, state} = Graph.__split__(expr, split_fn) + + assert [ + %Stage{ + id: stage_0_id, + expr: stage_0_expr, + arguments: stage_0_arguments + }, + %Stage{ + id: _stage_1_id, + expr: stage_1_expr, + arguments: stage_1_arguments + } + ] = chain + + assert [%{source: {nil, 0}}, %{source: {nil, 1}}] == stage_0_arguments + + assert [{2, arg_2_original_node_id, arg_2_id}, {3, arg_3_original_node_id, arg_3_id}] = + state.nodes_to_replace + |> Enum.map(fn {original_node_id, + %T{data: %Expr{id: id, op: :parameter, args: [idx]}}} -> + {idx, original_node_id, id} + end) + |> Enum.sort() + + # ensure that arg2 and arg3 map to the correct stage and output container position + assert [%{source: {stage_0_id, 0}}, %{source: {stage_0_id, 1}}] == stage_1_arguments + + # ensure that arg2 and arg3 are replacing the correct nodes + {_dot_node_id, %T{data: %Expr{args: [dot_arg_0, _, _, dot_arg_1, _, _]}}} = + Enum.find(cache, fn + {_, %T{data: %Expr{op: :dot}}} -> true + _ -> false + end) + + assert dot_arg_0.data.id == arg_2_id + assert dot_arg_1.data.id == arg_3_id + + # ensure that the output of the first stage contains the original nodes from dot(x, y) + # also assert on the rough shape for the expression + assert {%T{data: %Expr{id: ^arg_2_original_node_id}} = left, + %T{data: %Expr{id: ^arg_3_original_node_id}} = right} = stage_0_expr + + assert %T{ + data: %Expr{ + op: :add, + args: [ + %T{data: %Expr{op: :parameter, args: [0]}}, + %T{data: %Expr{op: :parameter, args: [1]}} + ] + } + } = left + + assert %T{ + data: %Expr{ + op: :subtract, + args: [ + %T{data: %Expr{op: :parameter, args: [0]}}, + %T{data: %Expr{op: :parameter, args: [1]}} + ] + } + } = right + + assert %T{ + data: %Expr{ + op: :divide, + args: [ + %T{ + data: %Expr{ + op: :multiply, + args: [ + %T{data: %Expr{op: :constant, args: [2]}}, + %T{ + data: %Expr{ + op: :dot, + args: [ + %T{data: %Expr{op: :parameter, args: [0]}}, + [0], + [], + %T{data: %Expr{op: :parameter, args: [1]}}, + [0], + [] + ] + } + } + ] + } + }, + %T{data: %Expr{op: :constant, args: [4]}} + ] + } + } = stage_1_expr + end + + test "expression with 2 splits, common nodes and argument separation" do + expr = + Nx.Defn.debug_expr(fn arg0, arg1, arg2 -> + x = Nx.add(arg0, arg1) + y = Nx.subtract(arg0, arg1) + z = Nx.dot(x, y) + w = Nx.multiply(z, 2) + a = Nx.sum(w) + + a + |> Nx.add(w) + |> Nx.subtract(arg2) + end).(Nx.tensor([[1, 2]]), Nx.tensor([[3], [4]]), Nx.tensor([5, 6])) + + split_fn = fn + %T{data: %Expr{op: :dot}} -> true + %T{data: %Expr{op: :sum}} -> true + _ -> false + end + + {chain, cache, state} = Graph.__split__(expr, split_fn) + + assert [ + %Stage{ + id: stage_0_id, + expr: stage_0_expr, + arguments: stage_0_arguments + }, + %Stage{ + id: stage_1_id, + expr: stage_1_expr, + arguments: stage_1_arguments + }, + %Stage{ + id: _stage_2_id, + expr: stage_2_expr, + arguments: stage_2_arguments + } + ] = chain + + assert [%{source: {nil, 0}}, %{source: {nil, 1}}] == stage_0_arguments + + assert map_size(state.args) == 6 + + original_args = + Enum.reduce(state.args, [], fn {id, _}, acc -> + if node = cache[id] do + [{hd(node.data.args), id} | acc] + else + acc + end + end) + |> Enum.sort() + |> Enum.map(fn {_, id} -> id end) + + [arg_0_id, arg_1_id, arg_2_id] = original_args + + assert [ + {2, arg_3_original_node_id, arg_3_id}, + {3, arg_4_original_node_id, arg_4_id}, + {4, arg_5_original_node_id, arg_5_id} + ] = + state.nodes_to_replace + |> Enum.map(fn {original_node_id, + %T{data: %Expr{id: id, op: :parameter, args: [idx]}}} -> + {idx, original_node_id, id} + end) + |> Enum.sort() + + assert arg_3_id not in original_args + assert arg_4_id not in original_args + assert arg_5_id not in original_args + + # ensure that arg3 and arg4 map to the correct stage and output container position + assert [%{source: {stage_0_id, 0}}, %{source: {stage_0_id, 1}}] == stage_1_arguments + + # ensure that arg3 and arg4 are replacing the correct nodes + {_dot_node_id, %T{data: %Expr{args: [dot_arg_0, _, _, dot_arg_1, _, _]}}} = + Enum.find(cache, fn + {_, %T{data: %Expr{op: :dot}}} -> true + _ -> false + end) + + assert dot_arg_0.data.id == arg_3_id + assert dot_arg_1.data.id == arg_4_id + + # ensure that the output of the first stage contains the original nodes from dot(x, y) + # also assert on the rough shape for the expression + assert {%T{data: %Expr{id: ^arg_3_original_node_id}} = left, + %T{data: %Expr{id: ^arg_4_original_node_id}} = right} = stage_0_expr + + assert %T{ + data: %Expr{ + op: :add, + args: [ + %T{data: %Expr{id: ^arg_0_id, op: :parameter, args: [0]}}, + %T{data: %Expr{id: ^arg_1_id, op: :parameter, args: [1]}} + ] + } + } = left + + assert %T{ + data: %Expr{ + op: :subtract, + args: [ + %T{data: %Expr{id: ^arg_0_id, op: :parameter, args: [0]}}, + %T{data: %Expr{id: ^arg_1_id, op: :parameter, args: [1]}} + ] + } + } = right + + assert {%T{ + data: %Expr{ + id: ^arg_5_original_node_id, + op: :multiply, + args: [ + %T{data: %Expr{op: :constant, args: [2]}}, + %T{ + data: %Expr{ + op: :dot, + args: [ + %T{data: %Expr{op: :parameter, args: [0]}}, + [1], + [], + %T{data: %Expr{op: :parameter, args: [1]}}, + [0], + [] + ] + } + } + ] + } + }} = stage_1_expr + + assert %T{data: %Expr{op: :subtract, args: [c, d]}} = stage_2_expr + assert %T{data: %Expr{op: :add, args: [b, a]}} = c + assert %T{data: %Expr{id: ^arg_2_id, op: :parameter, args: [0]}} = d + assert %T{data: %Expr{op: :sum, args: [^a, [axes: nil, keep_axes: false]]}} = b + assert %T{data: %Expr{id: ^arg_5_id, op: :parameter, args: [1]}} = a + + assert [%{source: {nil, 2}}, %{source: {stage_1_id, 0}}] == stage_2_arguments + end + + test "supports optional callbacks" do + arg0 = + Nx.u8([ + [1, 0, 1], + [1, 1, 1] + ]) + + expr = + Nx.Defn.debug_expr(fn a, b -> + x = Nx.add(b, 1) + y = Nx.sum(x, axes: [1]) + z = Nx.logical_not(y) + Nx.subtract(z, a) + end).(1, arg0) + + split_fn = fn + %T{data: %Expr{op: :sum}} -> true + _ -> false + end + + assert [%Stage{} = stage_0, %Stage{} = stage_1] = Graph.split(expr, split_fn) + + assert stage_0.arguments == [%{source: {nil, 1}}] + assert stage_1.arguments == [%{source: {nil, 0}}, %{source: {stage_0.id, 0}}] + + assert %T{data: %Expr{op: :subtract, args: [c, d]}} = stage_1.expr + assert %T{data: %Expr{op: :optional, args: [call, subexpr, _fun]}} = c + + assert %T{data: %Expr{id: arg_0_id, op: :parameter, args: [0]}} = d + + assert %T{data: %Expr{op: :logical_not, args: [b]}} = call + assert %T{data: %Expr{op: :sum, args: [a, [axes: [1], keep_axes: false]]}} = b + assert %T{data: %Expr{id: arg_1_id, op: :parameter, args: [1]}} = a + + assert %T{ + data: %Expr{ + op: :equal, + args: [ + %T{data: %Expr{id: subexpr_arg_0_id, op: :parameter, args: [0]}}, + %T{data: %Expr{op: :constant, args: [0]}} + ] + } + } = subexpr + + # ensure subexpr is hermetic + assert subexpr_arg_0_id != arg_0_id + assert subexpr_arg_0_id != arg_1_id + end + + test "supports in-line anonymous functions" do + arg0 = + Nx.u8([ + [1, 0, 1], + [1, 1, 1] + ]) + + expr = + Nx.Defn.debug_expr(fn a, b -> + x = Nx.add(b, 1) + y = Nx.sum(x, axes: [1]) + f = fn a -> Nx.equal(a, 0) end + z = f.(y) + Nx.subtract(z, a) + end).(1, arg0) + + split_fn = fn + %T{data: %Expr{op: :sum}} -> true + _ -> false + end + + assert [%Stage{} = stage_0, %Stage{} = stage_1] = Graph.split(expr, split_fn) + + assert [%{source: {nil, 1}}] == stage_0.arguments + + assert [%{source: {nil, 0}}, %{source: {stage_0.id, 0}}] == stage_1.arguments + assert %T{data: %Expr{op: :subtract, args: [c, d]}} = stage_1.expr + + assert %T{ + data: %Expr{ + op: :equal, + args: [ + left, + %T{data: %Expr{op: :constant, args: [0]}} + ] + } + } = c + + assert %T{data: %Expr{op: :parameter, args: [0]}} = d + + assert %T{data: %Expr{op: :sum, args: [a, [axes: [1], keep_axes: false]]}} = left + assert %T{data: %Expr{op: :parameter, args: [1]}} = a + end + end + + describe "run/2" do + test "executes the stages chain and returns the correct result" do + function = fn arg0, arg1 -> + # root + x = Nx.multiply(arg0, arg1) |> Nx.Defn.Expr.metadata(%{split: true}) + + # left side + w_left = Nx.multiply(x, arg1) |> Nx.Defn.Expr.metadata(%{split: true}) + + # right side + w_right = Nx.divide(x, arg1) |> Nx.Defn.Expr.metadata(%{split: true}) + + # merge + Nx.add(w_right, w_left) + end + + args = [Nx.tensor([1, 2]), Nx.tensor([3, 4])] + + # This is used in the final assertion of this test + expected_result = Nx.Defn.jit_apply(function, args) + + expr = apply(Nx.Defn.debug_expr(function), args) + + split_fn = fn + %T{data: %Expr{op: :metadata, args: [_expr, %{split: true}]}} -> true + _ -> false + end + + chain = Graph.split(expr, split_fn) + + assert [root, right, left, merge] = chain + + assert {%T{data: %Expr{op: :multiply, args: [arg0, arg1]}}} = root.expr + assert %T{data: %Expr{op: :parameter, args: [0]}} = arg0 + assert %T{data: %Expr{op: :parameter, args: [1]}} = arg1 + + # left should depend on exactly the same parameters as the root, as it's pulling from + # the global scope + assert {%T{data: %Expr{op: :multiply, args: [x, arg1_left]}}} = left.expr + + assert %T{ + data: %Expr{ + op: :metadata, + args: [ + %T{data: %Expr{op: :parameter, args: [1]}}, + %{split: true} + ] + } + } = x + + assert %T{data: %Expr{op: :parameter, args: [0]}} = arg1_left + + assert Enum.fetch!(left.arguments, 0).source == {nil, 1} + assert Enum.fetch!(left.arguments, 1).source == {root.id, 0} + + # right should depend on the result of the root and on arg1, but arg1 will be reindexed + # we should assert that the argument source for arg1_right is correct + assert {%T{data: %Expr{op: :divide, args: [x, arg1_right]}}} = right.expr + + assert %T{ + data: %Expr{ + op: :metadata, + args: [ + %T{data: %Expr{op: :parameter, args: [1]}}, + %{split: true} + ] + } + } = x + + assert %T{data: %Expr{op: :parameter, args: [0]}} = arg1_right + + assert Enum.fetch!(right.arguments, 0).source == {nil, 1} + assert Enum.fetch!(right.arguments, 1).source == {root.id, 0} + + assert %T{data: %Expr{op: :add, args: [w_right, w_left]}} = merge.expr + + assert %T{ + data: %Expr{ + op: :metadata, + args: [ + %T{data: %Expr{op: :parameter, args: [0]}}, + %{split: true} + ] + } + } = w_right + + assert %T{ + data: %Expr{ + op: :metadata, + args: [ + %T{data: %Expr{op: :parameter, args: [1]}}, + %{split: true} + ] + } + } = w_left + + assert Enum.fetch!(merge.arguments, 0).source == {right.id, 0} + assert Enum.fetch!(merge.arguments, 1).source == {left.id, 0} + + assert Graph.run(chain, args) == expected_result + end + end +end From aa8eac10a7e79870abd0976a546e90ca5007496c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Lenzi?= <61877861+TomasPegado@users.noreply.github.com> Date: Wed, 19 Mar 2025 17:12:18 -0300 Subject: [PATCH 22/36] Fix(exla): triangular_solve with batched matrix input (#1596) --- exla/lib/exla/defn.ex | 31 ++++---- exla/test/exla/nx_linalg_doctest_test.exs | 97 ++++++++++++++++++++++- 2 files changed, 112 insertions(+), 16 deletions(-) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index fece273800..55b13ee9b0 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -782,26 +782,27 @@ defmodule EXLA.Defn do lower = Keyword.fetch!(opts, :lower) transform = Keyword.fetch!(opts, :transform_a) - case Value.get_typespec(b).shape do - {dim} -> - b_shape = {dim, 1} + a_shape = Value.get_typespec(a).shape + b_shape = Value.get_typespec(b).shape - b = - b - |> to_type(type) - |> Value.reshape(Typespec.tensor(type, b_shape)) + if tuple_size(a_shape) > tuple_size(b_shape) do + b_shape = Tuple.insert_at(b_shape, tuple_size(b_shape), 1) - typespec = Typespec.tensor(type, b_shape) + b = + b + |> to_type(type) + |> Value.reshape(Typespec.tensor(type, b_shape)) - to_type(a, type) - |> Value.triangular_solve(b, left_side, lower, transform, typespec) - |> Value.reshape(Typespec.tensor(type, ans.shape)) + typespec = Typespec.tensor(type, b_shape) - _ -> - typespec = Typespec.tensor(type, ans.shape) + to_type(a, type) + |> Value.triangular_solve(b, left_side, lower, transform, typespec) + |> Value.reshape(Typespec.tensor(type, ans.shape)) + else + typespec = Typespec.tensor(type, ans.shape) - to_type(a, type) - |> Value.triangular_solve(to_type(b, type), left_side, lower, transform, typespec) + to_type(a, type) + |> Value.triangular_solve(to_type(b, type), left_side, lower, transform, typespec) end end diff --git a/exla/test/exla/nx_linalg_doctest_test.exs b/exla/test/exla/nx_linalg_doctest_test.exs index 3eeeb30546..38879e4048 100644 --- a/exla/test/exla/nx_linalg_doctest_test.exs +++ b/exla/test/exla/nx_linalg_doctest_test.exs @@ -2,6 +2,11 @@ defmodule EXLA.NxLinAlgDoctestTest do use EXLA.Case, async: true import Nx, only: :sigils + setup do + Nx.default_backend(EXLA.Backend) + :ok + end + @function_clause_error_doctests [ solve: 2, triangular_solve: 3 @@ -15,7 +20,8 @@ defmodule EXLA.NxLinAlgDoctestTest do least_squares: 3, determinant: 1, matrix_power: 2, - lu: 2 + lu: 2, + qr: 2 ] @excluded_doctests @function_clause_error_doctests ++ @@ -402,4 +408,93 @@ defmodule EXLA.NxLinAlgDoctestTest do end end end + + describe "triangular_solve" do + test "works with batched input" do + a = + Nx.tensor([ + [ + [-1, 0, 0], + [1, 1, 0], + [1, 1, 1] + ], + [ + [2, 0, 0], + [4, -2, 0], + [-5, 1, 3] + ] + ]) + + b = + Nx.tensor([ + [1.0, 2.0, 3.0], + [6, 10, 1] + ]) + + assert_equal(Nx.dot(a, [2], [0], Nx.LinAlg.triangular_solve(a, b), [1], [0]), b) + end + + test "works with B that has more columns than rows" do + a = + Nx.tensor( + [ + [1, 0], + [1, 1] + ], + type: :f64 + ) + + b = + Nx.tensor( + [ + [1, 1, 1], + [2, 2, 2] + ], + type: :f64 + ) + + x = Nx.LinAlg.triangular_solve(a, b) + + assert_equal( + x, + Nx.tensor( + [ + [1, 1, 1], + [1, 1, 1] + ], + type: :f64 + ) + ) + end + + test "property" do + a = Nx.tensor([[1, 0, 0], [1, 1, 0], [0, 1, 1]]) + b = Nx.tensor([[1.0, 2.0, 3.0], [2.0, 2.0, 4.0], [2.0, 0.0, 1.0]]) + assert_equal(Nx.dot(a, Nx.LinAlg.triangular_solve(a, b)), b) + + upper = Nx.transpose(a) + assert_equal(Nx.dot(upper, Nx.LinAlg.triangular_solve(upper, b, lower: false)), b) + + assert_equal( + Nx.dot( + Nx.LinAlg.triangular_solve(upper, b, left_side: false, lower: false), + upper + ), + b + ) + + assert_equal( + Nx.LinAlg.triangular_solve(a, b, transform_a: :transpose), + Nx.LinAlg.triangular_solve(upper, b, lower: false) + ) + + assert_equal( + Nx.dot( + Nx.transpose(a), + Nx.LinAlg.triangular_solve(a, b, transform_a: :transpose) + ), + b + ) + end + end end From 6adeb54da986f405a17a454c4019260ef42ecf9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Thu, 20 Mar 2025 10:47:41 +0100 Subject: [PATCH 23/36] Clarify composite docs Closes #1597. --- nx/lib/nx/defn/composite.ex | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nx/lib/nx/defn/composite.ex b/nx/lib/nx/defn/composite.ex index 42d51486b4..9f38b385b4 100644 --- a/nx/lib/nx/defn/composite.ex +++ b/nx/lib/nx/defn/composite.ex @@ -7,10 +7,10 @@ defmodule Nx.Defn.Composite do Numerical values, such as integers, floats, and complex numbers are not normalized before hand. Use `Nx.to_tensor/1` to do so. - The functions in this module can be used both inside and outside `defn`. - Note that, when a value is given to `defn`, it is first converted to - tensors and containers via `Nx.LazyContainer`. Inside `defn`, there are - no lazy containers, only containers. + The functions in this module are invoked outside of `defn` or inside + `deftransform`. Note that, when a value is given to `defn`, it is + first converted to tensors and containers via `Nx.LazyContainer`. + Inside `defn`, there are no lazy containers, only containers. """ alias Nx.Tensor, as: T From 7d18684613ca082da3c0df82dd1a19043b1ba12d Mon Sep 17 00:00:00 2001 From: TomasPegado Date: Wed, 26 Mar 2025 16:09:07 -0300 Subject: [PATCH 24/36] docs: getting started section --- nx/guides/getting_started/broadcast.livemd | 153 ++++++++++++++++++ .../numerical_definitions.livemd | 121 ++++++++++++++ nx/mix.exs | 2 + 3 files changed, 276 insertions(+) create mode 100644 nx/guides/getting_started/broadcast.livemd create mode 100644 nx/guides/getting_started/numerical_definitions.livemd diff --git a/nx/guides/getting_started/broadcast.livemd b/nx/guides/getting_started/broadcast.livemd new file mode 100644 index 0000000000..ecd46f1b7a --- /dev/null +++ b/nx/guides/getting_started/broadcast.livemd @@ -0,0 +1,153 @@ +# Broadcasts + +Often, the dimensions of tensors in an operator don't match. +For example, you might want to subtract a `1` from every +element of a `{2, 2}` tensor, like this: + +$$ +\begin{bmatrix} + 1 & 2 \\\\ + 3 & 4 +\end{bmatrix} - 1 = +\begin{bmatrix} + 0 & 1 \\\\ + 2 & 3 +\end{bmatrix} +$$ + +Mathematically, it's the same as this: + +$$ +\begin{bmatrix} + 1 & 2 \\\\ + 3 & 4 +\end{bmatrix} - +\begin{bmatrix} + 1 & 1 \\\\ + 1 & 1 +\end{bmatrix} = +\begin{bmatrix} + 0 & 1 \\\\ + 2 & 3 +\end{bmatrix} +$$ + +That means we need a way to convert `1` to a `{2, 2}` tensor. +`Nx.broadcast/2` solves that problem. This function takes +a tensor or a scalar and a shape. + +```elixir +Mix.install([ + {:nx, "~> 0.5"} +]) + + +Nx.broadcast(1, {2, 2}) +``` + +This broadcast takes the scalar `1` and translates it +to a compatible shape by copying it. Sometimes, it's easier +to provide a tensor as the second argument, and let `broadcast/2` +extract its shape: + +```elixir +tensor = Nx.tensor([[1, 2], [3, 4]]) +Nx.broadcast(1, tensor) +``` + +The code broadcasts `1` to the shape of `tensor`. In many operators +and functions, the broadcast happens automatically: + +```elixir +Nx.subtract(tensor, 1) +``` + +This result is possible because Nx broadcasts _both tensors_ +in `subtract/2` to compatible shapes. That means you can provide +scalar values as either argument: + +```elixir +Nx.subtract(10, tensor) +``` + +Or subtract a row or column. Mathematically, it would look like this: + +$$ +\begin{bmatrix} + 1 & 2 \\\\ + 3 & 4 +\end{bmatrix} - +\begin{bmatrix} + 1 & 2 +\end{bmatrix} = +\begin{bmatrix} + 0 & 0 \\\\ + 2 & 2 +\end{bmatrix} +$$ + +which is the same as this: + +$$ +\begin{bmatrix} + 1 & 2 \\\\ + 3 & 4 +\end{bmatrix} - +\begin{bmatrix} + 1 & 2 \\\\ + 1 & 2 +\end{bmatrix} = +\begin{bmatrix} + 0 & 0 \\\\ + 2 & 2 +\end{bmatrix} +$$ + +This rewrite happens in Nx too, also through a broadcast. We want to +broadcast the tensor `[1, 2]` to match the `{2, 2}` shape, like this: + +```elixir +Nx.broadcast(Nx.tensor([1, 2]), {2, 2}) +``` + +The `subtract` function in `Nx` takes care of that broadcast +implicitly, as before: + +```elixir +Nx.subtract(tensor, Nx.tensor([1, 2])) +``` + +The broadcast worked as advertised, copying the `[1, 2]` row +enough times to fill a `{2, 2}` tensor. A tensor with a +dimension of `1` will broadcast to fill the tensor: + +```elixir +[[1], [2]] |> Nx.tensor() |> Nx.broadcast({1, 2, 2}) +``` + +```elixir +[[[1, 2, 3]]] +|> Nx.tensor() +|> Nx.broadcast({4, 2, 3}) +``` + +Both of these examples copy parts of the tensor enough +times to fill out the broadcast shape. You can check out the +Nx broadcasting documentation for more details: + + + +```elixir +h Nx.broadcast +``` + +Much of the time, you won't have to broadcast yourself. Many of +the functions and operators Nx supports will do so automatically. + +We can use tensor-aware operators via various `Nx` functions and +many of them implicitly broadcast tensors. + +Throughout this section, we have been invoking `Nx.subtract/2` and +our code would be more expressive if we could use its equivalent +mathematical operator. Fortunately, Nx provides a way. Next, we'll +dive into numerical definitions using `defn`. diff --git a/nx/guides/getting_started/numerical_definitions.livemd b/nx/guides/getting_started/numerical_definitions.livemd new file mode 100644 index 0000000000..3f7b607042 --- /dev/null +++ b/nx/guides/getting_started/numerical_definitions.livemd @@ -0,0 +1,121 @@ +# Numerical definitions (defn) + +The `defn` macro simplifies the expression of mathematical formulas +containing tensors. Numerical definitions have two primary benefits +over classic Elixir functions. + +- They are _tensor-aware_. Nx replaces operators like `Kernel.-/2` + with the `Defn` counterparts — which in turn use `Nx` functions + optimized for tensors — so the formulas we express can use + tensors out of the box. + +- `defn` definitions allow for building computation graph of all the + individual operations and using a just-in-time (JIT) compiler to emit + highly specialized native code for the desired computation unit. + +We don't have to do anything special to get access to +get tensor awareness beyond importing `Nx.Defn` and writing +our code within a `defn` block. + +To use Nx in a Mix project or a notebook, we need to include +the `:nx` dependency and import the `Nx.Defn` module, +like this: + +```elixir +Mix.install([ + {:nx, "~> 0.5"} +]) +``` + +```elixir +import Nx.Defn +``` + +Just as the Elixir language supports `def`, `defmacro`, and `defp`, +Nx supports `defn`. There are a few restrictions. It allows only +numerical arguments in the form of primitives or tensors as arguments +or return values, and supports only a subset of the language. + +The subset of Elixir allowed within `defn` is quite broad, though. We can +use macros, pipes, and even conditionals, so we're not giving up +much when you're declaring mathematical functions. + +Additionally, despite these small concessions, `defn` provides huge benefits. +Code in a `defn` block uses tensor aware operators and types, so the math +beneath your functions has a better chance to shine through. Numerical +definitions can also run on accelerated numerical processors like GPUs and +TPUs. Here's an example numerical definition: + +```elixir +defmodule TensorMath do + import Nx.Defn + + defn subtract(a, b) do + a - b + end +end +``` + +This module has a numerical definition that will be compiled. +If we wanted to specify a compiler for this module, we could add +a module attribute before the `defn` clause. One of such compilers +is [the EXLA compiler](https://github.com/elixir-nx/nx/tree/main/exla). +You'd add the `mix` dependency for EXLA and do this: + + + +```elixir +@defn_compiler EXLA +defn subtract(a, b) do + a - b +end +``` + +Now, it's your turn. Add a `defn` to `TensorMath` +that accepts two tensors representing the lengths of sides of a +right triangle and uses the pythagorean theorem to return the +[length of the hypotenuse](https://www.mathsisfun.com/pythagoras.html). +Add your function directly to the previous Code cell. + +## deftransform + +The defn macro in Nx allows you to define functions that compile to efficient +numerical computations, but it comes with certain limitations—such as restrictions +on argument types, return values, and the subset of Elixir that it supports. +To overcome many of these limitations, Nx offers the deftransform macro. + +deftransform lets you perform computations or execute code that isn't directly +supported by defn, and then incorporate those results back into your numerical +function. This separation lets you use standard Elixir features where necessary +while keeping your core numerical logic optimized. + +In the following example, we define a deftransform function called +compute_tensor_from_list/1 that receives a list, which is not allowed +inside defn. Inside this transform function, we convert the list to a tensor +using Nx.tensor/1, and then pass it to a defn function called double_tensor/1, +which performs the actual numerical computation. + +```elixir +defmodule MyMath do + import Nx.Defn + + defn double_tensor(tensor) do + tensor * 2 + end + + deftransform compute_tensor_from_list(list) do + tensor = Nx.tensor(list) + double_tensor(tensor) + end +end + +``` + +```elixir +input = [1, 2, 3, 4] +result = MyMath.compute_tensor_from_list(input) +``` + +This setup allows us to keep our defn code clean and focused only on tensor +operations, while using deftransform to handle Elixir-native types and +preprocessing. diff --git a/nx/mix.exs b/nx/mix.exs index fb0af74bcd..46ac5ad508 100644 --- a/nx/mix.exs +++ b/nx/mix.exs @@ -61,6 +61,8 @@ defmodule Nx.MixProject do "guides/getting_started/introduction.md", "guides/getting_started/installation.md", "guides/getting_started/quickstart.livemd", + "guides/getting_started/broadcast.livemd", + "guides/getting_started/numerical_definitions.livemd", "guides/advanced/vectorization.livemd", "guides/advanced/aggregation.livemd", "guides/advanced/automatic_differentiation.livemd", From 2c51ba5ec2da2a4cb128fe140977e1ca933d24f6 Mon Sep 17 00:00:00 2001 From: TomasPegado Date: Wed, 26 Mar 2025 19:07:54 -0300 Subject: [PATCH 25/36] docs: Adding Cheatsheet --- nx/guides/cheatsheet/cheatsheet.cheatmd | 7 +++++++ nx/mix.exs | 2 ++ 2 files changed, 9 insertions(+) create mode 100644 nx/guides/cheatsheet/cheatsheet.cheatmd diff --git a/nx/guides/cheatsheet/cheatsheet.cheatmd b/nx/guides/cheatsheet/cheatsheet.cheatmd new file mode 100644 index 0000000000..b9bdb1cc8a --- /dev/null +++ b/nx/guides/cheatsheet/cheatsheet.cheatmd @@ -0,0 +1,7 @@ +# Cheatsheet + +This cheatsheet is designed to assist Python developers in transitioning to Elixir, +specifically by providing equivalent commands and code examples between NumPy and Nx. + +## Numpy -> Nx + diff --git a/nx/mix.exs b/nx/mix.exs index 46ac5ad508..6a73c427d6 100644 --- a/nx/mix.exs +++ b/nx/mix.exs @@ -63,6 +63,7 @@ defmodule Nx.MixProject do "guides/getting_started/quickstart.livemd", "guides/getting_started/broadcast.livemd", "guides/getting_started/numerical_definitions.livemd", + "guides/cheatsheet/cheatsheet.cheatmd", "guides/advanced/vectorization.livemd", "guides/advanced/aggregation.livemd", "guides/advanced/automatic_differentiation.livemd", @@ -119,6 +120,7 @@ defmodule Nx.MixProject do ], groups_for_extras: [ "Getting Started": ~r"^guides/getting_started/", + Cheatsheet: ~r"^guides/cheatsheet/", Exercises: ~r"^guides/exercises/", Advanced: ~r"^guides/advanced/" ] From e049bfb728b3e437a56543dfd43a55453bdb19b9 Mon Sep 17 00:00:00 2001 From: TomasPegado Date: Thu, 27 Mar 2025 14:21:40 -0300 Subject: [PATCH 26/36] docs: cheatsheet on array creation --- nx/guides/cheatsheet/cheatsheet.cheatmd | 47 +++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/nx/guides/cheatsheet/cheatsheet.cheatmd b/nx/guides/cheatsheet/cheatsheet.cheatmd index b9bdb1cc8a..69f46c9563 100644 --- a/nx/guides/cheatsheet/cheatsheet.cheatmd +++ b/nx/guides/cheatsheet/cheatsheet.cheatmd @@ -4,4 +4,51 @@ This cheatsheet is designed to assist Python developers in transitioning to Elix specifically by providing equivalent commands and code examples between NumPy and Nx. ## Numpy -> Nx +{: .col-2} +### Array Creation + + +#### Python code + +```python +import numpy as np + +# From list or nested list +a = np.array([1, 2, 3]) +b = np.array([[1, 2], [3, 4]]) + +# zeros and ones +np.zeros((2, 3)) # 2x3 array filled with zeros +np.ones((2, 3)) # 2x3 array filled with ones + +# Range of Numbers (like range()) +np.arange(0, 10, 2) # [0 2 4 6 8] + +# Linearly Spaced Values +np.linspace(0, 1, 5) # [0. 0.25 0.5 0.75 1. ] + +``` + +### Tensor Creation + +#### Elixir code + +```elixir +Mix.install([:nx]) + +# From list or nested list +a = Nx.tensor([1, 2, 3]) +b = Nx.tensor([[1, 2], [3, 4]]) + +# zeros and ones +Nx.broadcast(0, {2, 3}) # 2x3 tensor filled with zeros +Nx.broadcast(1, {2, 3}) # 2x3 tensor filled with ones + +# Range of Numbers (like range()) +Nx.iota({5}, axis: 0) |> Nx.multiply(2) # [0 2 4 6 8] + +# Linearly Spaced Values +Nx.iota({5}) |> Nx.divide(4) # [0.0, 0.25, 0.5, 0.75, 1.0] + +``` \ No newline at end of file From 69f69419ac08609ddc038018b6bedb851e50fc2f Mon Sep 17 00:00:00 2001 From: TomasPegado Date: Tue, 1 Apr 2025 17:38:32 -0300 Subject: [PATCH 27/36] docs: Cheatsheet basics --- nx/guides/cheatsheet/cheatsheet.cheatmd | 77 ++++++++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/nx/guides/cheatsheet/cheatsheet.cheatmd b/nx/guides/cheatsheet/cheatsheet.cheatmd index 69f46c9563..444f1d8414 100644 --- a/nx/guides/cheatsheet/cheatsheet.cheatmd +++ b/nx/guides/cheatsheet/cheatsheet.cheatmd @@ -8,7 +8,6 @@ specifically by providing equivalent commands and code examples between NumPy an ### Array Creation - #### Python code ```python @@ -30,6 +29,44 @@ np.linspace(0, 1, 5) # [0. 0.25 0.5 0.75 1. ] ``` +### Array Inspection + +#### Python code + +```python + +# Shape +a = np.array([[1, 2, 3], [4, 5, 6]]) +a.shape # (2, 3) + +# Number of dimensions +a.ndim # 2 + +# Data Type +a.dtype # dtype('int64') + +# Total Number of Elements +a.size +``` + +### Indexing and Slicing: Array + +#### Python code +```python + +# Indexing a Single Element +a = np.array([[10, 20], [30, 40]]) +a[0, 1] # 20 + +# Slicing a Range +a = np.array([10, 20, 30, 40, 50]) +a[1:4] # [20 30 40] + +# Selecting Along a Specific Axis +a = np.array([[1, 2], [3, 4], [5, 6]]) +a[:, 1] # [2 4 6] +``` + ### Tensor Creation #### Elixir code @@ -51,4 +88,42 @@ Nx.iota({5}, axis: 0) |> Nx.multiply(2) # [0 2 4 6 8] # Linearly Spaced Values Nx.iota({5}) |> Nx.divide(4) # [0.0, 0.25, 0.5, 0.75, 1.0] +``` + +### Tensor Inspection + +#### Elixir code + +```elixir + +# Shape +a = Nx.tensor([[1, 2, 3], [4, 5, 6]]) +Nx.shape(a) # {2, 3} + +# Number of Dimensions +Nx.rank(a) # 2 + +# Data Type +Nx.type(a) # {:s, 64} + +# Total Number of Elements +Nx.size(a) # 6 +``` + +### Indexing ans Slicing: Tensor + +#### Elixir code +```elixir +# Indexing a Single Element +a = Nx.tensor([[10, 20], [30, 40]]) +tensor[[0, 1]] # Returns a tensor, even for a single value (not a scalar like NumPy). + +# Slicing a Range + +a = Nx.tensor([10, 20, 30, 40, 50]) +a[1..3] # [20 30 40] + +# Selecting Along a Specific Axis +a = Nx.tensor([[1, 2], [3, 4], [5, 6]]) +a[[.., 1]] # [2 4 6] ``` \ No newline at end of file From f6966fd2af62420bbe294744b4bf1fe32ee9d45d Mon Sep 17 00:00:00 2001 From: TomasPegado Date: Wed, 2 Apr 2025 19:29:44 -0300 Subject: [PATCH 28/36] docs: cheatsheet linear algebra --- nx/guides/cheatsheet/cheatsheet.cheatmd | 73 +++++++++++++++++++++++-- 1 file changed, 68 insertions(+), 5 deletions(-) diff --git a/nx/guides/cheatsheet/cheatsheet.cheatmd b/nx/guides/cheatsheet/cheatsheet.cheatmd index 444f1d8414..f2cfa7bfbd 100644 --- a/nx/guides/cheatsheet/cheatsheet.cheatmd +++ b/nx/guides/cheatsheet/cheatsheet.cheatmd @@ -34,7 +34,6 @@ np.linspace(0, 1, 5) # [0. 0.25 0.5 0.75 1. ] #### Python code ```python - # Shape a = np.array([[1, 2, 3], [4, 5, 6]]) a.shape # (2, 3) @@ -53,7 +52,6 @@ a.size #### Python code ```python - # Indexing a Single Element a = np.array([[10, 20], [30, 40]]) a[0, 1] # 20 @@ -67,6 +65,41 @@ a = np.array([[1, 2], [3, 4], [5, 6]]) a[:, 1] # [2 4 6] ``` +### Linear Algebra Operations + +#### Python code +```python +# Matrix Multiplication +A = np.array([[1, 2], [3, 4]]) +B = np.array([[5, 6], [7, 8]]) + +np.matmul(A, B) +# or simply +A @ B + +# Transpose +A.T + +# Identity Matrix +np.eye(3) + +# Determinant +np.linalg.det(A) + +# Inverse +np.linalg.inv(A) + +# Solve a System of Linear Equations +A = np.array([[3, 1], [1, 2]]) +b = np.array([9, 8]) + +np.linalg.solve(A, b) + +# Eigenvalues and Eigenvectors +np.linalg.eig(A) + +``` + ### Tensor Creation #### Elixir code @@ -95,7 +128,6 @@ Nx.iota({5}) |> Nx.divide(4) # [0.0, 0.25, 0.5, 0.75, 1.0] #### Elixir code ```elixir - # Shape a = Nx.tensor([[1, 2, 3], [4, 5, 6]]) Nx.shape(a) # {2, 3} @@ -110,7 +142,7 @@ Nx.type(a) # {:s, 64} Nx.size(a) # 6 ``` -### Indexing ans Slicing: Tensor +### Indexing and Slicing: Tensor #### Elixir code ```elixir @@ -119,11 +151,42 @@ a = Nx.tensor([[10, 20], [30, 40]]) tensor[[0, 1]] # Returns a tensor, even for a single value (not a scalar like NumPy). # Slicing a Range - a = Nx.tensor([10, 20, 30, 40, 50]) a[1..3] # [20 30 40] # Selecting Along a Specific Axis a = Nx.tensor([[1, 2], [3, 4], [5, 6]]) a[[.., 1]] # [2 4 6] +``` + +### Linear Algebra Operations + +#### Elixir code +```elixir +# Matrix Multiplication +a = Nx.tensor([[1, 2], [3, 4]]) +b = Nx.tensor([[5, 6], [7, 8]]) + +Nx.dot(a, b) + +# Transpose +Nx.transpose(a) + +# Identity Matrix +Nx.eye({3, 3}) + +# Determinant +Nx.LinAlg.det(a) + +# Inverse +Nx.LinAlg.inverse(a) + +# Solve a System of Linear Equations +a = Nx.tensor([[3.0, 1.0], [1.0, 2.0]]) +b = Nx.tensor([9.0, 8.0]) + +Nx.LinAlg.solve(a, b) + +# Eigenvalues and Eigenvectors +Nx.LinAlg.eigh(a) ``` \ No newline at end of file From 12462988158f9fd61366ff5d23bd53426e48fa56 Mon Sep 17 00:00:00 2001 From: TomasPegado Date: Thu, 3 Apr 2025 14:08:28 -0300 Subject: [PATCH 29/36] docs: numpy -> elixir cheatsheet --- .../numpy_nx.cheatmd} | 142 +++++++++--------- nx/mix.exs | 2 + 2 files changed, 69 insertions(+), 75 deletions(-) rename nx/guides/{cheatsheet/cheatsheet.cheatmd => cheatsheets/numpy_nx.cheatmd} (87%) diff --git a/nx/guides/cheatsheet/cheatsheet.cheatmd b/nx/guides/cheatsheets/numpy_nx.cheatmd similarity index 87% rename from nx/guides/cheatsheet/cheatsheet.cheatmd rename to nx/guides/cheatsheets/numpy_nx.cheatmd index f2cfa7bfbd..3628c89077 100644 --- a/nx/guides/cheatsheet/cheatsheet.cheatmd +++ b/nx/guides/cheatsheets/numpy_nx.cheatmd @@ -1,14 +1,12 @@ -# Cheatsheet +# Numpy -> Nx This cheatsheet is designed to assist Python developers in transitioning to Elixir, specifically by providing equivalent commands and code examples between NumPy and Nx. -## Numpy -> Nx +## Tensor Creation {: .col-2} -### Array Creation - -#### Python code +### Python ```python import numpy as np @@ -29,9 +27,31 @@ np.linspace(0, 1, 5) # [0. 0.25 0.5 0.75 1. ] ``` -### Array Inspection +### Elixir + +```elixir +Mix.install([:nx]) + +# From list or nested list +a = Nx.tensor([1, 2, 3]) +b = Nx.tensor([[1, 2], [3, 4]]) + +# zeros and ones +Nx.broadcast(0, {2, 3}) # 2x3 tensor filled with zeros +Nx.broadcast(1, {2, 3}) # 2x3 tensor filled with ones + +# Range of Numbers (like range()) +Nx.iota({5}, axis: 0) |> Nx.multiply(2) # [0 2 4 6 8] + +# Linearly Spaced Values +Nx.iota({5}) |> Nx.divide(4) # [0.0, 0.25, 0.5, 0.75, 1.0] + +``` + +## Tensor Inspection +{: .col-2} -#### Python code +### Python ```python # Shape @@ -48,9 +68,27 @@ a.dtype # dtype('int64') a.size ``` -### Indexing and Slicing: Array +### Elixir -#### Python code +```elixir +# Shape +a = Nx.tensor([[1, 2, 3], [4, 5, 6]]) +Nx.shape(a) # {2, 3} + +# Number of Dimensions +Nx.rank(a) # 2 + +# Data Type +Nx.type(a) # {:s, 64} + +# Total Number of Elements +Nx.size(a) # 6 +``` + +## Indexing and Slicing +{: .col-2} + +### Python ```python # Indexing a Single Element a = np.array([[10, 20], [30, 40]]) @@ -65,9 +103,25 @@ a = np.array([[1, 2], [3, 4], [5, 6]]) a[:, 1] # [2 4 6] ``` -### Linear Algebra Operations +### Elixir +```elixir +# Indexing a Single Element +a = Nx.tensor([[10, 20], [30, 40]]) +tensor[[0, 1]] # Returns a tensor, even for a single value (not a scalar like NumPy). -#### Python code +# Slicing a Range +a = Nx.tensor([10, 20, 30, 40, 50]) +a[1..3] # [20 30 40] + +# Selecting Along a Specific Axis +a = Nx.tensor([[1, 2], [3, 4], [5, 6]]) +a[[.., 1]] # [2 4 6] +``` + +## Linear Algebra Operations +{: .col-2} + +### Python ```python # Matrix Multiplication A = np.array([[1, 2], [3, 4]]) @@ -96,72 +150,10 @@ b = np.array([9, 8]) np.linalg.solve(A, b) # Eigenvalues and Eigenvectors -np.linalg.eig(A) - +np.linalg.eigh(A) ``` -### Tensor Creation - -#### Elixir code - -```elixir -Mix.install([:nx]) - -# From list or nested list -a = Nx.tensor([1, 2, 3]) -b = Nx.tensor([[1, 2], [3, 4]]) - -# zeros and ones -Nx.broadcast(0, {2, 3}) # 2x3 tensor filled with zeros -Nx.broadcast(1, {2, 3}) # 2x3 tensor filled with ones - -# Range of Numbers (like range()) -Nx.iota({5}, axis: 0) |> Nx.multiply(2) # [0 2 4 6 8] - -# Linearly Spaced Values -Nx.iota({5}) |> Nx.divide(4) # [0.0, 0.25, 0.5, 0.75, 1.0] - -``` - -### Tensor Inspection - -#### Elixir code - -```elixir -# Shape -a = Nx.tensor([[1, 2, 3], [4, 5, 6]]) -Nx.shape(a) # {2, 3} - -# Number of Dimensions -Nx.rank(a) # 2 - -# Data Type -Nx.type(a) # {:s, 64} - -# Total Number of Elements -Nx.size(a) # 6 -``` - -### Indexing and Slicing: Tensor - -#### Elixir code -```elixir -# Indexing a Single Element -a = Nx.tensor([[10, 20], [30, 40]]) -tensor[[0, 1]] # Returns a tensor, even for a single value (not a scalar like NumPy). - -# Slicing a Range -a = Nx.tensor([10, 20, 30, 40, 50]) -a[1..3] # [20 30 40] - -# Selecting Along a Specific Axis -a = Nx.tensor([[1, 2], [3, 4], [5, 6]]) -a[[.., 1]] # [2 4 6] -``` - -### Linear Algebra Operations - -#### Elixir code +### Elixir ```elixir # Matrix Multiplication a = Nx.tensor([[1, 2], [3, 4]]) diff --git a/nx/mix.exs b/nx/mix.exs index 0b05a2df72..0d4b5e823a 100644 --- a/nx/mix.exs +++ b/nx/mix.exs @@ -62,6 +62,7 @@ defmodule Nx.MixProject do "guides/getting_started/quickstart.livemd", "guides/getting_started/broadcasting.livemd", "guides/getting_started/numerical_definitions.livemd", + "guides/cheatsheets/numpy_nx.cheatmd", "guides/advanced/vectorization.livemd", "guides/advanced/aggregation.livemd", "guides/advanced/automatic_differentiation.livemd", @@ -118,6 +119,7 @@ defmodule Nx.MixProject do ], groups_for_extras: [ "Getting Started": ~r"^guides/getting_started/", + "Cheatsheets": ~r"^guides/cheatsheets/", Exercises: ~r"^guides/exercises/", Advanced: ~r"^guides/advanced/" ] From 730e55addced2536cdec29ef8f3850abcf38015f Mon Sep 17 00:00:00 2001 From: TomasPegado Date: Mon, 7 Apr 2025 18:12:00 -0300 Subject: [PATCH 30/36] docs: adds :makeup_syntect and updates makeup to 1.2.1 --- nx/mix.exs | 4 +++- nx/mix.lock | 7 +++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/nx/mix.exs b/nx/mix.exs index 6a73c427d6..eb9d31c87a 100644 --- a/nx/mix.exs +++ b/nx/mix.exs @@ -37,7 +37,9 @@ defmodule Nx.MixProject do [ {:complex, "~> 0.6"}, {:telemetry, "~> 0.4.0 or ~> 1.0"}, - {:ex_doc, "~> 0.29", only: :docs} + {:ex_doc, "~> 0.29", only: :docs}, + {:makeup, "~> 1.2.1"}, + {:makeup_syntect, "~> 0.1"} ] end diff --git a/nx/mix.lock b/nx/mix.lock index 2d7fd485c4..6af877993c 100644 --- a/nx/mix.lock +++ b/nx/mix.lock @@ -1,10 +1,13 @@ %{ + "castore": {:hex, :castore, "1.0.12", "053f0e32700cbec356280c0e835df425a3be4bc1e0627b714330ad9d0f05497f", [:mix], [], "hexpm", "3dca286b2186055ba0c9449b4e95b97bf1b57b47c1f2644555879e659960c224"}, "complex": {:hex, :complex, "0.6.0", "b0130086a7a8c33574d293b2e0e250f4685580418eac52a5658a4bd148f3ccf1", [:mix], [], "hexpm", "0a5fa95580dcaf30fcd60fe1aaf24327c0fe401e98c24d892e172e79498269f9"}, "earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"}, "ex_doc": {:hex, :ex_doc, "0.34.2", "13eedf3844ccdce25cfd837b99bea9ad92c4e511233199440488d217c92571e8", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "5ce5f16b41208a50106afed3de6a2ed34f4acfd65715b82a0b84b49d995f95c1"}, - "makeup": {:hex, :makeup, "1.1.2", "9ba8837913bdf757787e71c1581c21f9d2455f4dd04cfca785c70bbfff1a76a3", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cce1566b81fbcbd21eca8ffe808f33b221f9eee2cbc7a1706fc3da9ff18e6cac"}, + "makeup": {:hex, :makeup, "1.2.1", "e90ac1c65589ef354378def3ba19d401e739ee7ee06fb47f94c687016e3713d1", [:mix], [{:nimble_parsec, "~> 1.4", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "d36484867b0bae0fea568d10131197a4c2e47056a6fbe84922bf6ba71c8d17ce"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"}, "makeup_erlang": {:hex, :makeup_erlang, "1.0.1", "c7f58c120b2b5aa5fd80d540a89fdf866ed42f1f3994e4fe189abebeab610839", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "8a89a1eeccc2d798d6ea15496a6e4870b75e014d1af514b1b71fa33134f57814"}, - "nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"}, + "makeup_syntect": {:hex, :makeup_syntect, "0.1.3", "ae2c3437f479ea50d08d794acaf02a2f3a8c338dd1f757f6b237c42eb27fcde1", [:mix], [{:makeup, "~> 1.2", [hex: :makeup, repo: "hexpm", optional: false]}, {:rustler, "~> 0.36.1", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.8.2", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}], "hexpm", "a27bd3bd8f7b87465d110295a33ed1022202bea78701bd2bbeadfb45d690cdbf"}, + "nimble_parsec": {:hex, :nimble_parsec, "1.4.2", "8efba0122db06df95bfaa78f791344a89352ba04baedd3849593bfce4d0dc1c6", [:mix], [], "hexpm", "4b21398942dda052b403bbe1da991ccd03a053668d147d53fb8c4e0efe09c973"}, + "rustler_precompiled": {:hex, :rustler_precompiled, "0.8.2", "5f25cbe220a8fac3e7ad62e6f950fcdca5a5a5f8501835d2823e8c74bf4268d5", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "63d1bd5f8e23096d1ff851839923162096364bac8656a4a3c00d1fff8e83ee0a"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, } From 3afc26b6d19253ca62bde24189194a5c73507e2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Lenzi?= <61877861+TomasPegado@users.noreply.github.com> Date: Tue, 22 Apr 2025 13:07:38 -0300 Subject: [PATCH 31/36] Update nx/mix.exs Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com> --- nx/mix.exs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nx/mix.exs b/nx/mix.exs index 98b99bed86..7b71db5a1c 100644 --- a/nx/mix.exs +++ b/nx/mix.exs @@ -38,8 +38,8 @@ defmodule Nx.MixProject do {:complex, "~> 0.6"}, {:telemetry, "~> 0.4.0 or ~> 1.0"}, {:ex_doc, "~> 0.29", only: :docs}, - {:makeup, "~> 1.2.1"}, - {:makeup_syntect, "~> 0.1"} + {:makeup, "~> 1.2.1", only: :docs}, + {:makeup_syntect, "~> 0.1", only: :docs} ] end From 1533ff82f76909382e88dbb4255ce094ff727525 Mon Sep 17 00:00:00 2001 From: TomasPegado Date: Thu, 24 Apr 2025 18:06:23 -0300 Subject: [PATCH 32/36] fix: changes based on reviews --- nx/guides/cheatsheets/numpy_nx.cheatmd | 327 ++++++++++++++---- .../numerical_definitions.livemd | 8 - 2 files changed, 262 insertions(+), 73 deletions(-) diff --git a/nx/guides/cheatsheets/numpy_nx.cheatmd b/nx/guides/cheatsheets/numpy_nx.cheatmd index 3628c89077..1a4d2ea146 100644 --- a/nx/guides/cheatsheets/numpy_nx.cheatmd +++ b/nx/guides/cheatsheets/numpy_nx.cheatmd @@ -9,42 +9,85 @@ specifically by providing equivalent commands and code examples between NumPy an ### Python ```python -import numpy as np +>>> import numpy as np # From list or nested list -a = np.array([1, 2, 3]) -b = np.array([[1, 2], [3, 4]]) +>>> np.array([1, 2, 3]) +array([1, 2, 3]) +>>> np.array([[1, 2], [3, 4]]) +array([[1, 2], + [3, 4]]) # zeros and ones -np.zeros((2, 3)) # 2x3 array filled with zeros -np.ones((2, 3)) # 2x3 array filled with ones +>>> np.zeros((2, 3)) +array([[0., 0., 0.], + [0., 0., 0.]]) +>>> np.ones((2, 3)) +array([[1., 1., 1.], + [1., 1., 1.]]) # Range of Numbers (like range()) -np.arange(0, 10, 2) # [0 2 4 6 8] +>>> np.arange(0, 10, 2) +array([0, 2, 4, 6, 8]) # Linearly Spaced Values -np.linspace(0, 1, 5) # [0. 0.25 0.5 0.75 1. ] +>>> np.linspace(0, 1, 5) +array([0. , 0.25, 0.5 , 0.75, 1. ]) ``` ### Elixir ```elixir -Mix.install([:nx]) +iex> Mix.install([:nx, "~> 0.9"]) +:ok # From list or nested list -a = Nx.tensor([1, 2, 3]) -b = Nx.tensor([[1, 2], [3, 4]]) +iex> Nx.tensor([1, 2, 3]) +#Nx.Tensor< + s32[3] + [1, 2, 3] +> +iex> Nx.tensor([[1, 2], [3, 4]]) +#Nx.Tensor< + s32[2][2] + [ + [1, 2], + [3, 4] + ] +> # zeros and ones -Nx.broadcast(0, {2, 3}) # 2x3 tensor filled with zeros -Nx.broadcast(1, {2, 3}) # 2x3 tensor filled with ones +iex> Nx.broadcast(0, {2, 3}) +#Nx.Tensor< + s32[2][3] + [ + [0, 0, 0], + [0, 0, 0] + ] +> +iex> Nx.broadcast(1, {2, 3}) +#Nx.Tensor< + s32[2][3] + [ + [1, 1, 1], + [1, 1, 1] + ] +> # Range of Numbers (like range()) -Nx.iota({5}, axis: 0) |> Nx.multiply(2) # [0 2 4 6 8] +iex> Nx.iota({5}, axis: 0) |> Nx.multiply(2) +#Nx.Tensor< + s32[5] + [0, 2, 4, 6, 8] +> # Linearly Spaced Values -Nx.iota({5}) |> Nx.divide(4) # [0.0, 0.25, 0.5, 0.75, 1.0] +iex> Nx.iota({5}) |> Nx.divide(4) +#Nx.Tensor< + f32[5] + [0.0, 0.25, 0.5, 0.75, 1.0] +> ``` @@ -54,35 +97,55 @@ Nx.iota({5}) |> Nx.divide(4) # [0.0, 0.25, 0.5, 0.75, 1.0] ### Python ```python +>>> import numpy as np + # Shape -a = np.array([[1, 2, 3], [4, 5, 6]]) -a.shape # (2, 3) +>>> a = np.array([[1, 2, 3], [4, 5, 6]]) +>>> a.shape +(2, 3) # Number of dimensions -a.ndim # 2 +>>> a.ndim +2 # Data Type -a.dtype # dtype('int64') +>>> a.dtype +dtype('int64') # Total Number of Elements -a.size +>>> a.size +6 ``` ### Elixir ```elixir +iex> Mix.install([:nx]) +:ok + # Shape -a = Nx.tensor([[1, 2, 3], [4, 5, 6]]) -Nx.shape(a) # {2, 3} +iex> a = Nx.tensor([[1, 2, 3], [4, 5, 6]]) +#Nx.Tensor< + s32[2][3] + [ + [1, 2, 3], + [4, 5, 6] + ] +> +iex> Nx.shape(a) +{2, 3} -# Number of Dimensions -Nx.rank(a) # 2 +# Number of dimensions +iex> Nx.rank(a) +2 # Data Type -Nx.type(a) # {:s, 64} +iex> Nx.type(a) +{:s, 32} # Total Number of Elements -Nx.size(a) # 6 +iex> Nx.size(a) +6 ``` ## Indexing and Slicing @@ -90,32 +153,81 @@ Nx.size(a) # 6 ### Python ```python +>>> import numpy as np + # Indexing a Single Element -a = np.array([[10, 20], [30, 40]]) -a[0, 1] # 20 +>>> a = np.array([[10, 20], [30, 40]]) +>>> a[0, 1] +np.int64(20) # Slicing a Range -a = np.array([10, 20, 30, 40, 50]) -a[1:4] # [20 30 40] +>>> a = np.array([10, 20, 30, 40, 50]) +>>> a[1:4] +array([20, 30, 40]) # Selecting Along a Specific Axis -a = np.array([[1, 2], [3, 4], [5, 6]]) -a[:, 1] # [2 4 6] +>>> a = np.array([[1, 2], [3, 4], [5, 6]]) +>>> a[:, 1] +array([2, 4, 6]) + +# Boolean Masking +>>> x = np.arange(10) +>>> x[x % 2 == 0] +array([0, 2, 4, 6, 8]) ``` ### Elixir ```elixir +iex> Mix.install([:nx, "~> 0.9"]) +:ok + # Indexing a Single Element -a = Nx.tensor([[10, 20], [30, 40]]) -tensor[[0, 1]] # Returns a tensor, even for a single value (not a scalar like NumPy). +iex> tensor = Nx.tensor([[10, 20], [30, 40]]) +#Nx.Tensor< + s32[2][2] + [ + [10, 20], + [30, 40] + ] +> +iex> tensor[[0, 1]] +#Nx.Tensor< + s32 + 20 +> # Slicing a Range -a = Nx.tensor([10, 20, 30, 40, 50]) -a[1..3] # [20 30 40] +iex> a = Nx.tensor([10, 20, 30, 40, 50]) +#Nx.Tensor< + s32[5] + [10, 20, 30, 40, 50] +> +iex> a[1..3] +#Nx.Tensor< + s32[3] + [20, 30, 40] +> # Selecting Along a Specific Axis -a = Nx.tensor([[1, 2], [3, 4], [5, 6]]) -a[[.., 1]] # [2 4 6] +iex> a = Nx.tensor([[1, 2], [3, 4], [5, 6]]) +#Nx.Tensor< + s32[3][2] + [ + [1, 2], + [3, 4], + [5, 6] + ] +> +iex> a[[.., 1]] +#Nx.Tensor< + s32[3] + [2, 4, 6] +> + +# Boolean Masking +# requires dynamic shape behavior, which is not directly supported +# in Nx because Nx compiles all operations ahead-of-time (like XLA or JAX), +# and tensors must have static shapes. ``` ## Linear Algebra Operations @@ -123,62 +235,147 @@ a[[.., 1]] # [2 4 6] ### Python ```python -# Matrix Multiplication -A = np.array([[1, 2], [3, 4]]) -B = np.array([[5, 6], [7, 8]]) +>>> import numpy as np -np.matmul(A, B) -# or simply -A @ B +# Matrix Multiplication +>>> A = np.array([[1, 2], [3, 4]]) +>>> B = np.array([[5, 6], [7, 8]]) +>>> np.matmul(A, B) +array([[19, 22], + [43, 50]]) # Transpose -A.T +>>> A.T +array([[1, 3], + [2, 4]]) # Identity Matrix -np.eye(3) +>>> np.eye(3) +array([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]]) # Determinant -np.linalg.det(A) +>>> np.linalg.det(A) +np.float64(-2.0000000000000004) # Inverse -np.linalg.inv(A) +>>> np.linalg.inv(A) +array([[-2. , 1. ], + [ 1.5, -0.5]]) # Solve a System of Linear Equations -A = np.array([[3, 1], [1, 2]]) -b = np.array([9, 8]) - -np.linalg.solve(A, b) +>>> A = np.array([[3, 1], [1, 2]]) +>>> b = np.array([9, 8]) +>>> np.linalg.solve(A, b) +array([2., 3.]) # Eigenvalues and Eigenvectors -np.linalg.eigh(A) +>>> np.linalg.eigh(A) +EighResult(eigenvalues=array([1.38196601, 3.61803399]), eigenvectors=array([[ 0.52573111, -0.85065081], + [-0.85065081, -0.52573111]])) ``` ### Elixir ```elixir -# Matrix Multiplication -a = Nx.tensor([[1, 2], [3, 4]]) -b = Nx.tensor([[5, 6], [7, 8]]) +iex(108)> Mix.install([:nx]) +:ok -Nx.dot(a, b) +# Matrix Multiplication +iex(111)> a = Nx.tensor([[1, 2], [3, 4]]) +#Nx.Tensor< + s32[2][2] + [ + [1, 2], + [3, 4] + ] +> +iex(112)> b = Nx.tensor([[5, 6], [7, 8]]) +#Nx.Tensor< + s32[2][2] + [ + [5, 6], + [7, 8] + ] +> +iex(114)> Nx.dot(a, b) +#Nx.Tensor< + s32[2][2] + [ + [19, 22], + [43, 50] + ] +> # Transpose -Nx.transpose(a) +iex(117)> Nx.transpose(a) +#Nx.Tensor< + s32[2][2] + [ + [1, 3], + [2, 4] + ] +> # Identity Matrix -Nx.eye({3, 3}) +iex(120)> Nx.eye({3, 3}) +#Nx.Tensor< + s32[3][3] + [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1] + ] +> # Determinant -Nx.LinAlg.det(a) +iex(123)> Nx.LinAlg.determinant(a) +#Nx.Tensor< + f32 + -2.0 +> # Inverse -Nx.LinAlg.inverse(a) +iex(126)> Nx.LinAlg.invert(a) +#Nx.Tensor< + f32[2][2] + [ + [-2.000000476837158, 1.0000003576278687], + [1.5000004768371582, -0.5000002384185791] + ] +> # Solve a System of Linear Equations -a = Nx.tensor([[3.0, 1.0], [1.0, 2.0]]) -b = Nx.tensor([9.0, 8.0]) - -Nx.LinAlg.solve(a, b) +iex(129)> a = Nx.tensor([[3.0, 1.0], [1.0, 2.0]]) +#Nx.Tensor< + f32[2][2] + [ + [3.0, 1.0], + [1.0, 2.0] + ] +> +iex(130)> b = Nx.tensor([9.0, 8.0]) +#Nx.Tensor< + f32[2] + [9.0, 8.0] +> +iex(132)> Nx.LinAlg.solve(a, b) +#Nx.Tensor< + f32[2] + [2.0, 3.0] +> # Eigenvalues and Eigenvectors -Nx.LinAlg.eigh(a) +iex(135)> Nx.LinAlg.eigh(a) +{#Nx.Tensor< + f32[2] + [3.618025779724121, 1.381974220275879] + >, + #Nx.Tensor< + f32[2][2] + [ + [0.8516583442687988, -0.5240974426269531], + [0.5240974426269531, 0.8516583442687988] + ] + >} ``` \ No newline at end of file diff --git a/nx/guides/getting_started/numerical_definitions.livemd b/nx/guides/getting_started/numerical_definitions.livemd index 910a5e4916..04ba8dadda 100644 --- a/nx/guides/getting_started/numerical_definitions.livemd +++ b/nx/guides/getting_started/numerical_definitions.livemd @@ -1,14 +1,6 @@ -<<<<<<< HEAD # Numerical definitions (defn) The `defn` macro simplifies the expression of mathematical formulas -======= -# Numerical Definitions (defn) - -## Section - -The `defn` macro and its siblings simplify the expression of mathematical formulas ->>>>>>> f9428e6e2b3fac14ea675ad71ab2f024b6d467b5 containing tensors. Numerical definitions have two primary benefits over classic Elixir functions. From a5f64689a28d2c1700f688ee2a8cb9e75cfc3a72 Mon Sep 17 00:00:00 2001 From: TomasPegado Date: Thu, 24 Apr 2025 18:08:25 -0300 Subject: [PATCH 33/36] another format fix --- nx/guides/cheatsheets/numpy_nx.cheatmd | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/nx/guides/cheatsheets/numpy_nx.cheatmd b/nx/guides/cheatsheets/numpy_nx.cheatmd index 1a4d2ea146..9476f4e3bb 100644 --- a/nx/guides/cheatsheets/numpy_nx.cheatmd +++ b/nx/guides/cheatsheets/numpy_nx.cheatmd @@ -278,11 +278,11 @@ EighResult(eigenvalues=array([1.38196601, 3.61803399]), eigenvectors=array([[ 0. ### Elixir ```elixir -iex(108)> Mix.install([:nx]) +iex> Mix.install([:nx]) :ok # Matrix Multiplication -iex(111)> a = Nx.tensor([[1, 2], [3, 4]]) +iex> a = Nx.tensor([[1, 2], [3, 4]]) #Nx.Tensor< s32[2][2] [ @@ -290,7 +290,7 @@ iex(111)> a = Nx.tensor([[1, 2], [3, 4]]) [3, 4] ] > -iex(112)> b = Nx.tensor([[5, 6], [7, 8]]) +iex> b = Nx.tensor([[5, 6], [7, 8]]) #Nx.Tensor< s32[2][2] [ @@ -298,7 +298,7 @@ iex(112)> b = Nx.tensor([[5, 6], [7, 8]]) [7, 8] ] > -iex(114)> Nx.dot(a, b) +iex> Nx.dot(a, b) #Nx.Tensor< s32[2][2] [ @@ -308,7 +308,7 @@ iex(114)> Nx.dot(a, b) > # Transpose -iex(117)> Nx.transpose(a) +iex> Nx.transpose(a) #Nx.Tensor< s32[2][2] [ @@ -318,7 +318,7 @@ iex(117)> Nx.transpose(a) > # Identity Matrix -iex(120)> Nx.eye({3, 3}) +iex> Nx.eye({3, 3}) #Nx.Tensor< s32[3][3] [ @@ -329,14 +329,14 @@ iex(120)> Nx.eye({3, 3}) > # Determinant -iex(123)> Nx.LinAlg.determinant(a) +iex> Nx.LinAlg.determinant(a) #Nx.Tensor< f32 -2.0 > # Inverse -iex(126)> Nx.LinAlg.invert(a) +iex> Nx.LinAlg.invert(a) #Nx.Tensor< f32[2][2] [ @@ -346,7 +346,7 @@ iex(126)> Nx.LinAlg.invert(a) > # Solve a System of Linear Equations -iex(129)> a = Nx.tensor([[3.0, 1.0], [1.0, 2.0]]) +iex> a = Nx.tensor([[3.0, 1.0], [1.0, 2.0]]) #Nx.Tensor< f32[2][2] [ @@ -354,19 +354,19 @@ iex(129)> a = Nx.tensor([[3.0, 1.0], [1.0, 2.0]]) [1.0, 2.0] ] > -iex(130)> b = Nx.tensor([9.0, 8.0]) +iex> b = Nx.tensor([9.0, 8.0]) #Nx.Tensor< f32[2] [9.0, 8.0] > -iex(132)> Nx.LinAlg.solve(a, b) +iex> Nx.LinAlg.solve(a, b) #Nx.Tensor< f32[2] [2.0, 3.0] > # Eigenvalues and Eigenvectors -iex(135)> Nx.LinAlg.eigh(a) +iex> Nx.LinAlg.eigh(a) {#Nx.Tensor< f32[2] [3.618025779724121, 1.381974220275879] From 42c7430153948e9a73060701fba0446f3a48a91e Mon Sep 17 00:00:00 2001 From: TomasPegado Date: Fri, 25 Apr 2025 18:04:50 -0300 Subject: [PATCH 34/36] adding blank-spaces for alignments --- nx/guides/cheatsheets/numpy_nx.cheatmd | 157 ++++++++++++++++++++++++- 1 file changed, 151 insertions(+), 6 deletions(-) diff --git a/nx/guides/cheatsheets/numpy_nx.cheatmd b/nx/guides/cheatsheets/numpy_nx.cheatmd index 9476f4e3bb..5d4e37907b 100644 --- a/nx/guides/cheatsheets/numpy_nx.cheatmd +++ b/nx/guides/cheatsheets/numpy_nx.cheatmd @@ -11,29 +11,56 @@ specifically by providing equivalent commands and code examples between NumPy an ```python >>> import numpy as np + # From list or nested list >>> np.array([1, 2, 3]) array([1, 2, 3]) + + + + >>> np.array([[1, 2], [3, 4]]) array([[1, 2], [3, 4]]) + + + + + # zeros and ones >>> np.zeros((2, 3)) array([[0., 0., 0.], [0., 0., 0.]]) + + + + + + >>> np.ones((2, 3)) array([[1., 1., 1.], [1., 1., 1.]]) + + + + + # Range of Numbers (like range()) >>> np.arange(0, 10, 2) array([0, 2, 4, 6, 8]) + + + # Linearly Spaced Values >>> np.linspace(0, 1, 5) array([0. , 0.25, 0.5 , 0.75, 1. ]) + + + ``` ### Elixir @@ -48,6 +75,7 @@ iex> Nx.tensor([1, 2, 3]) s32[3] [1, 2, 3] > + iex> Nx.tensor([[1, 2], [3, 4]]) #Nx.Tensor< s32[2][2] @@ -66,6 +94,7 @@ iex> Nx.broadcast(0, {2, 3}) [0, 0, 0] ] > + iex> Nx.broadcast(1, {2, 3}) #Nx.Tensor< s32[2][3] @@ -88,7 +117,6 @@ iex> Nx.iota({5}) |> Nx.divide(4) f32[5] [0.0, 0.25, 0.5, 0.75, 1.0] > - ``` ## Tensor Inspection @@ -99,8 +127,17 @@ iex> Nx.iota({5}) |> Nx.divide(4) ```python >>> import numpy as np + # Shape >>> a = np.array([[1, 2, 3], [4, 5, 6]]) + + + + + + + + >>> a.shape (2, 3) @@ -132,6 +169,7 @@ iex> a = Nx.tensor([[1, 2, 3], [4, 5, 6]]) [4, 5, 6] ] > + iex> Nx.shape(a) {2, 3} @@ -155,25 +193,59 @@ iex> Nx.size(a) ```python >>> import numpy as np + # Indexing a Single Element >>> a = np.array([[10, 20], [30, 40]]) + + + + + + + + >>> a[0, 1] np.int64(20) + + + # Slicing a Range >>> a = np.array([10, 20, 30, 40, 50]) + + + + + >>> a[1:4] array([20, 30, 40]) + + + # Selecting Along a Specific Axis >>> a = np.array([[1, 2], [3, 4], [5, 6]]) + + + + + + + + + >>> a[:, 1] array([2, 4, 6]) + + + # Boolean Masking >>> x = np.arange(10) >>> x[x % 2 == 0] array([0, 2, 4, 6, 8]) + + ``` ### Elixir @@ -190,6 +262,7 @@ iex> tensor = Nx.tensor([[10, 20], [30, 40]]) [30, 40] ] > + iex> tensor[[0, 1]] #Nx.Tensor< s32 @@ -202,6 +275,7 @@ iex> a = Nx.tensor([10, 20, 30, 40, 50]) s32[5] [10, 20, 30, 40, 50] > + iex> a[1..3] #Nx.Tensor< s32[3] @@ -218,6 +292,7 @@ iex> a = Nx.tensor([[1, 2], [3, 4], [5, 6]]) [5, 6] ] > + iex> a[[.., 1]] #Nx.Tensor< s32[3] @@ -225,9 +300,10 @@ iex> a[[.., 1]] > # Boolean Masking -# requires dynamic shape behavior, which is not directly supported -# in Nx because Nx compiles all operations ahead-of-time (like XLA or JAX), -# and tensors must have static shapes. +# requires dynamic shape behavior, which is not directly +# supported in Nx because Nx compiles all operations +# ahead-of-time (like XLA or JAX), and tensors must have +# static shapes. ``` ## Linear Algebra Operations @@ -237,43 +313,108 @@ iex> a[[.., 1]] ```python >>> import numpy as np + # Matrix Multiplication >>> A = np.array([[1, 2], [3, 4]]) + + + + + + + + >>> B = np.array([[5, 6], [7, 8]]) + + + + + + + + >>> np.matmul(A, B) array([[19, 22], [43, 50]]) + + + + + # Transpose >>> A.T array([[1, 3], [2, 4]]) + + + + + # Identity Matrix >>> np.eye(3) array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) + + + + + # Determinant >>> np.linalg.det(A) np.float64(-2.0000000000000004) + + + # Inverse >>> np.linalg.inv(A) array([[-2. , 1. ], [ 1.5, -0.5]]) + + + + + # Solve a System of Linear Equations >>> A = np.array([[3, 1], [1, 2]]) + + + + + + + + >>> b = np.array([9, 8]) + + + + + >>> np.linalg.solve(A, b) array([2., 3.]) + + + # Eigenvalues and Eigenvectors >>> np.linalg.eigh(A) -EighResult(eigenvalues=array([1.38196601, 3.61803399]), eigenvectors=array([[ 0.52573111, -0.85065081], - [-0.85065081, -0.52573111]])) +EighResult( + eigenvalues=array([1.38196601, 3.61803399]), + eigenvectors=array([ + [ 0.52573111, -0.85065081], + [-0.85065081, -0.52573111] + ])) + + + + + ``` ### Elixir @@ -290,6 +431,7 @@ iex> a = Nx.tensor([[1, 2], [3, 4]]) [3, 4] ] > + iex> b = Nx.tensor([[5, 6], [7, 8]]) #Nx.Tensor< s32[2][2] @@ -298,6 +440,7 @@ iex> b = Nx.tensor([[5, 6], [7, 8]]) [7, 8] ] > + iex> Nx.dot(a, b) #Nx.Tensor< s32[2][2] @@ -354,11 +497,13 @@ iex> a = Nx.tensor([[3.0, 1.0], [1.0, 2.0]]) [1.0, 2.0] ] > + iex> b = Nx.tensor([9.0, 8.0]) #Nx.Tensor< f32[2] [9.0, 8.0] > + iex> Nx.LinAlg.solve(a, b) #Nx.Tensor< f32[2] From b6d1e0eb8dff775bfb39db75678b1cfc095788dd Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 25 Apr 2025 18:56:39 -0300 Subject: [PATCH 35/36] refactor: use H4 trick formatting --- nx/guides/cheatsheets/numpy_nx.cheatmd | 533 ++++++++++--------------- 1 file changed, 206 insertions(+), 327 deletions(-) diff --git a/nx/guides/cheatsheets/numpy_nx.cheatmd b/nx/guides/cheatsheets/numpy_nx.cheatmd index 5d4e37907b..dd8bddf97c 100644 --- a/nx/guides/cheatsheets/numpy_nx.cheatmd +++ b/nx/guides/cheatsheets/numpy_nx.cheatmd @@ -1,81 +1,37 @@ -# Numpy -> Nx +# NumPy -> Nx -This cheatsheet is designed to assist Python developers in transitioning to Elixir, +This cheatsheet is designed to assist Python developers in transitioning to Elixir, specifically by providing equivalent commands and code examples between NumPy and Nx. ## Tensor Creation {: .col-2} -### Python - +### From list or nested list +#### NumPy ```python ->>> import numpy as np - - -# From list or nested list >>> np.array([1, 2, 3]) array([1, 2, 3]) - - - - ->>> np.array([[1, 2], [3, 4]]) -array([[1, 2], - [3, 4]]) - - - - - - -# zeros and ones ->>> np.zeros((2, 3)) -array([[0., 0., 0.], - [0., 0., 0.]]) - - - - - - ->>> np.ones((2, 3)) -array([[1., 1., 1.], - [1., 1., 1.]]) - - - - - - -# Range of Numbers (like range()) ->>> np.arange(0, 10, 2) -array([0, 2, 4, 6, 8]) - - - - -# Linearly Spaced Values ->>> np.linspace(0, 1, 5) -array([0. , 0.25, 0.5 , 0.75, 1. ]) - - - - ``` -### Elixir - +#### Nx ```elixir -iex> Mix.install([:nx, "~> 0.9"]) -:ok - -# From list or nested list iex> Nx.tensor([1, 2, 3]) #Nx.Tensor< s32[3] [1, 2, 3] > +``` +### 2D Arrays/Tensors +#### NumPy +```python +>>> np.array([[1, 2], [3, 4]]) +array([[1, 2], + [3, 4]]) +``` + +#### Nx +```elixir iex> Nx.tensor([[1, 2], [3, 4]]) #Nx.Tensor< s32[2][2] @@ -84,8 +40,23 @@ iex> Nx.tensor([[1, 2], [3, 4]]) [3, 4] ] > +``` + +### Zeros and Ones + +#### NumPy +```python +>>> np.zeros((2, 3)) +array([[0., 0., 0.], + [0., 0., 0.]]) + +>>> np.ones((2, 3)) +array([[1., 1., 1.], + [1., 1., 1.]]) +``` -# zeros and ones +#### NumPy +```elixir iex> Nx.broadcast(0, {2, 3}) #Nx.Tensor< s32[2][3] @@ -103,15 +74,35 @@ iex> Nx.broadcast(1, {2, 3}) [1, 1, 1] ] > +``` + +### Range of Numbers + +#### NumPy +```python +>>> np.arange(0, 10, 2) +array([0, 2, 4, 6, 8]) +``` -# Range of Numbers (like range()) +#### NumPy +```elixir iex> Nx.iota({5}, axis: 0) |> Nx.multiply(2) #Nx.Tensor< s32[5] [0, 2, 4, 6, 8] > +``` + +### Linearly Spaced Values -# Linearly Spaced Values +#### NumPy +```python +>>> np.linspace(0, 1, 5) +array([0. , 0.25, 0.5 , 0.75, 1. ]) +``` + +#### NumPy +```elixir iex> Nx.iota({5}) |> Nx.divide(4) #Nx.Tensor< f32[5] @@ -120,47 +111,18 @@ iex> Nx.iota({5}) |> Nx.divide(4) ``` ## Tensor Inspection -{: .col-2} - -### Python +{: .col-2-left} +### Shape +#### NumPy ```python ->>> import numpy as np - - -# Shape >>> a = np.array([[1, 2, 3], [4, 5, 6]]) - - - - - - - - >>> a.shape (2, 3) - -# Number of dimensions ->>> a.ndim -2 - -# Data Type ->>> a.dtype -dtype('int64') - -# Total Number of Elements ->>> a.size -6 ``` -### Elixir - +#### Nx ```elixir -iex> Mix.install([:nx]) -:ok - -# Shape iex> a = Nx.tensor([[1, 2, 3], [4, 5, 6]]) #Nx.Tensor< s32[2][3] @@ -169,19 +131,48 @@ iex> a = Nx.tensor([[1, 2, 3], [4, 5, 6]]) [4, 5, 6] ] > - iex> Nx.shape(a) {2, 3} +``` + +### Number of dimensions + +#### NumPy +```python +>>> a.ndim +2 +``` -# Number of dimensions -iex> Nx.rank(a) +#### Nx +```elixir +iex> Nx.rank(a) 2 +``` + +### Data Type + +#### NumPy +```python +>>> a.dtype +dtype('int64') +``` -# Data Type -iex> Nx.type(a) +#### Nx +```elixir +iex> Nx.type(a) {:s, 32} +``` + +### Total Number of Elements -# Total Number of Elements +#### NumPy +```python +>>> a.size +6 +``` + +#### Nx +```elixir iex> Nx.size(a) 6 ``` @@ -189,258 +180,94 @@ iex> Nx.size(a) ## Indexing and Slicing {: .col-2} -### Python +### Indexing a Single Element +#### NumPy ```python ->>> import numpy as np - - -# Indexing a Single Element >>> a = np.array([[10, 20], [30, 40]]) - - - - - - - - ->>> a[0, 1] +>>> a[0, 1] np.int64(20) - - - - -# Slicing a Range ->>> a = np.array([10, 20, 30, 40, 50]) - - - - - ->>> a[1:4] -array([20, 30, 40]) - - - - -# Selecting Along a Specific Axis ->>> a = np.array([[1, 2], [3, 4], [5, 6]]) - - - - - - - - - ->>> a[:, 1] -array([2, 4, 6]) - - - - -# Boolean Masking ->>> x = np.arange(10) ->>> x[x % 2 == 0] -array([0, 2, 4, 6, 8]) - - ``` -### Elixir +#### Nx ```elixir -iex> Mix.install([:nx, "~> 0.9"]) -:ok - # Indexing a Single Element iex> tensor = Nx.tensor([[10, 20], [30, 40]]) -#Nx.Tensor< - s32[2][2] - [ - [10, 20], - [30, 40] - ] -> - iex> tensor[[0, 1]] #Nx.Tensor< s32 20 > +``` +### Slicing a Range +#### NumPy +```python +>>> a = np.array([10, 20, 30, 40, 50]) +>>> a[1:4] +array([20, 30, 40]) +``` + +#### Nx +```elixir # Slicing a Range iex> a = Nx.tensor([10, 20, 30, 40, 50]) -#Nx.Tensor< - s32[5] - [10, 20, 30, 40, 50] -> - iex> a[1..3] #Nx.Tensor< s32[3] [20, 30, 40] > +``` + +### Selecting Along a Specific Axis +#### NumPy +```python +>>> a = np.array([[1, 2], [3, 4], [5, 6]]) +>>> a[:, 1] +array([2, 4, 6]) +``` +#### Nx +```elixir # Selecting Along a Specific Axis iex> a = Nx.tensor([[1, 2], [3, 4], [5, 6]]) -#Nx.Tensor< - s32[3][2] - [ - [1, 2], - [3, 4], - [5, 6] - ] -> - iex> a[[.., 1]] #Nx.Tensor< s32[3] [2, 4, 6] > +``` -# Boolean Masking -# requires dynamic shape behavior, which is not directly -# supported in Nx because Nx compiles all operations -# ahead-of-time (like XLA or JAX), and tensors must have -# static shapes. +### Boolean Masking +#### NumPy +```python +>>> x = np.arange(10) +>>> x[x % 2 == 0] +array([0, 2, 4, 6, 8]) ``` +#### Nx + +Boolean masking requires dynamic shape behavior, which is not +supported in Nx because Nx compiles all operations +ahead-of-time (like XLA or Jax), and for that, tensors must have static shapes. + ## Linear Algebra Operations {: .col-2} -### Python +### Matrix Multiplication +#### NumPy ```python ->>> import numpy as np - - -# Matrix Multiplication >>> A = np.array([[1, 2], [3, 4]]) - - - - - - - - >>> B = np.array([[5, 6], [7, 8]]) - - - - - - - - >>> np.matmul(A, B) array([[19, 22], [43, 50]]) - - - - - - -# Transpose ->>> A.T -array([[1, 3], - [2, 4]]) - - - - - - -# Identity Matrix ->>> np.eye(3) -array([[1., 0., 0.], - [0., 1., 0.], - [0., 0., 1.]]) - - - - - - -# Determinant ->>> np.linalg.det(A) -np.float64(-2.0000000000000004) - - - - -# Inverse ->>> np.linalg.inv(A) -array([[-2. , 1. ], - [ 1.5, -0.5]]) - - - - - - -# Solve a System of Linear Equations ->>> A = np.array([[3, 1], [1, 2]]) - - - - - - - - ->>> b = np.array([9, 8]) - - - - - ->>> np.linalg.solve(A, b) -array([2., 3.]) - - - - -# Eigenvalues and Eigenvectors ->>> np.linalg.eigh(A) -EighResult( - eigenvalues=array([1.38196601, 3.61803399]), - eigenvectors=array([ - [ 0.52573111, -0.85065081], - [-0.85065081, -0.52573111] - ])) - - - - - ``` -### Elixir +#### Nx ```elixir -iex> Mix.install([:nx]) -:ok - -# Matrix Multiplication iex> a = Nx.tensor([[1, 2], [3, 4]]) -#Nx.Tensor< - s32[2][2] - [ - [1, 2], - [3, 4] - ] -> - iex> b = Nx.tensor([[5, 6], [7, 8]]) -#Nx.Tensor< - s32[2][2] - [ - [5, 6], - [7, 8] - ] -> - iex> Nx.dot(a, b) #Nx.Tensor< s32[2][2] @@ -449,8 +276,18 @@ iex> Nx.dot(a, b) [43, 50] ] > +``` + +### Transpose +#### NumPy +```python +>>> A.T +array([[1, 3], + [2, 4]]) +``` -# Transpose +#### Nx +```elixir iex> Nx.transpose(a) #Nx.Tensor< s32[2][2] @@ -459,8 +296,19 @@ iex> Nx.transpose(a) [2, 4] ] > +``` + +### Identity Matrix +#### NumPy +```python +>>> np.eye(3) +array([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]]) +``` -# Identity Matrix +#### Nx +```elixir iex> Nx.eye({3, 3}) #Nx.Tensor< s32[3][3] @@ -470,15 +318,34 @@ iex> Nx.eye({3, 3}) [0, 0, 1] ] > +``` + +### Determinant +#### NumPy +```python +>>> np.linalg.det(A) +np.float64(-2.0000000000000004) +``` -# Determinant +#### Nx +```elixir iex> Nx.LinAlg.determinant(a) #Nx.Tensor< f32 -2.0 > +``` + +### Inverse +#### NumPy +```python +>>> np.linalg.inv(A) +array([[-2. , 1. ], + [ 1.5, -0.5]]) +``` -# Inverse +#### Nx +```elixir iex> Nx.LinAlg.invert(a) #Nx.Tensor< f32[2][2] @@ -487,30 +354,42 @@ iex> Nx.LinAlg.invert(a) [1.5000004768371582, -0.5000002384185791] ] > +``` -# Solve a System of Linear Equations -iex> a = Nx.tensor([[3.0, 1.0], [1.0, 2.0]]) -#Nx.Tensor< - f32[2][2] - [ - [3.0, 1.0], - [1.0, 2.0] - ] -> +### Solve a System of Linear Equations +#### NumPy +```python +>>> A = np.array([[3, 1], [1, 2]]) +>>> b = np.array([9, 8]) +>>> np.linalg.solve(A, b) +array([2., 3.]) +``` +#### Nx +```elixir +iex> a = Nx.tensor([[3.0, 1.0], [1.0, 2.0]]) iex> b = Nx.tensor([9.0, 8.0]) -#Nx.Tensor< - f32[2] - [9.0, 8.0] -> - iex> Nx.LinAlg.solve(a, b) #Nx.Tensor< f32[2] [2.0, 3.0] > +``` + +### Eigenvalues and Eigenvectors +#### NumPy +```python +>>> np.linalg.eigh(A) +EighResult( + eigenvalues=array([1.38196601, 3.61803399]), + eigenvectors=array([ + [ 0.52573111, -0.85065081], + [-0.85065081, -0.52573111] + ])) +``` -# Eigenvalues and Eigenvectors +#### Nx +```elixir iex> Nx.LinAlg.eigh(a) {#Nx.Tensor< f32[2] From 103a1af6c025c3d2f716d7d8a2d9cba62e994985 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 25 Apr 2025 18:57:58 -0300 Subject: [PATCH 36/36] formatting --- nx/mix.exs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nx/mix.exs b/nx/mix.exs index 7b71db5a1c..e76161f5d1 100644 --- a/nx/mix.exs +++ b/nx/mix.exs @@ -121,7 +121,7 @@ defmodule Nx.MixProject do ], groups_for_extras: [ "Getting Started": ~r"^guides/getting_started/", - "Cheatsheets": ~r"^guides/cheatsheets/", + Cheatsheets: ~r"^guides/cheatsheets/", Exercises: ~r"^guides/exercises/", Advanced: ~r"^guides/advanced/" ]