/r/JAX

Photograph via snooOG

Subreddit for the Machine Learning library JAX

r/jax is to discuss the machine learning library JAX

/r/JAX

1,498 Subscribers

1

Is there a flax (or jax generally) equivalent to something like lucidrains' x-transformers?

Basically I want a 'modern sequence-modelling function approximator' but I don't care about digging into all the implementation details and keeping up with the current sota on everything. I wanna just wrap it in a box and move on.

Is there a library that abstracts this away?

0 Comments
2024/10/04
17:38 UTC

3

Jax nested loops: taking for-ever. Need help with Vectorization

I have the following code which is called within Jax.Lax.Scan. This is a part of Langevin Simulation and runs for pretty high amount of time. The issue becomes with Jax it is taking for ever.
I found out I can use vectorization to make things faster but I can not do that for so many Jax transformation. Any help will be appreciated:

Bubble = namedtuple('Bubble', ['base', 'threshold', 'number_elements', 'start', 'end'])

@register_pytree_node_class
class BubbleMonitor(Monitor):
    TRESHOLDS = jnp.array([i / 10 for i in range(5, 150, 5)])  # start=.5, end=10.5, step.5
    TRESHOLD_SIZE = len(TRESHOLDS)
    MIN_BUB_ELEM, MAX_BUB_ELEM = 3, 20

    def __init__(self, dna):
        super(BubbleMonitor, self).__init__(dna)
        self.dna = dna
        self.bubble_index_start, self.bubble_index_end, self.bubbles, self.max_elements_base,self.bubble_array = self.initialize_bubble()

    def initialize_bubble(self):
        bubble_index_start = 0
        bubble_index_end = jnp.full((MAX_bases + 1, MAX_ELEMENTS, MAX_TRESHOLD), NO_BUBBLE)
        bubble_array=jnp.full((self.dna.n_nt_bases, MIN_BUB_ELEM, TRESHOLD_SIZE), 0)
        bubbles = jax.tree_util.tree_map(
            lambda x: jnp.full(MAX_BUBBLES, x),
            Bubble(base=-1, threshold=-1.0, number_elements=-1, start=-1, end=-1)
        )
        max_elements_base = jnp.full((MAX_bases + 1,), NO_ELEMENTS)
        return bubble_index_start, bubble_index_end, bubbles, max_elements_base,bubble_array

    def add_bubble(self, base, tr_i, tr, elements, step_global, state):
        """Add a bubble to the monitor using JAX-compatible transformations."""
        bubble_index_start, bubble_index_end, bubbles, max_elements_base,bubble_array = state

        add_condition = (elements >= MIN_ELEMENTS_PER_BUBBLE) & (elements<=self.dna.n_nt_bases) & (bubble_index_end[base, elements, tr_i] == NO_BUBBLE) & (bubble_index_start < MAX_BUBBLES)

        def add_bubble_fn(state):
            bubble_index_start, bubble_index_end, bubbles, max_elements_base,bubble_array = state

            bubble_index_end = bubble_index_end.at[base, elements, tr_i].set(bubble_index_start)
            # int_data=bubble_array.at[base, elements, tr_i] +1 
            bubble_array=bubble_array.at[base, elements, tr_i].add(1.0)
            bubbles = bubbles._replace(
                base=bubbles.base.at[bubble_index_start].set(base),
                threshold=bubbles.threshold.at[bubble_index_start].set(tr),
                number_elements=bubbles.number_elements.at[bubble_index_start].set(elements),
                start=bubbles.start.at[bubble_index_start].set(step_global),
                end=bubbles.end.at[bubble_index_start].set(NO_END),
            )
            max_elements_base = max_elements_base.at[base].max(elements)

            return bubble_index_start + 1, bubble_index_end, bubbles, max_elements_base,bubble_array
        # print("WE ARE COLLECTING BUBBELS",bubbles)

        new_state = jax.lax.cond(add_condition, add_bubble_fn, lambda x: x, state)

        return new_state 
    
    def close_bubbles(self, base, tr_i, elements, state,step_global):
        """Close bubbles that are still open and have more elements."""
        bubble_index_start, bubble_index_end, bubbles, max_elements_base,bubble_array = state

        def close_bubble_body_fn(elem_i, carry):
            bubble_index_end, bubbles = carry
            condition = (bubble_index_end[base, elem_i, tr_i] != NO_BUBBLE) & (bubbles.end[bubble_index_end[base, elem_i, tr_i]] == NO_END)
            
            bubble_index_end = jax.lax.cond(
                condition,
                lambda bie: bie.at[base, elem_i, tr_i].set(NO_BUBBLE),
                lambda bie: bie,
                bubble_index_end
            )
            
            bubbles = jax.lax.cond(
                condition,
                lambda b: b._replace(end=b.end.at[bubble_index_end[base, elem_i, tr_i]].set(step_global)),
                lambda b: b,
                bubbles
            )
            
            return bubble_index_end, bubbles

        bubble_index_end, bubbles = lax.fori_loop(
            elements + 1, max_elements_base[base] + 1, close_bubble_body_fn, (bubble_index_end, bubbles)
        )

        return bubble_index_start, bubble_index_end, bubbles, max_elements_base,bubble_array
    
    def find_bubbles(self, dna_state, step):
        """Find and manage bubbles based on the current simulation step."""

        def base_loop_body(base, state):
            bubble_index_start, bubble_index_end, bubbles, max_elements_base,bubble_array = state

            def tr_loop_body(tr_i, state):
                bubble_index_start, bubble_index_end, bubbles, max_elements_base,bubble_array = state
                R = jnp.array(0, dtype=jnp.int32)
                p = jnp.array(base, dtype=jnp.int32)
                tr = self.TRESHOLDS[tr_i]

                def while_body_fn(carry):
                    R, p, state = carry
                    bubble_index_start, bubble_index_end, bubbles, max_elements_base,bubble_array = state
                    R += 1
                    p = (base + R) % (self.dna.n_nt_bases + 1)
                    state = self.add_bubble(base, tr_i, tr, R, step, state)
                    return R, p, state

                def while_cond_fn(carry):
                    R, p, _ = carry
                    return (dna_state['coords_distance'][p] >= tr) & (R <= self.dna.n_nt_bases)

                R, p, state = lax.while_loop(
                    while_cond_fn,
                    while_body_fn,
                    (R, p, state)
                )

                state = self.close_bubbles(base, tr_i, R, state,step)
                return state
            
            state = lax.fori_loop(0, self.TRESHOLD_SIZE, tr_loop_body, state)
            return state

        state = (self.bubble_index_start, self.bubble_index_end, self.bubbles, self.max_elements_base,self.bubble_array)
        state = lax.fori_loop(0, self.dna.n_nt_bases, base_loop_body, state)

        # Unpack state after loop
        self.bubble_index_start, self.bubble_index_end, self.bubbles, self.max_elements_base,self.bubble_array = state

        return self.bubble_index_start, self.bubble_index_end, self.bubbles, self.max_elements_base,self.bubble_array

