/r/JAX
Subreddit for the Machine Learning library JAX
r/jax is to discuss the machine learning library JAX
/r/JAX
Hi group,
I‘m searching for a multi-point boundary value problem (BVP) solver (comparable to the bvp4c/bvp5c solvers provided by MATLAB) two solve 8 coupled ODEs in 7-8 layers of an electrochemical cell.
I already realized that in Python some workarounds are required to use solve_bvp of the scipy package since it is originally designed for two-point BVPs. However, switching to scipy is not possible for me since I need the solver for a real-time application.
Does anybody know/heard about activities within the JAX environment? So far I have only seen this approach for the two-point BVP solver, but I‘m not able to convert it to the multi-point BVP case:
https://gist.github.com/RicardoDominguez/f013d21a5991e863ffcf9076f5b9b34d
Thank you very much! :)
I liked JAX both for its approach (FP) then for its speed. It was a little sad for me when i had to sacrifice the FP style with OO + transform (flax/haiku) or use callable objects (eqx).
I wanted to share with you a little library a wrote recently on my spare time. It’s called zephyr(link in comments) and it is built on top jax. You write in an FP style, you call models which are functions (not callable objects, if you ignore that in python type(function) is object).
It’s not perfect, like the lack of examples aside from the README, or a lack of RNN (havent had time yet). But i’m able to use it and am using it. I found it simple, everything is a function.
I hope you can take a look and hear some comments on how I can improve it! Thanks!
Hey guys, I'm a newbie in jax / flax, and I want to know other's opinion about changing linen -> nnx in flax. About it's usability changes, or about their decision, etc. Do you think it's a right decision to drop linen for a long term plan for better usability? thanks!
Hey r/JAX ! Just wanted to share something exciting for those of you working across multiple ML frameworks.
Ivy is a Python package that allows you to seamlessly convert ML models and code between frameworks like PyTorch, TensorFlow, JAX, and NumPy. With Ivy, you can take a model you’ve built in PyTorch and easily bring it over to JAX without needing to rewrite everything. Great for experimenting, collaborating, or deploying across different setups!
On top of that, we’ve just partnered with Kornia, a popular differentiable computer vision library built on PyTorch, so now Kornia can also be used in TensorFlow, JAX, and NumPy. You can check it out in the latest Kornia release (v0.7.4) with the new methods:
kornia.to_tensorflow()
kornia.to_jax()
kornia.to_numpy()
It’s all powered by Ivy’s transpiler to make switching frameworks seamless. Give it a try and let us know what you think!
pip install ivy
Happy experimenting!
I have a seemingly-simple 4x image upscaler model that's consuming 36GB of VRAM on a 48GB card.
When I profile the memory usage, 75% comes from `jax.image.resize` which I'm using to do a standard nearest-neighbor upscale prior to applying the convolutional network.
This strikes me as unreasonable. When I open one of the source images in GIMP, it claims that 14.5MB of memory are used, for instance.
Why would the resize function use 27GB?
My batch size is 10, and images are cropped to 700x700 and 1400x1400.
Here's my model:
from pathlib import Path
import shutil
from flax import nnx
from flax.training.train_state import TrainState
import jax
import jax.numpy as jnp
import optax
INTERMEDIATE_FEATS = 16
class Model(nnx.Module):
def __init__(self, rngs=nnx.Rngs):
self.deep = nnx.Conv(
in_features=INTERMEDIATE_FEATS,
out_features=INTERMEDIATE_FEATS,
kernel_size=(7, 7),
padding='SAME',
rngs=rngs,
)
self.deeper = nnx.Conv(
in_features=INTERMEDIATE_FEATS,
out_features=INTERMEDIATE_FEATS,
kernel_size=(5, 5),
padding='SAME',
rngs=rngs,
)
self.deepest = nnx.Conv(
in_features=INTERMEDIATE_FEATS,
out_features=3,
kernel_size=(3, 3),
padding='SAME',
rngs=rngs,
)
def __call__(self, x: jax.Array):
new_shape = (x.shape[0], x.shape[1] * 2,
x.shape[2] * 2, INTERMEDIATE_FEATS)
upscaled = jax.image.resize(x, new_shape, "nearest")
out = self.deep(upscaled)
out = self.deeper(out)
out = self.deepest(out)
return out
def apply_model(state: TrainState, X: jax.Array, Y: jax.Array):
"""Computes gradients, loss and accuracy for a single batch."""
def loss_fn(params):
preds = state.apply_fn(params, X)
loss = jnp.mean(optax.squared_error(preds, Y))
return loss, preds
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, preds), grads = grad_fn(state.params)
return grads, loss
def update_model(state: TrainState, grads):
return state.apply_gradients(grads=grads)
Thanks
Hi guys!
I am a clinical research associate and we are currently doing a study involving wound care for diabetic foot ulcers. We have three convenient offices around Jacksonville. This study provides wound care, weekly assessments from the physician, and compensation for travel. No healthcare needed.
Additionally, we are in desperate need of a wheelchair for a patient experiencing a huge healthcare disparity. If anyone has an extra or any ideas, please let me know.
Please contact me for more details. Please feel free to share this post!
I have the following code which is called within Jax.Lax.Scan. This is a part of Langevin Simulation and runs for pretty high amount of time. The issue becomes with Jax it is taking for ever.
I found out I can use vectorization to make things faster but I can not do that for so many Jax transformation. Any help will be appreciated:
Bubble = namedtuple('Bubble', ['base', 'threshold', 'number_elements', 'start', 'end'])
@register_pytree_node_class
class BubbleMonitor(Monitor):
TRESHOLDS = jnp.array([i / 10 for i in range(5, 150, 5)]) # start=.5, end=10.5, step.5
TRESHOLD_SIZE = len(TRESHOLDS)
MIN_BUB_ELEM, MAX_BUB_ELEM = 3, 20
def __init__(self, dna):
super(BubbleMonitor, self).__init__(dna)
self.dna = dna
self.bubble_index_start, self.bubble_index_end, self.bubbles, self.max_elements_base,self.bubble_array = self.initialize_bubble()
def initialize_bubble(self):
bubble_index_start = 0
bubble_index_end = jnp.full((MAX_bases + 1, MAX_ELEMENTS, MAX_TRESHOLD), NO_BUBBLE)
bubble_array=jnp.full((self.dna.n_nt_bases, MIN_BUB_ELEM, TRESHOLD_SIZE), 0)
bubbles = jax.tree_util.tree_map(
lambda x: jnp.full(MAX_BUBBLES, x),
Bubble(base=-1, threshold=-1.0, number_elements=-1, start=-1, end=-1)
)
max_elements_base = jnp.full((MAX_bases + 1,), NO_ELEMENTS)
return bubble_index_start, bubble_index_end, bubbles, max_elements_base,bubble_array
def add_bubble(self, base, tr_i, tr, elements, step_global, state):
"""Add a bubble to the monitor using JAX-compatible transformations."""
bubble_index_start, bubble_index_end, bubbles, max_elements_base,bubble_array = state
add_condition = (elements >= MIN_ELEMENTS_PER_BUBBLE) & (elements<=self.dna.n_nt_bases) & (bubble_index_end[base, elements, tr_i] == NO_BUBBLE) & (bubble_index_start < MAX_BUBBLES)
def add_bubble_fn(state):
bubble_index_start, bubble_index_end, bubbles, max_elements_base,bubble_array = state
bubble_index_end = bubble_index_end.at[base, elements, tr_i].set(bubble_index_start)
# int_data=bubble_array.at[base, elements, tr_i] +1
bubble_array=bubble_array.at[base, elements, tr_i].add(1.0)
bubbles = bubbles._replace(
base=bubbles.base.at[bubble_index_start].set(base),
threshold=bubbles.threshold.at[bubble_index_start].set(tr),
number_elements=bubbles.number_elements.at[bubble_index_start].set(elements),
start=bubbles.start.at[bubble_index_start].set(step_global),
end=bubbles.end.at[bubble_index_start].set(NO_END),
)
max_elements_base = max_elements_base.at[base].max(elements)
return bubble_index_start + 1, bubble_index_end, bubbles, max_elements_base,bubble_array
# print("WE ARE COLLECTING BUBBELS",bubbles)
new_state = jax.lax.cond(add_condition, add_bubble_fn, lambda x: x, state)
return new_state
def close_bubbles(self, base, tr_i, elements, state,step_global):
"""Close bubbles that are still open and have more elements."""
bubble_index_start, bubble_index_end, bubbles, max_elements_base,bubble_array = state
def close_bubble_body_fn(elem_i, carry):
bubble_index_end, bubbles = carry
condition = (bubble_index_end[base, elem_i, tr_i] != NO_BUBBLE) & (bubbles.end[bubble_index_end[base, elem_i, tr_i]] == NO_END)
bubble_index_end = jax.lax.cond(
condition,
lambda bie: bie.at[base, elem_i, tr_i].set(NO_BUBBLE),
lambda bie: bie,
bubble_index_end
)
bubbles = jax.lax.cond(
condition,
lambda b: b._replace(end=b.end.at[bubble_index_end[base, elem_i, tr_i]].set(step_global)),
lambda b: b,
bubbles
)
return bubble_index_end, bubbles
bubble_index_end, bubbles = lax.fori_loop(
elements + 1, max_elements_base[base] + 1, close_bubble_body_fn, (bubble_index_end, bubbles)
)
return bubble_index_start, bubble_index_end, bubbles, max_elements_base,bubble_array
def find_bubbles(self, dna_state, step):
"""Find and manage bubbles based on the current simulation step."""
def base_loop_body(base, state):
bubble_index_start, bubble_index_end, bubbles, max_elements_base,bubble_array = state
def tr_loop_body(tr_i, state):
bubble_index_start, bubble_index_end, bubbles, max_elements_base,bubble_array = state
R = jnp.array(0, dtype=jnp.int32)
p = jnp.array(base, dtype=jnp.int32)
tr = self.TRESHOLDS[tr_i]
def while_body_fn(carry):
R, p, state = carry
bubble_index_start, bubble_index_end, bubbles, max_elements_base,bubble_array = state
R += 1
p = (base + R) % (self.dna.n_nt_bases + 1)
state = self.add_bubble(base, tr_i, tr, R, step, state)
return R, p, state
def while_cond_fn(carry):
R, p, _ = carry
return (dna_state['coords_distance'][p] >= tr) & (R <= self.dna.n_nt_bases)
R, p, state = lax.while_loop(
while_cond_fn,
while_body_fn,
(R, p, state)
)
state = self.close_bubbles(base, tr_i, R, state,step)
return state
state = lax.fori_loop(0, self.TRESHOLD_SIZE, tr_loop_body, state)
return state
state = (self.bubble_index_start, self.bubble_index_end, self.bubbles, self.max_elements_base,self.bubble_array)
state = lax.fori_loop(0, self.dna.n_nt_bases, base_loop_body, state)
# Unpack state after loop
self.bubble_index_start, self.bubble_index_end, self.bubbles, self.max_elements_base,self.bubble_array = state
return self.bubble_index_start, self.bubble_index_end, self.bubbles, self.max_elements_base,self.bubble_array
I am considering picking up Jax. Reading the documentation I see that Jax arrays are immutable.
Optimizing pipelines usually involve preallocating buffer arrays and performing in-place modifications to avoid memory (re)-allocation.
I'm not sure how I would avoid repetitive memory allocation in jax.
Is that somehow already taken care of ?
I authored a jupyter notebook, implementing homography in JAX. I am currently learning JAX, so any feedback would be helpful for me. Thanks!
https://github.com/kindoblue/homography-with-jax/blob/main/homography.ipynb
Hi, I'd like to introduce my toy project, JAxtar.
It's not code that many people will find useful, but I did most of the acrobatics with Jax while writing it, and I think it might inspire others who use Jax.
I wrote my master thesis on A* and neural heuristics for solving 15 puzzles, but when I reflected on it, the biggest headache was the high frequency and length of data transfers between the CPU and GPU. Almost half of the execution time was spent in these communication bottlenecks. Another solution to this problem was batched A* proposed by DeepCubeA, but I felt that it was not a complete solution.
I came across mctx one day, a mcts library written in pure jax by google deepmind.
I was fascinated by this approach and made many attempts to write A* in Jax, but was unsuccessful. The problem was the hashtable and priority queue.
After a long time after graduation, studying many examples, and brainfucking, I finally managed to write some working code.
There are a few special elements of this code that I'm proud to say are
I hope this project can serve as an inspiring example for anyone who enjoys Jax.
Basically what the title says. To me, JAX feels very much like a LISPy way of doing machine learning, so I was wondering if it has a port to some kind of LISP language.
Why?
tensorflow: Tensor
pytorch: Tensor
caffe2: Tensor
Theano: Tensor
jax: Array
It makes me want to from jax import Array as Tensor
Tensor is just such a badass well acepted name for a differenciable multidimensional array datastructure. Why did you did this? I'm going to make a pull request to add the Tensor class as some kind of alias or some kind factory of arrays.
Dear all. My main work is R&D in computer vision. I always used PyTorch (and TF before TF2) but was curious about Jax. Therefore I created my own library of layers / preset architectures called Jimmy (https://github.com/clementpoiret/jimmy/). It uses Flax (their new NNX API).
For the sake of learning, it implements ViTs, Mamba-1 and Mamba-2 based models, and some techniques I want to have fun with (Memory Efficient Sharpness Aware training, Layer Sharing).
As I'm quite new to Jax, my code might be too "PyTorch-like", so I am open to all advices, feedbacks, ideas of things to implement (methods, models, etc), etc. (Please don't really look at the way I save and load converted dinov2, I have to clean this part).
Also, if you have tips to enhance jit compile time, and overall compute performance, I am open!
So I downloaded jax from pypi without pip from the website I mean I installed it on tails os pleas help me
Hi,
I am currently working in a start-up which aims at discovering new materials through AI and an automated lab.
I am currently implementing a model we designed, which is going to be fairly complex - a transformer diffusion graph neural network. I am trying to choose which neural network library I will be using. I will be using JAX as my automated differentiable backbone language.
There are two libraries which I hesitating from : flax.nnx and equinox.
Equinox seems to be fairly mature but I am a bit scared that it won't be maintained in future since Patrick Kidger seems to be the only real developer of this project. On an other hand flax.nnx seems to add an extra layer of abstraction on top of jax, where jax pytrees are exchanged for graphs, which they justify is necessary in case of shared parameter representations.
What are your recommendations here? Thanks :)
Hi,
I use the clu
lib to track the metrics. I have a simple training step like https://flax.readthedocs.io/en/latest/guides/training_techniques/lr_schedule.html.
According to https://github.com/google/CommonLoopUtils/blob/main/clu/metrics.py#L661, a metrics.LastValue
can help me collect the last learning rate. But I could not find out how to implement it.
Help please!🙏
Hi,
Does jax or any ML tools can help me test if the hardware support bfloat16
natively?
I have a rtx 2070 and it does not support bfloat16. But if I create a code to use bfloat16, it still runs. I think the hardware will treat it as normal float16.
It would be nice if I can detect it and apply the right dtype
programmatically.
I have the following code to start with:
from functools import partial
from jax import jit
import jax
import jax.numpy as jnp
class Counter:
def __init__(self, count):
self.count = count
def add(self, number):
# Return a new Counter instance with updated count
self.count += number
from jax import jit
import jax.numpy as jnp
import jax
def execute(counter, steps):
for _ in range(steps):
counter.add(steps)
print(counter.count)
counter = Counter(0)
execute(counter, 10)
How can I replace the functionality with jax.lax.scan or jax.fori_loop?
I know there are ways to achieve similar functionality but I need this for another project and its not possible to write it here .
I have a scenario where I want to run MCMC simulation on some protein sequences.
I have a code working that is written in JAX. My target is to run 100 independent simulation for each sequence and I need to do it for millions of sequences. I have my hand on a supercomputer where each node has 4 80GB GPUs. I want to leverage the GPUs and make computation faster. I am not sure how can I achieve the parallelism. I tried using PMAP but it only allows to use 4 parallel simulations. This is still taking a lot of time. I am not sure how can I achieve faster computation by leveraging the hardware that I have.
One of my ideas was to VMAP the sequences and PMAP the parallel execution. Is it a correct approach?
My current implementation uses joblib to run parallel execution but it is not very good at GPU utilization.
I am doing a research project in RL and need an environment where agents can show diverse behaviours / there are many ways of achieving the goal that are qualitatively different. Think like starcraft or fortnite in terms of diversity of play styles where you can be effective with loads of different strategies - though it would be amazing if it is a single agent game as well as multiagent RL is beyond the scope.
I am planning on doing everything in JAX because I need to be super duper efficient.
Does anyone have a suggestion about a good environment to use? I am already looking at gymnax, XLand-Mini, Jumanji
Thanks!!!
Hi all,
I am a traditional SDE and I am pretty new to JAX but I do have great interest about JAX and GPU resource allocation and accelerations. Wanted to get some expert suggestions on what I can do to learn more about this stuff. Thank you so much!
Hey fellow JAX enthusiasts,
I'm currently working on a project that involves repeated interpolation of values, and I'm running into some performance issues. The current process involves loading grid values from a file and then interpolating them in each iteration. Unfortunately, the constant loading and data transfer between host and device is causing a significant bottleneck.
I've thought about utilizing the constant memory on NVIDIA GPUs to store my grid, but I'm unsure how to implement this or if it's even the best solution. Moreover, I'm stumped on how to optimize this process for TPUs.
If anyone has experience with similar challenges or can offer suggestions on how to overcome this performance overhead, I'd greatly appreciate it! Some potential solutions I'm open to exploring include:
Thanks in advance for your input!
I am considering moving some Pytorch projects to JAX, since the speed up I see in toy problems is big. However, my projects involve optimizing matrices that are symmetric positive definite (SPD). For this, I use geotorch in Pytorch, which does Riemannian gradient descent and works like a charm. In JAX, however, I don't see a clear option of a package to use for this.
One option is Pymanopt, which supports JAX, but it seems like you can't use jit (at least out of the box) with Pymanopt. Another option is Rieoptax, but it seems like it is not being maintained. I haven't found any other options. Any suggestions of what are my available options?
It is my understanding that symbolic differentiation is when a new function is created (manually or by a program) that can compute the gradient of the function whereas in case of automatic differentiation, there is no explicit function to compute gradient. Computation graph of original function in terms of arithmetic operations is used along with sum & product rules for elementary operations.
Based in this understanding, isn’t “grad” using symbolic differentiation. Jax claims that this is automatic differentiation.
0507c64e7e34b13629c6ff03dff6b5481faf718db5509988465b02178fce3ce310