Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .vulture_whitelist.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@
# Image processing functions used in prediction scripts (root level, excluded from scan)
process_single_wrapper # noqa - Used in prediction_utils.py, prediction_process_hdf5.py
_.n_expected_channels # noqa - fitsbolt config attribute set dynamically
_.channel_combination_dict # noqa - Used in prediction_process_cutana.py (outside scan path)
19 changes: 18 additions & 1 deletion anomaly_match/utils/validate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,25 @@ def _format_constraints():
cc = cfg.normalisation.channel_combination
fits_ext = cfg.normalisation.fits_extension

# Dict form channel_combination: keys are filter names, values are
# weight lists. Convert to numpy array for fitsbolt compatibility
# while preserving the original dict in the config for streaming.
if cc is not None and isinstance(cc, dict):
keys = list(cc.keys())
cc_array = np.column_stack([np.array(cc[k]) for k in keys])
inferred = cc_array.shape[0]
if inferred != cfg.normalisation.n_output_channels:
logger.info(
f"Setting n_output_channels to {inferred} "
f"from dict channel_combination ({keys})"
)
cfg.normalisation.n_output_channels = inferred
# Store numpy array for fitsbolt, keep dict in config for streaming
cfg.normalisation.channel_combination = cc_array
cfg.normalisation.channel_combination_dict = cc