0 Comments
2024/09/30
03:25 UTC

6

Immutable arrays, how to optimize memory allocation ?

I am considering picking up Jax. Reading the documentation I see that Jax arrays are immutable.

Optimizing pipelines usually involve preallocating buffer arrays and performing in-place modifications to avoid memory (re)-allocation.

I'm not sure how I would avoid repetitive memory allocation in jax.

Is that somehow already taken care of ?

2 Comments
2024/09/25
19:21 UTC

4

Homography in JAX

I authored a jupyter notebook, implementing homography in JAX. I am currently learning JAX, so any feedback would be helpful for me. Thanks!

https://github.com/kindoblue/homography-with-jax/blob/main/homography.ipynb

0 Comments
2024/09/24
08:56 UTC

12

Sharing my toy project "JAxtar" the pure jax and jittable A* algorithm for puzzle solving

Hi, I'd like to introduce my toy project, JAxtar.

It's not code that many people will find useful, but I did most of the acrobatics with Jax while writing it, and I think it might inspire others who use Jax.

I wrote my master thesis on A* and neural heuristics for solving 15 puzzles, but when I reflected on it, the biggest headache was the high frequency and length of data transfers between the CPU and GPU. Almost half of the execution time was spent in these communication bottlenecks. Another solution to this problem was batched A* proposed by DeepCubeA, but I felt that it was not a complete solution.

I came across mctx one day, a mcts library written in pure jax by google deepmind.
I was fascinated by this approach and made many attempts to write A* in Jax, but was unsuccessful. The problem was the hashtable and priority queue.

