A Jax implementation of Kolmogorov Arnold Networks.
Original publication: KAN: Kolmogorov-Arnold Networks
Original Pytorch implementation: link
An efficient pytorch implementation which is the inspiration for this repo: link
This is a work in progress. It's a port from efficient implementation repo mentioned earlier. Currently everything is tested against the pytorch implementation except for the update_grid
method. Flax is pretty strict with parameter manipulation outside of computational graph functions so trying to figure out a clean way to do that. Please feel free to contribute.