# opticks Models and analysis tools for optical system engineering
#
# Copyright (C) Egemen Imre
#
# Licensed under GNU GPL v3.0. See LICENSE.md for more info.
from enum import Enum
import numpy as np
from astropy.units import Quantity
from numpy import ndarray
from prysm._richdata import RichData
from prysm.polynomials import sum_of_2d_modes
from prysm.propagation import Wavefront
from opticks import u
from opticks.contrast_model.mtf import MTF_Model_1D
from opticks.contrast_model.optics_mtf import _psf_to_mtf # type: ignore[attr-defined]
from opticks.imaging_model.aperture import Aperture
from opticks.utils.prysm_utils import OptPathDiff
from opticks.utils.unit_utils import split_value_and_force_unit
[docs]
class WvlRef(Enum):
"""Reference wavelength selector for PSF output metadata."""
FIRST = "first"
LAST = "last"
MID = "mid" # middle element of the array
AVERAGE = "average"
[docs]
def resolve(self, wavelengths: Quantity) -> Quantity:
"""Select a wavelength from the Quantity array based on this selector.
Parameters
----------
wavelengths : Quantity
wavelengths array to select from
Returns
-------
Quantity
the selected wavelength
"""
if self == WvlRef.FIRST:
return wavelengths[0] # type: ignore[return-value]
elif self == WvlRef.LAST:
return wavelengths[-1] # type: ignore[return-value]
elif self == WvlRef.MID:
return wavelengths[len(wavelengths) // 2] # type: ignore[return-value]
elif self == WvlRef.AVERAGE:
return wavelengths.mean() # type: ignore[return-value]
else:
raise ValueError(f"Unknown WvlRef value: {self}")
[docs]
class PupilFunction:
"""Pupil function combining amplitude and phase information.
A pupil function encapsulates one or more wavelength+OPD combinations
(monochromatic or polychromatic) and can compute the resulting PSF.
Users should not create PupilFunction objects directly.
Use `Optics.add_mono_pupil_function()` or
`Optics.add_poly_pupil_function()` instead.
Parameters
----------
wavelengths : Quantity
wavelengths array (in microns)
opds : list[OptPathDiff | None]
list of optical path differences (in nm), one per wavelength.
Use None for zero phase (perfect wavefront).
aperture : Aperture
aperture object
aperture_dx : Quantity
aperture sample distance (in mm)
focal_length : Quantity
focal length (in mm)
spectral_weights : np.ndarray, optional
spectral weight of each wavelength, by default uniform
"""
def __init__(
self,
wavelengths: Quantity,
opds: list[OptPathDiff | None],
aperture: Aperture,
aperture_dx: Quantity,
focal_length: Quantity,
spectral_weights: np.ndarray | None = None,
) -> None:
# validate lengths
if len(wavelengths) != len(opds):
raise ValueError(
f"wavelengths and opds must have the same length, "
f"got {len(wavelengths)} and {len(opds)}."
)
if spectral_weights is not None and len(spectral_weights) != len(wavelengths):
raise ValueError(
f"spectral_weights must have the same length as wavelengths, "
f"got {len(spectral_weights)} and {len(wavelengths)}."
)
self._wavelengths = wavelengths
self._opds = opds
self._aperture = aperture
self._aperture_dx = aperture_dx
self._focal_length = focal_length
self._spectral_weights = (
spectral_weights
if spectral_weights is not None
else np.ones(len(wavelengths))
)
# build wavefronts eagerly
self._wavefronts = self._build_wavefronts()
# PSF and MTF caches
self._psf: RichData | None = None
self._mtf: RichData | None = None
[docs]
@classmethod
def monochromatic(
cls,
wavelength: Quantity,
opd: OptPathDiff | None,
aperture: Aperture,
aperture_dx: Quantity,
focal_length: Quantity,
) -> "PupilFunction":
"""Create a monochromatic PupilFunction (single wavelength).
This is useful for sampling of a narrow beam.
Parameters
----------
wavelength : Quantity
wavelength of light (in microns)
opd : OptPathDiff | None
optical path difference (in nm), or None for zero phase
aperture : Aperture
aperture object
aperture_dx : Quantity
aperture sample distance (in mm)
focal_length : Quantity
focal length (in mm)
"""
return cls(
wavelengths=Quantity([wavelength]),
opds=[opd],
aperture=aperture,
aperture_dx=aperture_dx,
focal_length=focal_length,
)
[docs]
@classmethod
def polychromatic(
cls,
wavelengths: Quantity,
opds: list[OptPathDiff | None],
aperture: Aperture,
aperture_dx: Quantity,
focal_length: Quantity,
spectral_weights: np.ndarray | None = None,
) -> "PupilFunction":
"""Create a polychromatic PupilFunction (multiple wavelengths).
This is useful for sampling of a broadband beam.
Parameters
----------
wavelengths : Quantity
wavelengths array (in microns)
opds : list[OptPathDiff | None]
list of optical path differences (in nm), one per wavelength
aperture : Aperture
aperture object
aperture_dx : Quantity
aperture sample distance (in mm)
focal_length : Quantity
focal length (in mm)
spectral_weights : np.ndarray, optional
spectral weight of each wavelength, by default uniform
"""
if len(wavelengths) == 1:
return cls.monochromatic(
wavelengths[0], # type: ignore[arg-type]
opds[0],
aperture,
aperture_dx,
focal_length,
)
return cls(
wavelengths=Quantity(wavelengths),
opds=opds,
aperture=aperture,
aperture_dx=aperture_dx,
focal_length=focal_length,
spectral_weights=spectral_weights,
)
@property
def is_monochromatic(self) -> bool:
"""True if this pupil function has a single wavelength."""
return len(self._wavelengths) == 1
@property
def num_wavelengths(self) -> int:
"""Number of wavelengths in this pupil function."""
return len(self._wavelengths)
@property
def wavelengths(self) -> Quantity:
"""Wavelengths array."""
return self._wavelengths
@property
def spectral_weights(self) -> np.ndarray:
"""Spectral weights array."""
return self._spectral_weights
[docs]
def compute_psf(
self,
psf_dx: Quantity,
psf_samples: int = 512,
wvl_ref: WvlRef | None = None,
with_units: bool = True,
) -> RichData:
"""Compute the PSF, cache internally, and return it.
Parameters
----------
psf_dx : Quantity
sample distance of the output PSF plane grid (in microns)
psf_samples : int, optional
number of samples in the output plane, by default 512
wvl_ref : WvlRef, optional
reference wavelength selector for output metadata.
For monochromatic, defaults to the only wavelength.
For polychromatic, defaults to WvlRef.AVERAGE.
with_units : bool, optional
output the PSF with or without units, by default True
Returns
-------
RichData
PSF model
"""
# resolve reference wavelength
if wvl_ref is None:
wvl_ref = WvlRef.FIRST if self.is_monochromatic else WvlRef.AVERAGE
ref_wvl = wvl_ref.resolve(self._wavelengths)
# compute PSF
psf = _compute_psf(
self._wavefronts,
self._focal_length,
ref_wvl,
psf_dx,
psf_samples,
self._spectral_weights,
)
# add units if requested
if with_units:
psf = RichData(psf.data, psf.dx * u.um, psf.wavelength * u.um)
# cache and invalidate MTF cache
self._psf = psf
self._mtf = None
return psf
@property
def psf(self) -> RichData:
"""Return the cached PSF.
Raises
------
ValueError
if compute_psf() has not been called yet
"""
if self._psf is None:
raise ValueError("PSF has not been computed yet. Call compute_psf() first.")
return self._psf
@property
def mtf(self) -> RichData:
"""Return the MTF, computing lazily from the cached PSF if needed.
Raises
------
ValueError
if compute_psf() has not been called yet
"""
if self._psf is None:
raise ValueError("PSF has not been computed yet. Call compute_psf() first.")
if self._mtf is None:
self._mtf = _psf_to_mtf(self._psf, with_units=True)
return self._mtf # type: ignore[return-value]
[docs]
def to_MTF_Model_1D(self, slice: str) -> MTF_Model_1D:
"""Convert the cached 2D MTF to a 1D MTF model.
Extracts a 1D slice from the 2D MTF and returns it as an
``MTF_Model_1D`` object. The PSF must have been computed
before calling this method.
Possible slice strings are ``x``, ``y``, ``azavg``, ``azavmedian``,
``azmin``, ``azpv``, ``azvar``, ``azstd``.
Parameters
----------
slice : str
slice type (e.g., "x", "y", "azavg")
Returns
-------
MTF_Model_1D
1D MTF model
Raises
------
ValueError
if compute_psf() has not been called yet
"""
return MTF_Model_1D.from_mtf_2d(self.mtf, slice)
def _build_wavefronts(self) -> list[Wavefront]:
"""Build prysm Wavefront objects from stored wavelengths/OPDs."""
wavefronts = []
for wvl, opd in zip(self._wavelengths, self._opds, strict=True):
opd_data = opd.strip_units(u.nm).data if opd else None
wf = Wavefront.from_amp_and_phase(
self._aperture.data,
phase=opd_data,
wavelength=wvl.to_value(u.um),
dx=self._aperture_dx.to_value(u.mm),
)
wavefronts.append(wf)
return wavefronts
def _compute_psf(
pupils: list[Wavefront],
focal_length: float | Quantity,
wvl: float | Quantity,
psf_dx: float | Quantity,
psf_samples: int,
spectral_weights: ndarray,
) -> RichData:
"""Computes the PSF for a single point on the Image Plane.
The function can handle monochromatic or polychromatic PSF
computations. The PSF or Image Plane is resampled to
the user defined grid.
The spectral weights array should have the same number of elements
as the number of Pupil Functions.
The operation can be fairly expensive, therefore it is advised
to keep the output PSF stored.
Parameters
----------
pupils : list[Wavefront]
list of Pupil functions or Wavefronts
focal_length : float | Quantity
focal length in mm
wvl : float | Quantity
reference wavelength (in microns)
psf_dx : float | Quantity
sample distance of the output PSF Plane grid (in microns)
psf_samples : int
number of samples in the output plane.
If int, interpreted as square else interpreted as (x,y),
which is the reverse of numpy's (y, x) row major ordering.
spectral_weights : np.ndarray
spectral weight of each wavelength
Returns
-------
RichData
PSF model (without units)
"""
focal_length_val, _ = split_value_and_force_unit(focal_length, u.mm)
wvl_val, _ = split_value_and_force_unit(wvl, u.um)
psf_dx_val, _ = split_value_and_force_unit(psf_dx, u.um)
psf_components = []
# focus all WF objects and compute the monochromatic PSF
for pupil in pupils:
# complex field in the plane of the PSF (no unit support)
# Note: focusing changes aperture sampling,
# and is a function of wavelength. Therefore we use fixed sampling
# for multiple wavelengths.
mdft = pupil.prepare_executor(focal_length_val, psf_dx_val, psf_samples)
coherent_psf = pupil.focus_dft(mdft)
psf_data = coherent_psf.intensity.data
# sum of intensities, wvls are incoherent to each other
psf_components.append(psf_data)
# Create psf array via summation (no unit support)
psf_data = sum_of_2d_modes(np.asarray(psf_components), spectral_weights)
# Add scaling and wavelength information
# pupil to psf plane means dx is switched from mm to um
psf = RichData(psf_data, psf_dx_val, wvl_val)
return psf