After a long time after graduation, studying many examples, and brainfucking, I finally managed to write some working code.

There are a few special elements of this code that I'm proud to say are

  • a hash_func_builder for convert defined states to hash keys
  • a hashtable to lookup and insert in a parallel way
  • a priority queue that can be batched, pushed and popped
  • a fully jitted A* algorithm for puzzles.

I hope this project can serve as an inspiring example for anyone who enjoys Jax.

2 Comments
2024/09/03
14:37 UTC

1

Does JAX have a LISP port?

Basically what the title says. To me, JAX feels very much like a LISPy way of doing machine learning, so I was wondering if it has a port to some kind of LISP language.

0 Comments
2024/08/22
04:35 UTC

0

rant: Why Array instead of Tensor?

Why?

tensorflow: Tensor

pytorch: Tensor

caffe2: Tensor

Theano: Tensor

jax: Array

It makes me want to from jax import Array as Tensor

Tensor is just such a badass well acepted name for a differenciable multidimensional array datastructure. Why did you did this? I'm going to make a pull request to add the Tensor class as some kind of alias or some kind factory of arrays.

8 Comments
2024/08/20
22:59 UTC

6

Learning Jax best practices: what do you think about my toy library?

Dear all. My main work is R&D in computer vision. I always used PyTorch (and TF before TF2) but was curious about Jax. Therefore I created my own library of layers / preset architectures called Jimmy (https://github.com/clementpoiret/jimmy/). It uses Flax (their new NNX API).

For the sake of learning, it implements ViTs, Mamba-1 and Mamba-2 based models, and some techniques I want to have fun with (Memory Efficient Sharpness Aware training, Layer Sharing).

As I'm quite new to Jax, my code might be too "PyTorch-like", so I am open to all advices, feedbacks, ideas of things to implement (methods, models, etc), etc. (Please don't really look at the way I save and load converted dinov2, I have to clean this part).

Also, if you have tips to enhance jit compile time, and overall compute performance, I am open!

0 Comments
2024/08/15
11:58 UTC

0

I have a problem with jax

So I downloaded jax from pypi without pip from the website I mean I installed it on tails os pleas help me

0 Comments
2024/07/26
11:59 UTC

6

Best jax neural networks library for industrial projects

Hi,

I am currently working in a start-up which aims at discovering new materials through AI and an automated lab.

I am currently implementing a model we designed, which is going to be fairly complex - a transformer diffusion graph neural network. I am trying to choose which neural network library I will be using. I will be using JAX as my automated differentiable backbone language.

There are two libraries which I hesitating from : flax.nnx and equinox.

Equinox seems to be fairly mature but I am a bit scared that it won't be maintained in future since Patrick Kidger seems to be the only real developer of this project. On an other hand flax.nnx seems to add an extra layer of abstraction on top of jax, where jax pytrees are exchanged for graphs, which they justify is necessary in case of shared parameter representations.

What are your recommendations here? Thanks :)

3 Comments
2024/07/09
15:01 UTC

1

How to log learning rate during training?

Hi,

I use the clu lib to track the metrics. I have a simple training step like https://flax.readthedocs.io/en/latest/guides/training_techniques/lr_schedule.html.

According to https://github.com/google/CommonLoopUtils/blob/main/clu/metrics.py#L661, a metrics.LastValue can help me collect the last learning rate. But I could not find out how to implement it.

Help please!🙏

1 Comment
2024/06/06
02:20 UTC

5

Is there's a way to test if the GPU supports bfloat16?

Hi,

Does jax or any ML tools can help me test if the hardware support bfloat16 natively?

I have a rtx 2070 and it does not support bfloat16. But if I create a code to use bfloat16, it still runs. I think the hardware will treat it as normal float16.

It would be nice if I can detect it and apply the right dtype programmatically.

2 Comments
2024/06/05
06:35 UTC

1

How do I achieve this one in JAX? Jittable class method

I have the following code to start with:

from functools import partial
from jax import jit
import jax
import jax.numpy as jnp

class Counter:
    def __init__(self, count):
        self.count = count

    def add(self, number):
        # Return a new Counter instance with updated count
        self.count += number

from jax import jit
import jax.numpy as jnp
import jax


def execute(counter, steps):
    for _ in range(steps):
        counter.add(steps)
        print(counter.count)


