import os, sys
from datetime import datetime
import warnings
from importlib import import_module
from addict import Dict
import json
from collections import OrderedDict
import lenstronomy.SimulationAPI.ObservationConfig as obs_cfg
[docs]class BaobabConfig:
"""Nested dictionary representing the configuration for Baobab data generation
"""
def __init__(self, user_cfg):
"""
Parameters
----------
user_cfg : dict or Dict
user-defined configuration
"""
self.__dict__ = Dict(user_cfg)
if not hasattr(self, 'out_dir'):
# Default out_dir path if not specified
self.out_dir = os.path.join(self.destination_dir, '{:s}_{:s}_prior={:s}_seed={:d}'.format(self.name, self.train_vs_val, self.bnn_prior_class, self.seed))
self.out_dir = os.path.abspath(self.out_dir)
if not hasattr(self, 'checkpoint_interval'):
self.checkpoint_interval = max(100, self.n_data // 100)
self.get_survey_info(self.survey_info, self.psf.type)
self.interpret_magnification_cfg()
self.interpret_kinematics_cfg()
self.log_filename = datetime.now().strftime("log_%m-%d-%Y_%H:%M_baobab.json")
self.log_path = os.path.join(self.out_dir, self.log_filename)
[docs] @classmethod
def from_file(cls, user_cfg_path):
"""Alternative constructor that accepts the path to the user-defined configuration python file
Parameters
----------
user_cfg_path : str or os.path object
path to the user-defined configuration python file
"""
dirname, filename = os.path.split(os.path.abspath(user_cfg_path))
module_name, ext = os.path.splitext(filename)
sys.path.insert(0, dirname)
if ext == '.py':
#user_cfg_file = map(__import__, module_name)
#user_cfg = getattr(user_cfg_file, 'cfg')
user_cfg_script = import_module(module_name)
user_cfg = getattr(user_cfg_script, 'cfg').deepcopy()
return cls(user_cfg)
elif ext == '.json':
with open(user_cfg_path, 'r') as f:
user_cfg_str = f.read()
user_cfg = Dict(json.loads(user_cfg_str)).deepcopy()
return cls(user_cfg)
else:
raise NotImplementedError("This extension is not supported.")
[docs] def export_log(self):
"""Export the baobab log to the current working directory
"""
with open(self.log_path, 'w') as f:
json.dump(self.__dict__, f)
print("Exporting baobab log to {:s}".format(self.log_path))
[docs] def interpret_magnification_cfg(self):
if 'agn_light' not in self.components:
if len(self.bnn_omega.magnification.frac_error_sigma) != 0: # non-empty dictionary
warnings.warn("`bnn_omega.magnification.frac_error_sigma` field is ignored as the images do not contain AGN.")
self.bnn_omega.magnification.frac_error_sigma = 0.0
else:
if 'magnification' not in self.bnn_omega:
self.bnn_omega.magnification.frac_error_sigma = 0.0
elif self.bnn_omega.magnification is None:
self.bnn_omega.magnification.frac_error_sigma = 0.0
if ('magnification' not in self.bnn_omega) and 'agn_light' in self.components:
self.bnn_omega.magnification.frac_error_sigma = 0.0
[docs] def interpret_kinematics_cfg(self):
"""Validate the kinematics config
"""
kinematics_cfg = self.bnn_omega.kinematics_cfg
if kinematics_cfg.anisotropy_model == 'analytic':
warnings.warn("Since velocity dispersion computation is analytic, any entry other than `sampling_number` in `kinematics.numerics_kwargs` will be ignored.")
[docs] def get_survey_info(self, survey_info, psf_type):
"""Fetch the camera and instrument information corresponding to the survey string identifier
"""
sys.path.insert(0, obs_cfg.__path__[0])
survey_module = import_module(survey_info['survey_name'])
survey_class = getattr(survey_module, survey_info['survey_name'])
coadd_years = survey_info['coadd_years'] if 'coadd_years' in survey_info else None
self.survey_object_dict = OrderedDict()
for bp in survey_info['bandpass_list']:
survey_object = survey_class(band=bp, psf_type=psf_type, coadd_years=coadd_years)
# Overwrite ObservationConfig PSF type with user-configured PSF type
if hasattr(self, 'psf'):
survey_object.obs['psf_type'] = self.psf.type
if survey_object.obs['psf_type'] == 'PIXEL':
if hasattr(self, 'psf'):
if hasattr(self.psf, 'psf_kernel_size'):
survey_object.psf_kernel_size = self.psf.kernel_size
else:
raise ValueError("Observation dictionary must specify PSF kernel size if psf_type is PIXEL.")
if hasattr(self.psf, 'which_psf_maps'):
survey_object.which_psf_maps = self.psf.which_psf_maps
else:
raise ValueError("Observation dictionary must specify indices of PSF kernel maps if psf_type is PIXEL.")
else:
raise ValueError("User must supply PSF kwargs in the Baobab config if PSF type is PIXEL.")
else: # 'GAUSSIAN'
survey_object.psf_kernel_size = None
survey_object.which_psf_maps = None
# Override default survey specs with user-specified kwargs
survey_object.camera.update(survey_info['override_camera_kwargs'])
survey_object.obs.update(survey_info['override_obs_kwargs'])
self.survey_object_dict[bp] = survey_object
# Camera dict is same across bands, so arbitrarily take the last band
self.instrument = survey_object.camera
[docs] def get_noise_kwargs(self,bandpass):
"""
Return the noise kwargs defined in the babobab config, e.g. for passing to the noise model for online data augmentation
Returns
-------
(dict): A dict containing the noise kwargs to be passed to the noise
model.
(str): The bandpass to pull the noise information for
"""
# Go through the baobab config and pull out the noise kwargs one by one.
noise_kwargs = {}
noise_kwargs.update(self.instrument)
noise_kwargs.update(self.survey_object_dict[bandpass].kwargs_single_band())
return noise_kwargs