JAX: accelerated machine learning research via composable function transformations in Python
JAX is a system for high-performance machine learning research and numerical computing. It offers the familiarity of Python+NumPy together with hardware acceleration, and it enables the definition and composition of user-wielded function transformations useful for machine learning programs. These transformations include automatic differentiation, automatic batching, end-to-end compilation (via XLA), parallelizing over multiple accelerators, and more. Composing these transformations is the key to JAX’s power and simplicity.
JAX had its initial open-source release in December 2018 (https://github.com/google/jax). It’s used by researchers for a wide range of advanced applications, from studying training dynamics of neural networks, to probabilistic programming, to scientific applications in physics and biology.
This talk will be moderated by Todd Mytkowicz.
Matt Johnson is a research scientist at Google Brain interested in software systems powering machine learning research. When moonlighting as a machine learning researcher, he works on composing graphical models with neural networks, automatically recognizing and exploiting conjugacy structure, and model-based reinforcement learning from pixels. Matt was a postdoc with Ryan Adams at the Harvard Intelligent Probabilistic Systems Group and Bob Datta in the Datta Lab at the Harvard Medical School. His Ph.D. is from MIT in EECS, where he worked with Alan Willsky on Bayesian time series models and scalable inference. He was an undergrad at UC Berkeley (Go Bears!).
Thu 19 NovDisplayed time zone: Central Time (US & Canada) change
13:00 - 13:40 | |||
13:00 40mTalk | JAX: accelerated machine learning research via composable function transformations in Python REBASE Matthew J. Johnson Google Brain |