Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading

Chromatix: a differentiable, GPU-accelerated wave-optics library

Chromatix: a differentiable, GPU-accelerated wave-optics library
Main Many current microscopy methods increasingly rely on computation as an integral part of the imaging process. This model-based approach to optics—integrating optical system design with algorithmic reconstruction or optimization—has had major implications for biological discovery. Single-molecule localization microscopy can reveal individual protein complexes without requiring electron microscopy 1 , 2 and three-dimensional (3D) snapshot microscopy enables volumetric imaging at the frame rate of the camera, facilitating whole-brain imaging at high temporal resolution 3 , 4 , 5 , 6 , 7 , 8 . For both techniques, using wave-optics models to engineer point spread functions (PSFs) allows improved resolution in 3D 3 , 9 , 10 , 11 . Computer-generated holography uses wave-optics models of light propagation to design precise optogenetic stimulation of many neurons simultaneously 12 , 13 , 14 , 15 . Integrating microscopy with wave-optics models also allows measurement of optical properties that are otherwise difficult or slow to obtain, such as using diffraction tomography through a strongly scattering sample to obtain 3D refractive index distributions of transparent tissues from only intensity images, allowing high-contrast, label-free imaging of optically transparent model organisms such as C aenorhabditis elegans or Danio rerio 16 , 17 . Differentiable models of wave optics have also allowed for the end-to-end optimization of optical design and image processing algorithms for a variety of techniques such as quantitative phase imaging 18 , hyperspectral imaging 19 , extended depth-of-field 20 , monocular depth estimation 21 , 22 , localization microscopy 9 , lensless imaging 23 and 3D snapshot microscopy 3 . However, each computational optics method typically requires the researcher to program an optics simulation from scratch. These de novo simulations have differing conventions from other simulations, can be difficult to reuse in other applications and are often computationally suboptimal due to the difficulty of programming fast optics simulations on current computer hardware such as graphics processing units (GPUs). Moreover, these simulations typically need to be differentiable to facilitate efficient optimization of the relevant optical parameters or to be easily combined with deep neural networks. A standard library providing differentiable wave-optics simulations that make efficient use of GPUs is therefore desirable. Here we describe Chromatix, a high-performance differentiable wave-optics simulation library that could fill this gap in the field of computational optics. Drawing inspiration from current deep learning frameworks such as PyTorch 24 and TensorFlow 25 that construct neural networks from layers of mathematical operations, Chromatix enables researchers to describe optical systems as compositions of fundamental elements. This architectural parallel runs deeper than mere interface design: Chromatix shares the core requirements of differentiability for gradient-based optimization, scalability for tackling large-scale problems and composability for rapid method development. Chromatix is built on JAX (just after execution) 26 , a numerical computation library for Python that provides GPU acceleration and automatic differentiation to automatically and efficiently calculate gradients with respect to any input of differentiable functions. By leveraging JAX’s capabilities, Chromatix simulations can be seamlessly accelerated on GPUs and parallelized across multiple devices with minimal code changes (Extended Data Fig. 1 ), enabling integration with existing deep learning models, optimization loops and even hardware control systems. Chromatix implements diverse optical models, from conventional lenses to diffractive elements such as liquid-crystal-on-silicon spatial light modulators (SLMs), and various models of scalar and vectorial wave propagation through both free space and scattering media, enabling simulation of optical systems interacting with multiple-scattering (potentially anisotropic) biological samples whose amplitude and refractive index vary in 3D. While there exist established tools for optical design such as Zemax 27 or CODE V 28 that do support wave optics, they are not efficiently implemented for the types of model we present here and are not differentiable or interoperable with deep learning models. Recently, a number of open-source differentiable optics simulation libraries have been released 29 , 30 , 31 , 32 ; however, these libraries are either primarily ray-based 29 , 32 (while the computational optics methods we are interested in are wave-based) or lack support for all the features that would be desirable for such a standard library in computational optics, for example multiple-scattering 3D samples or polarization. We provide a more detailed comparison of Chromatix versus other optics simulation software in Extended Data Table 1 . Here we demonstrate Chromatix’s ability to simulate and optimize various optical systems for biological imaging, showing simulations of widefield and snapshot fluorescence microscopes, phase contrast microscopes and Fourier holography systems for optogenetics. In each case, we highlight Chromatix’s scalability, delivering results substantially faster than previous implementations—2–22× improvements through parallelization (depending on the problem)—while eliminating the challenges of correctly implementing wave-optics simulations from scratch. We anticipate that this open-source library will democratize high-performance, scalable wave-optics simulations, enabling exploration of a much richer design space in computational optics and accelerating innovation in biological imaging and beyond. Results Design and implementation The design of Chromatix is strongly inspired by current deep learning frameworks. Here we detail the principles that informed our design choices and implementation decisions. We argue that effective frameworks for computational optics, like those for deep learning, must embody three key characteristics: differentiability, composability and scalability. We also discuss the high level implementation of Chromatix as it relates to these three characteristics. Differentiability Differentiability is the ability to calculate gradients, which can be used for gradient-based optimization (for example, of the parameters of an optical simulation). For small, low-dimensional inputs, numerical differentiation can be sufficient, but for high-dimensional inputs (for example, the pixels of an SLM) this becomes too computationally expensive to evaluate, and it is preferable to use backpropagation of the gradients of each step of the simulation. When combined with the wide variety of possible optical models or neural networks, automatic differentiation becomes a desirable property. Common programming languages such as MATLAB ( https://www.mathworks.com ) and C do not provide general-purpose automatic differentiation, requiring gradients to be manually derived and implemented: an error-prone, time-consuming and inflexible process. Current deep learning frameworks 24 , 25 , 26 provide automatic differentiation: given a function, they can automatically calculate the gradient with respect to any parameter of that function as long as the function is differentiable with respect to that parameter. Similarly, Chromatix can automatically calculate the gradient with respect to any parameter of a simulation as it has been written using JAX. Differentiability has already found several uses in optical design, enabling end-to-end design of computational optics systems for a variety of problems 3 , 9 , 18 , 19 , 20 , 21 , 22 , 23 . Differentiability also has the potential to improve solutions to inverse problems in optics. Traditional inverse problem approaches simplify both the sample and the optics 33 , 34 , 35 , 36 , whereas differentiable models can handle more realistic complexity. Automatic differentiation opens up a new class of gradient-based optimizers such as Adam 37 , which can improve reconstruction fidelity by allowing arbitrarily complex physics simulations (for example, scattering 16 , 38 or sample deformation 39 ) in the forward simulation. These benefits come at nearly zero programmer effort with automatic differentiation: once the forward model of the simulation has been defined, its gradients are automatically defined as well. Thus, automatic differentiation can also enable so-called self-calibrating algorithms 40 . As hardware has a finite accuracy, optimizing certain physical parameters (such as angle of illumination in tomography) together with the sample has been shown to improve fidelity 41 , 42 . A different line of work replaces discrete voxel-based representations with neural network-based continuous representations, a concept known as implicit neural representations (INRs) or neural radiance fields 39 , 43 , 44 , 45 . INRs have been applied to separation of motion artifacts from sample dynamics 39 , estimation of dynamic aberrations 44 , reconstruction of 3D quantitative phase of scattering samples 46 and aberration correction without wavefront sensors or calibration measurements 43 . Here too, differentiable simulations are required to train these networks. Composability The principle underlying differentiability is composability: the gradient of a composition of two functions can be calculated from the gradient of each of those functions. Taking a broader and more practical view, we can interpret composability as being able to easily swap and replace components of a network, for example replacing the activation function in a multilayer perceptron, without requiring changes to the rest of the system. This composability is possible due to standardization in the field of machine learning, which enables machine learning researchers to conveniently incorporate their colleagues’ advances by quickly replacing a function rather than having to rewrite their code from scratch. The field of optics stands in stark contrast: implementations are often project-specific, each with their own conventions and quirks, and a baseline to compare these codes to with respect to accuracy and speed does not exist. This practice is time-consuming, error-prone and makes reproducing results challenging. Chromatix proposes a standard for wave-optics simulations to enable composition of a wide array of optical models (Fig. 1 ). The experiments presented in this paper all share a common, well-tested codebase and many more components are available in our documentation. We believe that the existence of both a standard library and baseline implementations can substantially speed up research and make it more reproducible. Fig. 1: The design and components of Chromatix. The alternative text for this image may have been generated using AI. Full size image a , Chromatix combines wave-optics models, GPU acceleration and differentiability in a single library, providing a unified modeling framework to allow a wide range of applications. b , Chromatix implements a wide range of optical elements such as lenses, sensors, free-space propagation models for scalar and vectorial waves 70 , 71 , 72 and complex scattering samples 16 , 38 . c , These elements can be combined to simulate a wide variety of experimental systems and solve a wide range of problems in computational optics. Green highlighted elements indicate the element or sample that would be optimized in each application. DMD, digital micromirror device; SLM, spatial light modulator; f , focal length; z n , propagation distance. Scalability Optics is moving to ever larger fields of view (FOVs) and higher resolutions, sometimes requiring large compute clusters for sample reconstruction. A key requirement for new optics simulations is thus the ability to scale; researchers may want to run code on laptops for quick prototyping, but also easily scale up to GPU-clusters for large-scale sample reconstruction. Previous popular programming environments have made this difficult: NumPy runs only on central processing units 47 ; MATLAB requires specific code for GPU usage and does not support general-purpose automatic differentiation; PyTorch 24 /TensorFlow 25 make writing GPU code with automatic differentiation relatively easy but it can be tricky to support multiple GPUs (both PyTorch and TensorFlow) or achieve good performance for typical operations in optical simulation that differ greatly from typical operations in neural networks (PyTorch). Writing device-specific programs for simulations requires substantial effort and calcifies their capabilities, which is not appropriate for the fast iteration demanded by scientific research. Chromatix instead relies on JAX 26 and its underlying XLA (accelerated linear algebra) compiler to support fast optical simulation on central processing units, GPUs and tensor processing units with only a single implementation (and without requiring custom lower level GPU code for fast operations as in PyTorch 24 ). JAX also offers several functions to automatically vectorize code (that is, parallelize a batch on a single GPU) or parallelize over multiple GPUs, independently of the description of the optical elements in an optical system 26 . For example, with only a couple of lines of changes to the code we can scale a two-dimensional (2D), single-wavelength simulation to a 3D, multi-wavelength simulation running on multiple GPUs in parallel (see Extended Data Fig. 1 for code examples). Implementation In deep learning frameworks, these principles manifest themselves as models consisting of sequences of deep learning operations (layers). We observe a clear correspondence to optics, where optical systems consist of a sequence of optical elements and propagations. A key difference, however, is that the ‘hidden state’ of an optical system has a clear physical meaning: it is the complex light field moving through the system. To completely describe this field, and thus the state of the system at any time, additional information, such as the wavelength, polarization and spatial sampling, is required. The core idea behind Chromatix is that all this information can be encoded in a single, fundamental structure. Any optical element can then be written as a transformation of this structured field, and any optical system as a sequence of these elements. This allows Chromatix to model a wide variety of optical systems under a unified interface, which makes extending its capabilities straightforward. Experiments We present six computational experiments demonstrating four major features of Chromatix: solving inverse problems to reconstruct samples, accelerating reconstruction and optical design using deep learning, composing modular optical elements and models in arbitrary ways, and scaling optical simulation speed by an order of magnitude. To do this, we showcase both reproductions of existing computational methods in opt

Source: Nature

Read Original Source →

Cart (0 items)