counter = Counter(0)
execute(counter, 10)

How can I replace the functionality with jax.lax.scan or jax.fori_loop?

I know there are ways to achieve similar functionality but I need this for another project and its not possible to write it here .

0 Comments
2024/06/03
23:01 UTC

1

Independent parallel run : leveraging GPU

I have a scenario where I want to run MCMC simulation on some protein sequences.

I have a code working that is written in JAX. My target is to run 100 independent simulation for each sequence and I need to do it for millions of sequences. I have my hand on a supercomputer where each node has 4 80GB GPUs. I want to leverage the GPUs and make computation faster. I am not sure how can I achieve the parallelism. I tried using PMAP but it only allows to use 4 parallel simulations. This is still taking a lot of time. I am not sure how can I achieve faster computation by leveraging the hardware that I have.

One of my ideas was to VMAP the sequences and PMAP the parallel execution. Is it a correct approach?

My current implementation uses joblib to run parallel execution but it is not very good at GPU utilization.

1 Comment
2024/05/28
03:26 UTC

1

Jax Enabled Environments

I am doing a research project in RL and need an environment where agents can show diverse behaviours / there are many ways of achieving the goal that are qualitatively different. Think like starcraft or fortnite in terms of diversity of play styles where you can be effective with loads of different strategies - though it would be amazing if it is a single agent game as well as multiagent RL is beyond the scope.

I am planning on doing everything in JAX because I need to be super duper efficient.

Does anyone have a suggestion about a good environment to use? I am already looking at gymnax, XLand-Mini, Jumanji

Thanks!!!

0 Comments
2024/05/20
14:01 UTC

3

what should be the best resources to follow to learn Jax and GPU resources allocation and accelerations?

Hi all,

I am a traditional SDE and I am pretty new to JAX but I do have great interest about JAX and GPU resource allocation and accelerations. Wanted to get some expert suggestions on what I can do to learn more about this stuff. Thank you so much!

4 Comments
2024/05/11
01:30 UTC

1

Seeking optimization advice for interpolation-heavy computation

Hey fellow JAX enthusiasts,

I'm currently working on a project that involves repeated interpolation of values, and I'm running into some performance issues. The current process involves loading grid values from a file and then interpolating them in each iteration. Unfortunately, the constant loading and data transfer between host and device is causing a significant bottleneck.

I've thought about utilizing the constant memory on NVIDIA GPUs to store my grid, but I'm unsure how to implement this or if it's even the best solution. Moreover, I'm stumped on how to optimize this process for TPUs.

If anyone has experience with similar challenges or can offer suggestions on how to overcome this performance overhead, I'd greatly appreciate it! Some potential solutions I'm open to exploring include:

  • Optimizing data transfer and loading
  • Leveraging GPU/TPU architecture for faster computation
  • Alternative interpolation methods or libraries
  • Any other creative solutions you might have!

Thanks in advance for your input!

0 Comments
2024/04/23
16:47 UTC

2

Here's the key benchmark table from the link. The JAX backend on GPUs is fastest for 7 of 12 benchmarks, and the TensorFlow backend is fastest for the other 5 of the 12. The Pytorch backend is not the fastest for any benchmark, & is often slower by a considerable margin.

0 Comments
2024/03/31
12:30 UTC

5

Optimization on Manifolds with JAX?

I am considering moving some Pytorch projects to JAX, since the speed up I see in toy problems is big. However, my projects involve optimizing matrices that are symmetric positive definite (SPD). For this, I use geotorch in Pytorch, which does Riemannian gradient descent and works like a charm. In JAX, however, I don't see a clear option of a package to use for this.

One option is Pymanopt, which supports JAX, but it seems like you can't use jit (at least out of the box) with Pymanopt. Another option is Rieoptax, but it seems like it is not being maintained. I haven't found any other options. Any suggestions of what are my available options?

1 Comment
2024/03/26
14:09 UTC

2

Grad vs symbolic differentiation

It is my understanding that symbolic differentiation is when a new function is created (manually or by a program) that can compute the gradient of the function whereas in case of automatic differentiation, there is no explicit function to compute gradient. Computation graph of original function in terms of arithmetic operations is used along with sum & product rules for elementary operations.

Based in this understanding, isn’t “grad” using symbolic differentiation. Jax claims that this is automatic differentiation.

5 Comments
2024/03/17
11:51 UTC

