Nx Library Scope

189 views
Skip to first unread message

Jason S

unread,
Feb 18, 2021, 9:46:15 AM2/18/21
to Numerical Elixir (Nx)
How specifically is Nx following jax/numpy or what is in scope for an Nx contribution?

Say I wanted to bring over the jax.random.* functions (specifically I need the beta distribution) should my workflow be:
  1. make sure Jax can do it (https://github.com/google/jax/blob/master/jax/_src/random.py#L812
  2. build it using nx primitives and see how far I get
  3. PR?
Is this more of an avenue for discussion per-module or is this something you'd maybe like to see be its own library kind of direction. 

Jason

smoriarity.5

unread,
Feb 18, 2021, 10:33:15 AM2/18/21
to Numerical Elixir (Nx)
This is a good question and I can see this going in a few directions. First, I think the overall goal is to avoid as much API bloat as possible. NumPy is pretty big, and then when you add some routines in SciPy that could arguably fit into the NumPy API, it's massive. I think in general the Nx API should be limited, but not at a cost of productivity. It would be pretty painful to have to reimplement beta, bernoulli, gamma, etc. every time you needed them.

So then this can go in two directions:

1) We implement some common distributions, probably under an `Nx.Random` namespace. The good thing is that all of these functions can be implemented in terms of other Nx primitives, so compiler writers wouldn't need to implement these routines. Another positive to this approach is that it would be easier (I think) to shell out to custom kernels if need be. I know JAX uses some custom CUDA PRNG kernels. XLA also has some routines specifically for PRNGs.

2) We shell this out to a library of numerical definitions. This avoids adding anything to the API, but anytime somebody wants to use common RNGs they have to bring in another dependency (maybe not that bad). This keeps the API slim, but might be kind of annoying from a productivity perspective. You also lose out on the ability to (easily) plug custom kernels where necessary.

We could do something in between and add primitives for making PRNG Keys and some other random primitives. I could see stuff like `jax.random.shuffle` being annoying to implement as a numerical definition. I'm kind of in between both of these choices. I'd be curious to hear other opinions.

Sean

Jason S

unread,
Feb 18, 2021, 10:48:34 AM2/18/21
to elix...@googlegroups.com
Yea doing custom PRNG's might be out of MY depth, for my use case using the built in normal/normal_uniform is adequate. I am fairly confident I could port some of this from Jax to NX, for my project it would be adequate to move over what I need. 

But I was curious of how far Nx is planning on going into the Numpy direction because numpy is big. Pretty neat how the jax code looks similar to how you'd implement it in javascript/elixir/c proper anyways. 

--
You received this message because you are subscribed to a topic in the Google Groups "Numerical Elixir (Nx)" group.
To unsubscribe from this topic, visit https://groups.google.com/d/topic/elixir-nx/93SsPPlkUnU/unsubscribe.
To unsubscribe from this group and all its topics, send an email to elixir-nx+...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/elixir-nx/ce7c8315-6ed0-4cb3-bb79-326cb23a4d36n%40googlegroups.com.

José Valim

unread,
Feb 18, 2021, 11:00:13 AM2/18/21
to elix...@googlegroups.com
> 1) We implement some common distributions, probably under an `Nx.Random` namespace.

I like this idea but I have some questions. We will have random functions in Nx and some in Nx.Random. Maybe this can be a feature as it identifies that everything in Nx goes directly to a backend/compiler and Nx.Random is built on top of it but this poses a couple challenges:

1. Nx.mean is built on top of existing primitives. Does it mean it should go to a separate module?

2. If in the future we port Nx.Random.random_whatever to Nx.random_whatever so it can have a compiler/backend specific implementation, it becomes duplicated. But maybe that's not an issue exactly because of the above (and we will potentially phase out the Nx.Random impl anyway)

It is also worth adding that, if the concern is discoverability, ExDoc allows us to group functions in the sidebar by specific concerns (like Linalg) even if they are all in the same module.

You received this message because you are subscribed to the Google Groups "Numerical Elixir (Nx)" group.
To unsubscribe from this group and stop receiving emails from it, send an email to elixir-nx+...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/elixir-nx/CAPQ-sG-rGrzUb_R6K-kgQk-1f%3DkeOFuQOQ3%2B4o1yxND_Q3_OKA%40mail.gmail.com.

Jeff Smith

unread,
Feb 19, 2021, 7:16:48 PM2/19/21
to Numerical Elixir (Nx)
Sharing two PT examples of relevance.

PT does keep random in the same codebase at the same level in the hierarchy as core concepts like tensors: https://pytorch.org/docs/stable/random.html

PT also allows for the registration of custom RNGs and provides a crypto-secure one as a separately installable component: https://github.com/pytorch/csprng

To  me, a tensor library probably needs RNG-type functionality as a first-class citizen within the tensor lib itself. That said, I could see achieving this via wrapping some other lib, if there's a good collection of existing implementations.

José Valim

unread,
Feb 20, 2021, 3:37:53 AM2/20/21
to elix...@googlegroups.com
Thanks Jeff! Sounds good to increase our random repertoire then. Let's focus on growing the pizza, then later we can discuss how to slice it.

In any case, I have tagged all functions in Nx and you can see in the image below an example of how we can still have classification without using multiple modules. A couple notes:

* I have kept "random" inside the creation functions for now but it is 1LOC to change it
* N-dimensional is the opposite of element-wise (dot, conv, etc) - maybe there is a better name but also a 1LOC fix :)
* It was easy to find a place for most functions, except "sort", "reverse", and "concatenate" (which I put in n-dim)

Screenshot 2021-02-20 at 09.33.25.png


Reply all
Reply to author
Forward
0 new messages