From 0f4d39d383cfe8295e45e1f0a8acb36860f889da Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Sun, 22 Mar 2026 11:26:40 -0700 Subject: [PATCH] Change the shape of cooperative matrix to support AMD --- Cargo.toml | 1 + blade-graphics/src/gles/mod.rs | 3 +- blade-graphics/src/lib.rs | 31 +++++++-- blade-graphics/src/metal/mod.rs | 14 +++- blade-graphics/src/shader.rs | 6 +- blade-graphics/src/vulkan/init.rs | 96 ++++++++++++++++++++++++---- blade-graphics/src/vulkan/mod.rs | 4 +- blade-graphics/src/vulkan/surface.rs | 2 +- blade-particle/src/system.rs | 12 ++-- examples/bunnymark/example.rs | 5 +- examples/matmul/main.rs | 70 ++++++++++++++------ examples/matmul/matmul.wgsl | 37 +++++++---- examples/particle/main.rs | 9 ++- examples/ray-query/main.rs | 7 +- examples/scene/main.rs | 19 +++--- tests/gpu_examples.rs | 1 + tests/parse_shaders.rs | 13 +++- tests/snapshot.rs | 3 +- 18 files changed, 245 insertions(+), 88 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bca84eb4..1ab48714 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,6 +69,7 @@ egui-winit = { version = "0.33", default-features = false, features = [ ] } transform-gizmo-egui = "0.8" env_logger = "0.11" +half = { version = "2", features = ["bytemuck"] } num_cpus = { workspace = true } glam = { workspace = true } log = { workspace = true } diff --git a/blade-graphics/src/gles/mod.rs b/blade-graphics/src/gles/mod.rs index 7e5928f2..c17cef7c 100644 --- a/blade-graphics/src/gles/mod.rs +++ b/blade-graphics/src/gles/mod.rs @@ -453,7 +453,8 @@ impl Context { ray_query: crate::ShaderVisibility::empty(), sample_count_mask: 0x1 | 0x4, //TODO: accurate info dual_source_blending: false, - cooperative_matrix: false, + shader_float16: false, + cooperative_matrix: crate::CooperativeMatrix::default(), } } diff --git a/blade-graphics/src/lib.rs b/blade-graphics/src/lib.rs index 534e3f98..47f92184 100644 --- a/blade-graphics/src/lib.rs +++ b/blade-graphics/src/lib.rs @@ -169,8 +169,8 @@ pub enum NotSupportedError { impl fmt::Display for NotSupportedError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Platform(e) => write!(f, "platform error: {}", e), + match *self { + Self::Platform(ref e) => write!(f, "platform error: {}", e), Self::NoSupportedDeviceFound => f.write_str("no supported device found"), Self::PlatformNotSupported => f.write_str("platform not supported"), } @@ -196,7 +196,7 @@ pub enum DeviceError { impl fmt::Display for DeviceError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { + match *self { Self::DeviceLost => f.write_str("device lost"), Self::OutOfMemory => f.write_str("out of memory"), } @@ -216,6 +216,25 @@ pub struct MemoryStats { pub usage: u64, } +/// Cooperative matrix support information. +/// +/// Each field is a tile size (8 or 16), or 0 if that configuration +/// is not supported. Naga supports square tiles only (8×8 and 16×16). +#[derive(Clone, Copy, Debug, Default, PartialEq)] +pub struct CooperativeMatrix { + /// Tile size for all-f32 operations. + pub f32_tile: u32, + /// Tile size for f16-input, f32-accumulator operations. + pub f16_tile: u32, +} + +impl CooperativeMatrix { + /// Returns true if any cooperative matrix configuration is supported. + pub fn is_supported(&self) -> bool { + self.f32_tile > 0 || self.f16_tile > 0 + } +} + #[derive(Clone, Debug, Default, PartialEq)] pub struct Capabilities { /// Support binding arrays of handles. @@ -226,8 +245,10 @@ pub struct Capabilities { pub sample_count_mask: u32, /// Support for dual-source blending. pub dual_source_blending: bool, - /// Support for cooperative matrix operations. - pub cooperative_matrix: bool, + /// Support for 16-bit floating-point types in shaders. + pub shader_float16: bool, + /// Cooperative matrix support. + pub cooperative_matrix: CooperativeMatrix, } #[derive(Clone, Debug, Default)] diff --git a/blade-graphics/src/metal/mod.rs b/blade-graphics/src/metal/mod.rs index dde980af..e6d2e340 100644 --- a/blade-graphics/src/metal/mod.rs +++ b/blade-graphics/src/metal/mod.rs @@ -538,9 +538,19 @@ impl Context { .filter(|&count| device.supportsTextureSampleCount(count as _)) .sum(), dual_source_blending: true, - cooperative_matrix: device.supportsFamily(metal::MTLGPUFamily::Apple7) + // Metal Shading Language supports half-precision floats on all supported devices. + shader_float16: true, + cooperative_matrix: if device.supportsFamily(metal::MTLGPUFamily::Apple7) || device.supportsFamily(metal::MTLGPUFamily::Mac2) - || device.supportsFamily(metal::MTLGPUFamily::Metal3), + || device.supportsFamily(metal::MTLGPUFamily::Metal3) + { + crate::CooperativeMatrix { + f32_tile: 8, + f16_tile: 0, + } + } else { + crate::CooperativeMatrix::default() + }, } } diff --git a/blade-graphics/src/shader.rs b/blade-graphics/src/shader.rs index 4263e897..3b94cdea 100644 --- a/blade-graphics/src/shader.rs +++ b/blade-graphics/src/shader.rs @@ -39,9 +39,13 @@ impl super::Context { naga::valid::Capabilities::DUAL_SOURCE_BLENDING, device_caps.dual_source_blending, ); + caps.set( + naga::valid::Capabilities::SHADER_FLOAT16, + device_caps.shader_float16, + ); caps.set( naga::valid::Capabilities::COOPERATIVE_MATRIX, - device_caps.cooperative_matrix, + device_caps.cooperative_matrix.is_supported(), ); naga::valid::Validator::new(flags, caps) .validate(module) diff --git a/blade-graphics/src/vulkan/init.rs b/blade-graphics/src/vulkan/init.rs index 9ddd70ad..f07a2d20 100644 --- a/blade-graphics/src/vulkan/init.rs +++ b/blade-graphics/src/vulkan/init.rs @@ -71,7 +71,8 @@ struct AdapterCapabilities { external_memory: bool, timing: bool, dual_source_blending: bool, - cooperative_matrix: bool, + shader_float16: bool, + cooperative_matrix: crate::CooperativeMatrix, memory_budget: bool, bugs: SystemBugs, } @@ -256,6 +257,8 @@ unsafe fn inspect_adapter( let mut ray_query_features = vk::PhysicalDeviceRayQueryFeaturesKHR::default(); let mut cooperative_matrix_features = vk::PhysicalDeviceCooperativeMatrixFeaturesKHR::default(); let mut vulkan_memory_model_features = vk::PhysicalDeviceVulkanMemoryModelFeatures::default(); + let mut float16_int8_features = vk::PhysicalDeviceShaderFloat16Int8Features::default(); + let mut storage_16bit_features = vk::PhysicalDevice16BitStorageFeatures::default(); let mut features2_khr = vk::PhysicalDeviceFeatures2::default() .push_next(&mut inline_uniform_block_features) .push_next(&mut timeline_semaphore_features) @@ -265,12 +268,15 @@ unsafe fn inspect_adapter( .push_next(&mut acceleration_structure_features) .push_next(&mut ray_query_features) .push_next(&mut cooperative_matrix_features) - .push_next(&mut vulkan_memory_model_features); + .push_next(&mut vulkan_memory_model_features) + .push_next(&mut float16_int8_features) + .push_next(&mut storage_16bit_features); instance .get_physical_device_properties2 .get_physical_device_features2(phd, &mut features2_khr); let dual_source_blending = features2_khr.features.dual_src_blend != 0; + let shader_float16 = float16_int8_features.shader_float16 != 0; let has_inline_ub = supported_extensions.contains(&vk::EXT_INLINE_UNIFORM_BLOCK_NAME) && inline_uniform_block_properties.max_inline_uniform_block_size @@ -382,23 +388,67 @@ unsafe fn inspect_adapter( let cooperative_matrix = if !supported_extensions.contains(&vk::KHR_COOPERATIVE_MATRIX_NAME) { log::info!("No cooperative matrix extension support"); - false + crate::CooperativeMatrix::default() } else if cooperative_matrix_features.cooperative_matrix == vk::FALSE { log::info!( "No cooperative matrix feature support. Features = {:?}", cooperative_matrix_features ); - false + crate::CooperativeMatrix::default() } else if vulkan_memory_model_features.vulkan_memory_model == vk::FALSE { log::info!( "No Vulkan memory model support (required for cooperative matrix). Features = {:?}", vulkan_memory_model_features ); - false + crate::CooperativeMatrix::default() } else { - log::info!("Cooperative matrix is supported"); - true + // Query supported cooperative matrix configurations and find + // square float configurations (Naga supports 8x8 and 16x16). + let coop_props = instance + .cooperative_matrix + .get_physical_device_cooperative_matrix_properties(phd) + .unwrap_or_default(); + let find_tile = |a_type, b_type, c_type, result_type| { + [8u32, 16].into_iter().find(|&size| { + coop_props.iter().any(|p| { + p.m_size == size + && p.n_size == size + && p.k_size == size + && p.a_type == a_type + && p.b_type == b_type + && p.c_type == c_type + && p.result_type == result_type + && p.scope == vk::ScopeKHR::SUBGROUP + }) + }) + }; + let f32t = vk::ComponentTypeKHR::FLOAT32; + let f16t = vk::ComponentTypeKHR::FLOAT16; + let f32_tile = find_tile(f32t, f32t, f32t, f32t).unwrap_or(0); + let f16_tile = if float16_int8_features.shader_float16 != 0 + && storage_16bit_features.storage_buffer16_bit_access != 0 + { + find_tile(f16t, f16t, f32t, f32t).unwrap_or(0) + } else { + 0 + }; + let cm = crate::CooperativeMatrix { f32_tile, f16_tile }; + if cm.is_supported() { + log::info!( + "Cooperative matrix: f32 tile={}, f16 tile={}", + cm.f32_tile, + cm.f16_tile, + ); + } else { + log::info!( + "Cooperative matrix extension present but no usable config. Properties: {:?}", + coop_props + ); + } + cm }; + // Auto-enable shader_float16 when cooperative matrix has f16 support. + let shader_float16 = shader_float16 || cooperative_matrix.f16_tile > 0; let buffer_marker = supported_extensions.contains(&vk::AMD_BUFFER_MARKER_NAME); let shader_info = supported_extensions.contains(&vk::AMD_SHADER_INFO_NAME); @@ -434,6 +484,7 @@ unsafe fn inspect_adapter( external_memory, timing, dual_source_blending, + shader_float16, cooperative_matrix, memory_budget, bugs, @@ -601,7 +652,7 @@ impl super::Context { } } else { unsafe { entry.create_instance(&create_info, None) } - .map_err(|e| crate::PlatformError::init(e))? + .map_err(crate::PlatformError::init)? } }; @@ -610,6 +661,7 @@ impl super::Context { _debug_utils: ext::debug_utils::Instance::new(&entry, &core_instance), get_physical_device_properties2: khr::get_physical_device_properties2::Instance::new(&entry, &core_instance), + cooperative_matrix: khr::cooperative_matrix::Instance::new(&entry, &core_instance), get_surface_capabilities2: if desc.presentation { Some(khr::get_surface_capabilities2::Instance::new( &entry, @@ -663,7 +715,7 @@ impl super::Context { instance .core .enumerate_physical_devices() - .map_err(|e| crate::PlatformError::init(e))? + .map_err(crate::PlatformError::init)? .into_iter() .find_map(|phd| { inspect_adapter( @@ -738,7 +790,7 @@ impl super::Context { vk::KHR_EXTERNAL_MEMORY_FD_NAME }); } - if capabilities.cooperative_matrix { + if capabilities.cooperative_matrix.is_supported() { device_extensions.push(vk::KHR_COOPERATIVE_MATRIX_NAME); if capabilities.api_version < vk::API_VERSION_1_2 { device_extensions.push(vk::KHR_VULKAN_MEMORY_MODEL_NAME); @@ -810,9 +862,27 @@ impl super::Context { .push_next(&mut khr_ray_query); } + let mut khr_float16_int8; + let mut storage_16bit; + if capabilities.shader_float16 { + khr_float16_int8 = vk::PhysicalDeviceShaderFloat16Int8Features { + shader_float16: vk::TRUE, + ..Default::default() + }; + device_create_info = device_create_info.push_next(&mut khr_float16_int8); + } + if capabilities.cooperative_matrix.f16_tile > 0 { + storage_16bit = vk::PhysicalDevice16BitStorageFeatures { + storage_buffer16_bit_access: vk::TRUE, + uniform_and_storage_buffer16_bit_access: vk::TRUE, + ..Default::default() + }; + device_create_info = device_create_info.push_next(&mut storage_16bit); + } + let mut khr_cooperative_matrix; let mut vulkan_memory_model; - if capabilities.cooperative_matrix { + if capabilities.cooperative_matrix.is_supported() { khr_cooperative_matrix = vk::PhysicalDeviceCooperativeMatrixFeaturesKHR { cooperative_matrix: vk::TRUE, ..Default::default() @@ -861,7 +931,7 @@ impl super::Context { instance .core .create_device(physical_device, &device_create_info, None) - .map_err(|e| crate::PlatformError::init(e))? + .map_err(crate::PlatformError::init)? } }; @@ -1123,6 +1193,7 @@ impl super::Context { .limits .framebuffer_depth_sample_counts, dual_source_blending: capabilities.dual_source_blending, + shader_float16: capabilities.shader_float16, cooperative_matrix: capabilities.cooperative_matrix, binding_array: capabilities.binding_array, memory_budget: capabilities.memory_budget, @@ -1153,6 +1224,7 @@ impl super::Context { }, sample_count_mask: self.sample_count_flags.as_raw(), dual_source_blending: self.dual_source_blending, + shader_float16: self.shader_float16, cooperative_matrix: self.cooperative_matrix, } } diff --git a/blade-graphics/src/vulkan/mod.rs b/blade-graphics/src/vulkan/mod.rs index 7ac24bba..69295830 100644 --- a/blade-graphics/src/vulkan/mod.rs +++ b/blade-graphics/src/vulkan/mod.rs @@ -19,6 +19,7 @@ struct Instance { core: ash::Instance, _debug_utils: ash::ext::debug_utils::Instance, get_physical_device_properties2: khr::get_physical_device_properties2::Instance, + cooperative_matrix: khr::cooperative_matrix::Instance, get_surface_capabilities2: Option, surface: Option, } @@ -265,7 +266,8 @@ pub struct Context { min_uniform_buffer_offset_alignment: u64, sample_count_flags: vk::SampleCountFlags, dual_source_blending: bool, - cooperative_matrix: bool, + shader_float16: bool, + cooperative_matrix: crate::CooperativeMatrix, binding_array: bool, memory_budget: bool, instance: Instance, diff --git a/blade-graphics/src/vulkan/surface.rs b/blade-graphics/src/vulkan/surface.rs index 32cc05c5..97da7128 100644 --- a/blade-graphics/src/vulkan/surface.rs +++ b/blade-graphics/src/vulkan/surface.rs @@ -198,7 +198,7 @@ impl super::Context { window.window_handle().unwrap().as_raw(), None, ) - .map_err(|e| crate::PlatformError::init(e))? + .map_err(crate::PlatformError::init)? }; let khr_surface = self diff --git a/blade-particle/src/system.rs b/blade-particle/src/system.rs index 5d196664..e73c2370 100644 --- a/blade-particle/src/system.rs +++ b/blade-particle/src/system.rs @@ -215,12 +215,12 @@ impl ParticleSystem { EmitterShape::Sphere { radius } => radius, }; - let (colors, color_count) = match &self.effect.particle.color { + let (colors, color_count) = match self.effect.particle.color { ColorConfig::Solid(c) => { - let packed = pack_color(*c); + let packed = pack_color(c); ([packed, packed, packed, packed], 1u32) } - ColorConfig::Palette(palette) => { + ColorConfig::Palette(ref palette) => { let mut colors = [0u32; 4]; let count = palette.len().min(4); for i in 0..count { @@ -267,7 +267,7 @@ impl ParticleSystem { let mut pc = pass.with(&pipeline.reset_pipeline); pc.bind(0, &self.main_data()); let group_size = pipeline.reset_pipeline.get_workgroup_size(); - let group_count = (self.capacity as u32 + group_size[0] - 1) / group_size[0]; + let group_count = (self.capacity as u32).div_ceil(group_size[0]); pc.dispatch([group_count, 1, 1]); self.needs_reset = false; } @@ -298,7 +298,7 @@ impl ParticleSystem { self.emit_accumulator -= emit_count as f32; let params = self.make_emit_params(emit_count, self.origin); let wg_size = pipeline.emit_pipeline.get_workgroup_size()[0]; - let groups = (emit_count + wg_size - 1) / wg_size; + let groups = emit_count.div_ceil(wg_size); let mut pass = encoder.compute("particle emit continuous"); let mut pc = pass.with(&pipeline.emit_pipeline); pc.bind(0, &main_data); @@ -317,7 +317,7 @@ impl ParticleSystem { for burst in bursts { let params = self.make_emit_params(burst.count, burst.position); let wg_size = pipeline.emit_pipeline.get_workgroup_size()[0]; - let groups = (burst.count + wg_size - 1) / wg_size; + let groups = burst.count.div_ceil(wg_size); let mut pass = encoder.compute("particle emit burst"); let mut pc = pass.with(&pipeline.emit_pipeline); pc.bind(0, &main_data); diff --git a/examples/bunnymark/example.rs b/examples/bunnymark/example.rs index ce8fb471..2736c212 100644 --- a/examples/bunnymark/example.rs +++ b/examples/bunnymark/example.rs @@ -159,8 +159,7 @@ impl Example { } context.sync_buffer(vertex_buf); - let mut bunnies = Vec::new(); - bunnies.push(Sprite { + let bunnies = vec![Sprite { data: SpriteData { locals: Locals { position: [-100.0, 100.0], @@ -170,7 +169,7 @@ impl Example { }, }, vertex_buf: vertex_buf.into(), - }); + }]; let mut command_encoder = context.create_command_encoder(gpu::CommandEncoderDesc { name: "init", diff --git a/examples/matmul/main.rs b/examples/matmul/main.rs index c4f15a4e..e6ccee46 100644 --- a/examples/matmul/main.rs +++ b/examples/matmul/main.rs @@ -1,9 +1,12 @@ //! Fast matrix multiplication using cooperative matrix operations. //! -//! Computes C = A * B for 64x64 f32 matrices using 8x8 cooperative +//! Computes C = A * B for 64x64 matrices using cooperative //! matrix tiles (tensor cores / simdgroup matrix), then verifies //! against a CPU reference. //! +//! Adapts to the device's supported tile size (8 or 16) and scalar +//! type (f32 or f16 inputs with f32 accumulator). +//! //! Requires VK_KHR_cooperative_matrix (Vulkan) or Apple7+ (Metal). use blade_graphics as gpu; @@ -13,7 +16,6 @@ use std::mem; const M: u32 = 64; const N: u32 = 64; const K: u32 = 64; -const TILE: u32 = 8; #[repr(C)] #[derive(Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)] @@ -43,23 +45,38 @@ fn main() { }; let caps = context.capabilities(); - if !caps.cooperative_matrix { + let cm = caps.cooperative_matrix; + // Prefer f32 inputs, fall back to f16 inputs with f32 accumulator. + let (tile, f16_input) = if cm.f32_tile > 0 { + (cm.f32_tile, false) + } else if cm.f16_tile > 0 { + (cm.f16_tile, true) + } else { eprintln!( "Cooperative matrix not supported on this device ({}).", context.device_information().device_name ); eprintln!("Requires VK_KHR_cooperative_matrix (Vulkan) or Apple7+ (Metal)."); return; - } + }; + let input_type = if f16_input { "f16" } else { "f32" }; println!( - "Device: {} (cooperative matrix supported)", + "Device: {} (cooperative matrix {tile}x{tile}, {input_type} input)", context.device_information().device_name ); + // Specialize shader source for the device's capabilities + let coop_type = format!("coop_mat{tile}x{tile}"); + let source_template = include_str!("matmul.wgsl"); + let source = source_template + .replace("ENABLE_F16", if f16_input { "enable f16;" } else { "" }) + .replace("COOP_MAT", &coop_type) + .replace("INPUT_SCALAR", input_type) + .replace("TILE_SIZE", &format!("{tile}u")); + // Create shader and pipeline - let source = include_str!("matmul.wgsl"); let shader = context.create_shader(gpu::ShaderDesc { - source, + source: &source, naga_module: None, }); let mut pipeline = context.create_compute_pipeline(gpu::ComputePipelineDesc { @@ -68,20 +85,34 @@ fn main() { compute: shader.at("main"), }); - // Prepare input matrices - let a_data: Vec = (0..M * K).map(|i| (i % 7) as f32 * 0.1).collect(); - let b_data: Vec = (0..K * N).map(|i| (i % 11) as f32 * 0.1).collect(); + // Prepare input matrices (always compute in f32, convert to f16 if needed) + let a_f32: Vec = (0..M * K).map(|i| (i % 7) as f32 * 0.1).collect(); + let b_f32: Vec = (0..K * N).map(|i| (i % 11) as f32 * 0.1).collect(); let c_data: Vec = vec![0.0; (M * N) as usize]; let params = Params { m: M, n: N, k: K }; + let (a_bytes, b_bytes): (Vec, Vec) = if f16_input { + let a_f16: Vec = a_f32.iter().map(|&v| half::f16::from_f32(v)).collect(); + let b_f16: Vec = b_f32.iter().map(|&v| half::f16::from_f32(v)).collect(); + ( + bytemuck::cast_slice(&a_f16).to_vec(), + bytemuck::cast_slice(&b_f16).to_vec(), + ) + } else { + ( + bytemuck::cast_slice(&a_f32).to_vec(), + bytemuck::cast_slice(&b_f32).to_vec(), + ) + }; + let buf_a = context.create_buffer(gpu::BufferDesc { name: "matrix_a", - size: (a_data.len() * mem::size_of::()) as u64, + size: a_bytes.len() as u64, memory: gpu::Memory::Shared, }); let buf_b = context.create_buffer(gpu::BufferDesc { name: "matrix_b", - size: (b_data.len() * mem::size_of::()) as u64, + size: b_bytes.len() as u64, memory: gpu::Memory::Shared, }); let buf_c = context.create_buffer(gpu::BufferDesc { @@ -97,8 +128,8 @@ fn main() { // Upload data unsafe { - std::ptr::copy_nonoverlapping(a_data.as_ptr() as *const u8, buf_a.data(), a_data.len() * 4); - std::ptr::copy_nonoverlapping(b_data.as_ptr() as *const u8, buf_b.data(), b_data.len() * 4); + std::ptr::copy_nonoverlapping(a_bytes.as_ptr(), buf_a.data(), a_bytes.len()); + std::ptr::copy_nonoverlapping(b_bytes.as_ptr(), buf_b.data(), b_bytes.len()); std::ptr::copy_nonoverlapping(c_data.as_ptr() as *const u8, buf_c.data(), c_data.len() * 4); std::ptr::copy_nonoverlapping( ¶ms as *const Params as *const u8, @@ -125,7 +156,7 @@ fn main() { params: buf_params.into(), }, ); - pe.dispatch([M / TILE, N / TILE, 1]); + pe.dispatch([M / tile, N / tile, 1]); } let sp = context.submit(&mut encoder); let _ = context.wait_for(&sp, !0); @@ -140,13 +171,14 @@ fn main() { for j in 0..N { let mut sum = 0.0f32; for ki in 0..K { - sum += a_data[(i * K + ki) as usize] * b_data[(ki * N + j) as usize]; + sum += a_f32[(i * K + ki) as usize] * b_f32[(ki * N + j) as usize]; } reference[(i * N + j) as usize] = sum; } } - // Verify + // Verify (f16 inputs lose precision, so use a wider tolerance) + let tolerance = if f16_input { 0.5 } else { 0.01 }; let mut max_error = 0.0f32; for i in 0..(M * N) as usize { max_error = max_error.max((result[i] - reference[i]).abs()); @@ -154,10 +186,10 @@ fn main() { println!("Matrix multiplication {M}x{K}x{N} complete."); println!("Max error vs CPU reference: {max_error:.6}"); - if max_error < 0.01 { + if max_error < tolerance { println!("PASS"); } else { - println!("FAIL (tolerance: 0.01)"); + println!("FAIL (tolerance: {tolerance})"); } // Print top-left 4x4 of result diff --git a/examples/matmul/matmul.wgsl b/examples/matmul/matmul.wgsl index 332698a9..7a72c16e 100644 --- a/examples/matmul/matmul.wgsl +++ b/examples/matmul/matmul.wgsl @@ -1,15 +1,22 @@ -// Cooperative matrix multiplication: C = A * B +// Cooperative matrix multiplication: C = A * B + C // -// Each workgroup handles one 8x8 output tile. -// The K dimension is iterated in 8-wide steps, -// loading 8x8 tiles of A and B and accumulating into C. +// Each workgroup handles one output tile. +// The K dimension is iterated in TILE-wide steps, +// loading tiles of A and B and accumulating into C. +// +// The host substitutes placeholders based on device capabilities: +// ENABLE_F16 - "enable f16;" or empty +// TILE_SIZE - 8u or 16u +// INPUT_SCALAR - f32 or f16 +// COOP_MAT - coop_mat8x8 or coop_mat16x16 enable wgpu_cooperative_matrix; +ENABLE_F16 -const TILE: u32 = 8u; +const TILE: u32 = TILE_SIZE; -var matrix_a: array; -var matrix_b: array; +var matrix_a: array; +var matrix_b: array; var matrix_c: array; struct Params { @@ -19,23 +26,25 @@ struct Params { } var params: Params; -@compute @workgroup_size(8, 8, 1) +// Workgroup X must be a multiple of the subgroup size (32 or 64). +// 64 is the LCM of common subgroup sizes. +@compute @workgroup_size(64, 1, 1) fn main(@builtin(workgroup_id) wg: vec3) { let row = wg.x * TILE; let col = wg.y * TILE; let n = params.n; let k = params.k; - // Zero-initialize accumulator + // Zero-initialize accumulator (row-major load with stride = n) let c_offset = row * n + col; - var acc = coopLoad>(&matrix_c[c_offset], n); + var acc = coopLoadT>(&matrix_c[c_offset], n); - // Accumulate tiles along K + // Accumulate tiles along K (row-major loads) for (var t: u32 = 0u; t < k; t += TILE) { - let a = coopLoad>(&matrix_a[row * k + t], k); - let b = coopLoad>(&matrix_b[t * n + col], n); + let a = coopLoadT>(&matrix_a[row * k + t], k); + let b = coopLoadT>(&matrix_b[t * n + col], n); acc = coopMultiplyAdd(a, b, acc); } - coopStore(acc, &matrix_c[c_offset], n); + coopStoreT(acc, &matrix_c[c_offset], n); } diff --git a/examples/particle/main.rs b/examples/particle/main.rs index 1d5d2123..a29831ba 100644 --- a/examples/particle/main.rs +++ b/examples/particle/main.rs @@ -458,12 +458,11 @@ impl winit::application::ApplicationHandler for App { .. }, .. - } => match key_code { - winit::keyboard::KeyCode::Escape => { + } => { + if key_code == winit::keyboard::KeyCode::Escape { event_loop.exit(); } - _ => {} - }, + } winit::event::WindowEvent::CloseRequested => { event_loop.exit(); } @@ -497,7 +496,7 @@ impl winit::application::ApplicationHandler for App { let control_flow = if let Some(repaint_after_instant) = std::time::Instant::now().checked_add(repaint_delay) { - winit::event_loop::ControlFlow::WaitUntil(repaint_after_instant.into()) + winit::event_loop::ControlFlow::WaitUntil(repaint_after_instant) } else { winit::event_loop::ControlFlow::Wait }; diff --git a/examples/ray-query/main.rs b/examples/ray-query/main.rs index faff40b3..abd2a82e 100644 --- a/examples/ray-query/main.rs +++ b/examples/ray-query/main.rs @@ -87,12 +87,11 @@ impl winit::application::ApplicationHandler for App { .. }, .. - } => match key_code { - winit::keyboard::KeyCode::Escape => { + } => { + if key_code == winit::keyboard::KeyCode::Escape { event_loop.exit(); } - _ => {} - }, + } winit::event::WindowEvent::RedrawRequested => { let example = self.example.as_mut().unwrap(); let context = self.context.as_ref().unwrap(); diff --git a/examples/scene/main.rs b/examples/scene/main.rs index eff3b56c..239a9d7f 100644 --- a/examples/scene/main.rs +++ b/examples/scene/main.rs @@ -625,11 +625,10 @@ impl Example { self.debug_blit = if let Some(view) = blit_view { let mut db = match self.debug_blit.take() { Some(db) => db, - None => { - let mut db = blade_render::DebugBlit::default(); - db.target_size = [min_size, min_size]; - db - } + None => blade_render::DebugBlit { + target_size: [min_size, min_size], + ..Default::default() + }, }; db.input = view; let style = ui.style(); @@ -734,11 +733,9 @@ impl Example { tc.rotation = glam::Quat::from_euler(glam::EulerRot::default(), a1, a2, a3); let transform = tc.to_blade(); - if object.transform != transform { - if tc.is_inversible() { - object.transform = transform; - self.have_objects_changed = true; - } + if object.transform != transform && tc.is_inversible() { + object.transform = transform; + self.have_objects_changed = true; } }); } @@ -915,7 +912,7 @@ impl winit::application::ApplicationHandler for App { example.is_point_selected = false; } winit::event::WindowEvent::CursorMoved { position, .. } => { - if let Some(_) = self.drag_start { + if self.drag_start.is_some() { let prev = glam::Quat::from(example.camera.inner.rot); let rotation_local = glam::Quat::from_rotation_x( (self.last_mouse_pos[1] as f32 - position.y as f32) * drag_speed, diff --git a/tests/gpu_examples.rs b/tests/gpu_examples.rs index 450b52c8..23027fa8 100644 --- a/tests/gpu_examples.rs +++ b/tests/gpu_examples.rs @@ -10,6 +10,7 @@ use blade_graphics as gpu; use blade_graphics::ShaderData; use std::slice; +#[allow(dead_code)] #[path = "../examples/bunnymark/example.rs"] mod bunnymark_example; #[cfg(not(gles))] diff --git a/tests/parse_shaders.rs b/tests/parse_shaders.rs index 249ee2ba..7cee4d68 100644 --- a/tests/parse_shaders.rs +++ b/tests/parse_shaders.rs @@ -37,7 +37,7 @@ fn parse_wgsl() { } }; let shader_raw = match path.extension() { - Some(ostr) if &*ostr == "wgsl" => { + Some(ostr) if ostr == "wgsl" => { println!("Validating {:?}", path); fs::read(&path).unwrap_or_default() } @@ -45,7 +45,16 @@ fn parse_wgsl() { }; let cooker = blade_asset::Cooker::new(&example, Default::default()); - let text_out = blade_render::shader::parse_shader(&shader_raw, &cooker, &expansions); + let mut text_out = + blade_render::shader::parse_shader(&shader_raw, &cooker, &expansions); + + // Substitute cooperative matrix template placeholders with defaults + // so the shader parses as valid WGSL. + text_out = text_out + .replace("ENABLE_F16", "") + .replace("COOP_MAT", "coop_mat8x8") + .replace("INPUT_SCALAR", "f32") + .replace("TILE_SIZE", "8u"); let module = match wgsl::parse_str(&text_out) { Ok(module) => module, diff --git a/tests/snapshot.rs b/tests/snapshot.rs index c2afad0c..51721d72 100644 --- a/tests/snapshot.rs +++ b/tests/snapshot.rs @@ -56,7 +56,8 @@ impl OffscreenTarget { context: &gpu::Context, encoder: &mut gpu::CommandEncoder, ) -> Vec { - if let mut transfer = encoder.transfer("snapshot-readback") { + { + let mut transfer = encoder.transfer("snapshot-readback"); transfer.copy_texture_to_buffer( self.texture.into(), self.readback.into(),