2

Session m3ga

0507c64e7e34b13629c6ff03dff6b5481faf718db5509988465b02178fce3ce310

1 Comment
2024/03/06
07:49 UTC

3

JAX compared to PyTorch 2: Get a feeling for JAX!

0 Comments
2024/03/04
10:22 UTC

4

A JAX Based Library for training and inference of LLMs and Multi-modals on GPU, TPU

hi guys I have been working on a project named EasyDeL, an open-source library, that is specifically designed to enhance and streamline the training process of machine learning models. It focuses primarily on Jax/Flax and aims to provide convenient and effective solutions for training Flax/Jax Models on TPU/GPU for both Serving and Training purposes. Some of the key features provided by EasyDeL include

  • Serving and API Engines for Using and serving LLMs in JAX as efficiently as possible.
  • Support for 8, 6, and 4 BIT inference and training in JAX
  • A wide range of models in Jax is supported which have never been implemented before such as Falcon, Qwen2, Phi2, Mixtral, and MPT ...
  • Integration of flashAttention in JAX for GPUs and TPUs
  • Automatic serving of LLMs with mid and high-level APIs in both JAX and PyTorch
  • LLM Trainer and fine-tuner in JAX
  • Video CLM Trainer and Fine-tunerFalcon, Qwen2, Phi2, Mixtral, and MPT ...
  • RLHF (Reinforcement Learning from Human Feedback) in Jax (Beta Stage)
  • DPOTrainer(Supported) and SFTTrainer(Developing Stage)
  • Various other features to enhance the training process and optimize performance.
  • LoRA: Low-Rank Adaptation of Large Language Models
  • RingAttention, Flash Attention, BlockWise FFN, and Efficient Attention are supported for more than 90 % of models(FJFormer Backbone).
  • Serving and API Engines for Using and serving LLMs in JAX as efficiently as possible.
  • Automatic Converting Models from JAX-EasyDeL to PyTorch-HF and reverse

For more information, Documents, Examples, and use cases check https://github.com/erfanzar/EasyDeL I'll be happy to get any feedback or new ideas for new models or features.

0 Comments
2024/02/21
11:15 UTC

7

A Jax-based library for designing and training transformer models from scratch.

Hey guys, I just published the developer version of NanoDL, a library for developing transformer models within the Jax/Flax ecosystem and would love your feedback!

Key Features of NanoDL include:

  • A wide array of blocks and layers, facilitating the creation of customised transformer models from scratch.
  • An extensive selection of models like LlaMa2, Mistral, Mixtral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, GAT, CLIP, and more, catering to a variety of tasks and applications.
  • Data-parallel distributed trainers so developers can efficiently train large-scale models on multiple GPUs or TPUs, without the need for manual training loops.
  • Dataloaders, making the process of data handling for Jax/Flax more straightforward and effective.
  • Custom layers not found in Flax/Jax, such as RoPE, GQA, MQA, and SWin attention, allowing for more flexible model development.
  • GPU/TPU-accelerated classical ML models like PCA, KMeans, Regression, Gaussian Processes etc., akin to SciKit Learn on GPU.
  • Modular design so users can blend elements from various models, such as GPT, Mixtral, and LlaMa2, to craft unique hybrid transformer models.
  • A range of advanced algorithms for NLP and computer vision tasks, such as Gaussian Blur, BLEU etc.
  • Each model is contained in a single file with no external dependencies, so the source code can also be easily used.

Checkout the repository for sample usage and more details: https://github.com/HMUNACHI/nanodl

Ultimately, I want as many opinions as possible, next steps to consider, issues, even contributions.

Note: I am working on the readme docs. For now, in the source codes, I include a comprehensive example on top of each model file in comments.

0 Comments
2024/02/08
10:18 UTC

2

JAX static arguments error

I have a function:

from jax import numpy as jnp
@partial(jit, static_argnums=(2, 3, 4, 5))
def f(a, b, c, d, e, f):
    # do something
    return # something

I want to set say c, d, e, f as static variables as it doesn't change (Config variables). Here c and d are jnp.ndarray. While e and f are float. I get an error:
ValueError: Non-hashable static arguments are not supported. An error occurred during a call to 'f' while trying to hash an object of type <class 'jaxlib.xla_extension.ArrayImpl'>, [1. 1.]. The error was:

