Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cutana/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,15 @@ def apply_normalisation(images: np.ndarray, config: DotMap) -> np.ndarray:

# Add images array to parameters (done here to avoid unnecessary copying)
fitsbolt_params["images"] = images_array
# TODO(hotfix): Use config.data_type to determine output dtype for fitsbolt normalization
if config.data_type == "uint8":
fitsbolt_params["output_dtype"] = np.uint8
elif config.data_type == "float32":
fitsbolt_params["output_dtype"] = np.float32
elif config.data_type == "float64":
fitsbolt_params["output_dtype"] = np.float64
else: # default to float32 if unknown
fitsbolt_params["output_dtype"] = np.float32

try:
# Apply fitsbolt batch normalization with parameters
Expand Down
2 changes: 1 addition & 1 deletion cutana/normalisation_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def convert_cfg_to_fitsbolt_cfg(config: DotMap, num_channels: int = 1) -> Dict[s
if method == "log":
norm_method = fitsbolt.NormalisationMethod.LOG
elif method == "linear":
norm_method = fitsbolt.NormalisationMethod.CONVERSION_ONLY
norm_method = fitsbolt.NormalisationMethod.LINEAR
elif method == "asinh":
norm_method = fitsbolt.NormalisationMethod.ASINH
elif method == "zscale":
Expand Down
10 changes: 8 additions & 2 deletions tests/cutana/unit/test_image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,10 @@ def test_combine_channels_simple(self, mock_cutout_data):

# Should return RGB format (1, H, W, 3)
assert combined.shape == (1, H, W, 3)
assert combined.dtype == np.float32
assert combined.dtype in [
np.float32,
np.float64,
] # fitsbolt may return float32 or float64 depending on input data
assert isinstance(combined, np.ndarray)

def test_combine_channels_equal_weights(self, mock_cutout_data):
Expand All @@ -281,7 +284,10 @@ def test_combine_channels_equal_weights(self, mock_cutout_data):
# It processes RGB weights differently than simple linear combination
# Just verify basic properties - should return RGB format (1, H, W, 3)
assert combined.shape == (1, H, W, 3)
assert combined.dtype == np.float32
assert combined.dtype in [
np.float32,
np.float64,
] # fitsbolt may return float32 or float64 depending on input data
assert isinstance(combined, np.ndarray)

def test_error_handling_invalid_cutout_data(self):
Expand Down
Loading