Skip to content

Combine bits via xor when bitcasting from larger to smaller type#734

Merged
patrick-kidger merged 3 commits intopatrick-kidger:mainfrom
alexander-de-ranitz:fix_bitcast_from_float64_to_int32
Feb 23, 2026
Merged

Combine bits via xor when bitcasting from larger to smaller type#734
patrick-kidger merged 3 commits intopatrick-kidger:mainfrom
alexander-de-ranitz:fix_bitcast_from_float64_to_int32

Conversation

@alexander-de-ranitz
Copy link
Contributor

Fixes #733.

Previously, when downcasting to a type with fewer bits, the larger type would be converted to a smaller intermediate dtype, which was subsequently bitcasted to the requested type. The casting to a smaller intermediate type loses precision, and can result in distinct values being cast to identical values. Since these values are used as keys to generate random sequences, this is problematic, as it results in identical noise being generated in subsequent timesteps. This commit fixes this by not throwing away bits when the input type is larger than the requested output type. Instead, the larger number is bitcast to multiple values in the smaller type, which are then combined using xor.

Previously, when casting e.g. a float64 to an int32, numbers that were close in float64 could be mapped to identical int32's. Since these int's are used as keys to generate random sequences, this is problematic, as it results in identical noise being generated in subsequent timesteps. This commit fixes this by not throwing away bits when the input type is larger than the requested output type. Instead, the larger number is bitcast to multiple values in the smaller type, which are then combined using xor.
@patrick-kidger
Copy link
Owner

Looks reasonable to me! (Though I think the pre-commits are failing, see CONTRIBUTING.md.) It'd also be great to see tests that the new functionality is correct but no worries if you're up for that?

Comment on lines +21 to +24
if result.shape != val.shape:
result = jnp.bitwise_xor.reduce(result, axis=-1)

return result
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe an assert result.shape == val.shape afterwards to be sure that this xor has done its job?

This checks the intended behaviour of mapping nearby numbers to distinct values when downcasting to a smaller dtype.
@alexander-de-ranitz
Copy link
Contributor Author

Added two assertions (shape and dtype) and a simple test which fails with the previous implementation, but passes with these changes. The formatting issues have also been fixed. LMK if any further additions are necessary!

@patrick-kidger patrick-kidger merged commit 5ea5fcb into patrick-kidger:main Feb 23, 2026
2 checks passed
@patrick-kidger
Copy link
Owner

Looks perfect! Merged 🎉 Thank you for the fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

UnsafeBrownianPath.evaluate() bitcasts to int32, which results in unexpected results when using float64 timesteps

2 participants