TypeError: unhashable type: 'ArrayImpl'

If I don't set c and d as a static variables, I can run it without errors. How do I set c and d to be static variables?

I can provide any more info if needed. Thanks in advance.

3 Comments
2023/12/19
01:46 UTC

1

JAX or TensorFlow?

Question: What should I use JAX or TensorFlow?

Context: I am working on a research project that is related to Mergers of Black Holes. There is a code base that uses numpy at the backend to perform number crunching. But is slow therefore we have to shift to another code base that utilizes GPU/TPU effectively. Note that this is a research project therefore the codebase will likely be changed over the years by the researchers. I have to make the same number crunching code but using JAX, a friend has to make Bayesian Neural Net which will later be integrated with my code. I want him to work on JAX or any other pure JAX-based framework, but he is stuck on using TensorFlow. What should be the rational decision here?

7 Comments
2023/11/27
07:31 UTC

3

Learning resources?

Does anyone know of a good quickstart, tutorial, or curriculum for learning jax? I need to use it in a new project, and I'd like to get an overview of the whole language before getting started.

3 Comments
2023/11/04
17:26 UTC

2

Unable to create model in Jax

Hello, I'm trying to run code written by Google, but after following their directions for installing Jax/Flax and running their code, I keep on getting an error:

rng, init_rng, model_rng, dropout_rng = jax.random.split(rng, num=4)

init_conditioning = None

if config.get("conditioning_key"):

init_conditioning = jnp.ones(

[1] + list(train_ds.element_spec[config.conditioning_key].shape)[2:],

jnp.int32)

init_inputs = jnp.ones(

[1] + list(train_ds.element_spec["video"].shape)[2:],

jnp.float32)

initial_vars = model.init(

{"params": model_rng, "state_init": init_rng, "dropout": dropout_rng},

video=init_inputs, conditioning=init_conditioning,

padding_mask=jnp.ones(init_inputs.shape[:-1], jnp.int32))

# Split into state variables (e.g. for batchnorm stats) and model params.

# Note that \pop()\ on a FrozenDict performs a deep copy.``

state_vars, initial_params = initial_vars.pop("params") # pytype: disable=attribute-error

In the last line, the code errors out saying that it expected two outputs but only received one.

This seems to a problem with trying to run other jax models as well, but I can't find a solution in any forum I looked online.

Does anyone know what this issue is?

0 Comments
2023/10/29
03:58 UTC

1

Is the Response from the Google Bard correct????

The statement "Keras is to TensorFlow as Flax is to JAX" is a good analogy to describe the relationship between these two deep learning frameworks.

  • Keras is a high-level API that makes it easy to build and train neural networks. It is built on top of TensorFlow, which provides the underlying computation and numerical routines.
  • Flax is a high-performance deep learning library that is built on top of JAX. It is similar to Keras in that it provides a high-level API for building neural networks, but it is more performant and flexible.

In other words, Keras is to TensorFlow as Flax is to JAX:

  • Both Keras and Flax are high-level APIs for building neural networks.
  • Both Keras and Flax are built on top of a lower-level framework (TensorFlow and JAX, respectively).
  • Flax is more performant and flexible than Keras.

Here are some additional details about the similarities and differences between Keras and Flax:

  • Similarities:
    • Both Keras and Flax provide a high-level API for building neural networks. This means that you can use these frameworks to define the structure of your network, the type of layers you want to use, and the hyperparameters of your model.
    • Both Keras and Flax can be used to train neural networks on a variety of tasks, including image classification, natural language processing, and speech recognition.
    • Both Keras and Flax are open source frameworks that are actively maintained by their communities.
  • Differences:
    • Performance: Flax is generally more performant than Keras. This is because Flax is built on top of JAX, which is a high-performance numerical computation library.
    • Flexibility: Flax is more flexible than Keras. This is because Flax gives you more control over the underlying computation of your model.
    • Community: The Keras community is larger than the Flax community. This means that there are more resources available for learning Keras, and there are more people who can help you if you have problems.

Ultimately, the best framework for you will depend on your specific needs. If you are looking for a high-performance framework that gives you a lot of control over the underlying computation, then Flax is a good choice. If you are looking for a framework that is easy to learn and use, then Keras is a good choice.

I hope this helps!

1 Comment
2023/09/02
14:21 UTC

Back To Top