# Infer n_output_channels from channel_combination matrix if provided
if cc is not None and hasattr(cc, "shape") and len(cc.shape) == 2:
elif cc is not None and hasattr(cc, "shape") and len(cc.shape) == 2:
inferred = cc.shape[0]
if inferred != cfg.normalisation.n_output_channels:
logger.info(
Expand Down
111 changes: 64 additions & 47 deletions prediction_process_cutana.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,44 +97,31 @@ def evaluate_images_from_cutana(

# Configure FITS extensions for cutana.
#
# AnomalyMatch's fits_extension uses integer HDU indices (for multi-extension
# cutout files loaded by fitsbolt). Cutana operates on large mosaic tiles
# referenced in the catalogue — each source has separate FITS files per band.
# Cutana identifies bands by filter name (e.g. "VIS", "NIR-H") extracted from
# the file paths, so we must resolve integer indices to filter names here.
#
# NOTE: filter name extraction currently relies on Euclid naming conventions
# (via cutana.catalogue_preprocessor.extract_filter_name). If your catalogue
# uses non-Euclid file naming, set cfg.normalisation.fits_extension to
# explicit filter name strings instead of integer indices.
# AnomalyMatch's fits_extension uses integer HDU indices or cutout extension
# names (e.g. CHANNEL_1) for reading multi-extension cutout files via
# fitsbolt. Cutana streaming operates on large mosaic tiles referenced in
# the catalogue — each source has separate FITS files per band with a single
# PRIMARY HDU. These cutout-level extension values are NOT valid for the
# source mosaics, so for multi-extension streaming we always resolve filter
# names (e.g. "VIS", "NIR-H") from the catalogue file paths.
fits_ext = cfg.normalisation.fits_extension
if fits_ext is None:
fits_ext = ["PRIMARY"]
elif isinstance(fits_ext, (str, int)):
fits_ext = [fits_ext]

# When fits_extension contains integers, resolve to filter names from the
# catalogue's fits_file_paths column.
has_integer_indices = any(isinstance(e, int) for e in fits_ext)
if has_integer_indices:
if len(fits_ext) > 1:
extension_names = _resolve_filter_names_from_catalogue(
cutana_sources_path, len(fits_ext)
)
else:
# Single integer index (e.g. [0]) maps to the PRIMARY HDU
extension_names = ["PRIMARY"]
else:
extension_names = [str(e) for e in fits_ext]

# Build selected_extensions for cutana.
# For multi-file catalogues (separate FITS per band), each file has only a
# PRIMARY HDU, so fits_extensions must be ["PRIMARY"]. The filter names go
# into channel_weights and selected_extensions for channel identification.
if has_integer_indices:
cutana_config.fits_extensions = ["PRIMARY"]
if len(fits_ext) > 1:
# Multi-extension: resolve filter names from catalogue file paths.
# fits_extension values (integer HDU indices like [1,2,3] or cutout
# extension names like ['CHANNEL_1','CHANNEL_2','CHANNEL_3']) are only
# meaningful for reading cutout files — they must not be passed to
# cutana as-is because source mosaics have different HDU structure.
extension_names = _resolve_filter_names_from_catalogue(cutana_sources_path, len(fits_ext))
else:
cutana_config.fits_extensions = extension_names
extension_names = ["PRIMARY"]

# Source mosaics have a single PRIMARY HDU per file.
cutana_config.fits_extensions = ["PRIMARY"]

selected_extensions = []
for name in extension_names:
Expand All @@ -146,9 +133,36 @@ def evaluate_images_from_cutana(
# Channel combination must happen BEFORE normalisation (cutana's pipeline
# ensures this) so that ZSCALE/ASINH see the same data shape as training.
n_out = cfg.normalisation.n_output_channels
if cfg.normalisation.channel_combination is not None:
# Multi-extension: convert numpy matrix (n_out x n_in) to cutana dict
combo = cfg.normalisation.channel_combination
# Prefer the original dict form (preserved by validate_config when the user
# provides channel_combination as a dict). Falls back to the numpy array.
combo_dict = cfg.normalisation.channel_combination_dict
# DotMap subclasses dict, so empty auto-created DotMaps pass isinstance check.
# Only use combo_dict when it's a real non-empty dict set by validate_config.
combo = (
combo_dict
if isinstance(combo_dict, dict) and len(combo_dict) > 0
else cfg.normalisation.channel_combination
)
if combo is not None and isinstance(combo, dict) and len(combo) > 0:
# Dict form: keys are filter names, values are weight lists.
# This is order-independent — no risk of channel jumbling when
# the streaming catalogue has a different file order than training.
missing = set(extension_names) - set(combo.keys())
if missing:
raise ValueError(
f"channel_combination dict is missing keys for resolved filter names: "
f"{missing}. Dict keys: {list(combo.keys())}, "
f"resolved filters: {extension_names}"
)
channel_weights = {}
for name in extension_names:
weights = combo[name]
channel_weights[name] = list(weights) if not isinstance(weights, list) else weights
cutana_config.channel_weights = channel_weights
elif combo is not None and hasattr(combo, "shape"):
# Numpy array form: columns are positionally mapped to extension_names
# (order depends on catalogue file path order — use dict form to avoid
# ambiguity when streaming and training catalogues differ).
channel_weights = {}
for j, ext_name in enumerate(extension_names):
channel_weights[str(ext_name)] = combo[:, j].tolist()
Expand All @@ -164,28 +178,31 @@ def evaluate_images_from_cutana(

# Verify channel configuration consistency
if len(extension_names) > 1:
combo = cfg.normalisation.channel_combination
n_in = combo.shape[1] if hasattr(combo, "shape") else len(extension_names)
if isinstance(combo, dict):
n_in = len(combo)
elif hasattr(combo, "shape"):
n_in = combo.shape[1]
else:
n_in = len(extension_names)
if len(extension_names) != n_in:
raise ValueError(
f"Number of resolved filter names ({len(extension_names)}) does not match "
f"channel_combination input dimension ({n_in}). "
f"Filter names: {extension_names}, matrix shape: {combo.shape}"
f"Filter names: {extension_names}"
)
if combo.shape[0] != n_out:
if isinstance(combo, dict):
lengths = {k: len(v) for k, v in combo.items()}
bad = {k: length for k, length in lengths.items() if length != n_out}
if bad:
raise ValueError(
f"channel_combination dict values must have length {n_out} "
f"(n_output_channels), but got: {bad}"
)
elif hasattr(combo, "shape") and combo.shape[0] != n_out:
raise ValueError(
f"channel_combination output dimension ({combo.shape[0]}) does not match "
f"n_output_channels ({n_out})"
)
# For non-diagonal matrices, verify all input channels contribute
# (a zero column means an extension is loaded but never used)
for j, ext_name in enumerate(extension_names):
col_sum = abs(combo[:, j]).sum()
if col_sum == 0:
logger.warning(
f"Extension '{ext_name}' (column {j}) has zero weight in "
f"channel_combination — this channel will be loaded but ignored"
)
logger.info(
f"Channel configuration: {len(extension_names)} inputs -> {n_out} outputs, "
f"filter order: {extension_names}"
Expand Down
Loading