Skip to content

Commit 8afb3ce

Browse files
add storage_texture option to as_bind_group macro (#9943)
# Objective - Add the ability to describe storage texture bindings when deriving `AsBindGroup`. - This is especially valuable for the compute story of bevy which deserves some extra love imo. ## Solution - This add the ability to annotate struct fields with a `#[storage_texture(0)]` annotation. - Instead of adding specific option parsing for all the image formats and access modes, I simply accept a token stream and defer checking to see if the option is valid to the compiler. This still results in useful and friendly errors and is free to maintain and always compatible with wgpu changes. --- ## Changelog - The `#[storage_texture(..)]` annotation is now accepted for fields of `Handle<Image>` in structs that derive `AsBindGroup`. - The game_of_life compute shader example has been updated to use `AsBindGroup` together with `[storage_texture(..)]` to obtain the `BindGroupLayout`. ## Migration Guide
1 parent 0fa14c8 commit 8afb3ce

File tree

4 files changed

+146
-13
lines changed

4 files changed

+146
-13
lines changed

crates/bevy_render/macros/src/as_bind_group.rs

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use syn::{
1111

1212
const UNIFORM_ATTRIBUTE_NAME: Symbol = Symbol("uniform");
1313
const TEXTURE_ATTRIBUTE_NAME: Symbol = Symbol("texture");
14+
const STORAGE_TEXTURE_ATTRIBUTE_NAME: Symbol = Symbol("storage_texture");
1415
const SAMPLER_ATTRIBUTE_NAME: Symbol = Symbol("sampler");
1516
const STORAGE_ATTRIBUTE_NAME: Symbol = Symbol("storage");
1617
const BIND_GROUP_DATA_ATTRIBUTE_NAME: Symbol = Symbol("bind_group_data");
@@ -19,6 +20,7 @@ const BIND_GROUP_DATA_ATTRIBUTE_NAME: Symbol = Symbol("bind_group_data");
1920
enum BindingType {
2021
Uniform,
2122
Texture,
23+
StorageTexture,
2224
Sampler,
2325
Storage,
2426
}
@@ -133,6 +135,8 @@ pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result<TokenStream> {
133135
BindingType::Uniform
134136
} else if attr_ident == TEXTURE_ATTRIBUTE_NAME {
135137
BindingType::Texture
138+
} else if attr_ident == STORAGE_TEXTURE_ATTRIBUTE_NAME {
139+
BindingType::StorageTexture
136140
} else if attr_ident == SAMPLER_ATTRIBUTE_NAME {
137141
BindingType::Sampler
138142
} else if attr_ident == STORAGE_ATTRIBUTE_NAME {
@@ -255,6 +259,45 @@ pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result<TokenStream> {
255259
}
256260
});
257261
}
262+
BindingType::StorageTexture => {
263+
let StorageTextureAttrs {
264+
dimension,
265+
image_format,
266+
access,
267+
visibility,
268+
} = get_storage_texture_binding_attr(nested_meta_items)?;
269+
270+
let visibility =
271+
visibility.hygienic_quote(&quote! { #render_path::render_resource });
272+
273+
let fallback_image = get_fallback_image(&render_path, dimension);
274+
275+
binding_impls.push(quote! {
276+
( #binding_index,
277+
#render_path::render_resource::OwnedBindingResource::TextureView({
278+
let handle: Option<&#asset_path::Handle<#render_path::texture::Image>> = (&self.#field_name).into();
279+
if let Some(handle) = handle {
280+
images.get(handle).ok_or_else(|| #render_path::render_resource::AsBindGroupError::RetryNextUpdate)?.texture_view.clone()
281+
} else {
282+
#fallback_image.texture_view.clone()
283+
}
284+
})
285+
)
286+
});
287+
288+
binding_layouts.push(quote! {
289+
#render_path::render_resource::BindGroupLayoutEntry {
290+
binding: #binding_index,
291+
visibility: #visibility,
292+
ty: #render_path::render_resource::BindingType::StorageTexture {
293+
access: #render_path::render_resource::StorageTextureAccess::#access,
294+
format: #render_path::render_resource::TextureFormat::#image_format,
295+
view_dimension: #render_path::render_resource::#dimension,
296+
},
297+
count: None,
298+
}
299+
});
300+
}
258301
BindingType::Texture => {
259302
let TextureAttrs {
260303
dimension,
@@ -585,6 +628,10 @@ impl ShaderStageVisibility {
585628
fn vertex_fragment() -> Self {
586629
Self::Flags(VisibilityFlags::vertex_fragment())
587630
}
631+
632+
fn compute() -> Self {
633+
Self::Flags(VisibilityFlags::compute())
634+
}
588635
}
589636

590637
impl VisibilityFlags {
@@ -595,6 +642,13 @@ impl VisibilityFlags {
595642
..Default::default()
596643
}
597644
}
645+
646+
fn compute() -> Self {
647+
Self {
648+
compute: true,
649+
..Default::default()
650+
}
651+
}
598652
}
599653

600654
impl ShaderStageVisibility {
@@ -741,7 +795,72 @@ impl Default for TextureAttrs {
741795
}
742796
}
743797

798+
struct StorageTextureAttrs {
799+
dimension: BindingTextureDimension,
800+
// Parsing of the image_format parameter is deferred to the type checker,
801+
// which will error if the format is not member of the TextureFormat enum.
802+
image_format: proc_macro2::TokenStream,
803+
// Parsing of the access parameter is deferred to the type checker,
804+
// which will error if the access is not member of the StorageTextureAccess enum.
805+
access: proc_macro2::TokenStream,
806+
visibility: ShaderStageVisibility,
807+
}
808+
809+
impl Default for StorageTextureAttrs {
810+
fn default() -> Self {
811+
Self {
812+
dimension: Default::default(),
813+
image_format: quote! { Rgba8Unorm },
814+
access: quote! { ReadWrite },
815+
visibility: ShaderStageVisibility::compute(),
816+
}
817+
}
818+
}
819+
820+
fn get_storage_texture_binding_attr(metas: Vec<Meta>) -> Result<StorageTextureAttrs> {
821+
let mut storage_texture_attrs = StorageTextureAttrs::default();
822+
823+
for meta in metas {
824+
use syn::Meta::{List, NameValue};
825+
match meta {
826+
// Parse #[storage_texture(0, dimension = "...")].
827+
NameValue(m) if m.path == DIMENSION => {
828+
let value = get_lit_str(DIMENSION, &m.value)?;
829+
storage_texture_attrs.dimension = get_texture_dimension_value(value)?;
830+
}
831+
// Parse #[storage_texture(0, format = ...))].
832+
NameValue(m) if m.path == IMAGE_FORMAT => {
833+
storage_texture_attrs.image_format = m.value.into_token_stream();
834+
}
835+
// Parse #[storage_texture(0, access = ...))].
836+
NameValue(m) if m.path == ACCESS => {
837+
storage_texture_attrs.access = m.value.into_token_stream();
838+
}
839+
// Parse #[storage_texture(0, visibility(...))].
840+
List(m) if m.path == VISIBILITY => {
841+
storage_texture_attrs.visibility = get_visibility_flag_value(&m)?;
842+
}
843+
NameValue(m) => {
844+
return Err(Error::new_spanned(
845+
m.path,
846+
"Not a valid name. Available attributes: `dimension`, `image_format`, `access`.",
847+
));
848+
}
849+
_ => {
850+
return Err(Error::new_spanned(
851+
meta,
852+
"Not a name value pair: `foo = \"...\"`",
853+
));
854+
}
855+
}
856+
}
857+
858+
Ok(storage_texture_attrs)
859+
}
860+
744861
const DIMENSION: Symbol = Symbol("dimension");
862+
const IMAGE_FORMAT: Symbol = Symbol("image_format");
863+
const ACCESS: Symbol = Symbol("access");
745864
const SAMPLE_TYPE: Symbol = Symbol("sample_type");
746865
const FILTERABLE: Symbol = Symbol("filterable");
747866
const MULTISAMPLED: Symbol = Symbol("multisampled");

crates/bevy_render/macros/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ pub fn derive_extract_component(input: TokenStream) -> TokenStream {
5151

5252
#[proc_macro_derive(
5353
AsBindGroup,
54-
attributes(uniform, texture, sampler, bind_group_data, storage)
54+
attributes(uniform, storage_texture, texture, sampler, bind_group_data, storage)
5555
)]
5656
pub fn derive_as_bind_group(input: TokenStream) -> TokenStream {
5757
let input = parse_macro_input!(input as DeriveInput);

crates/bevy_render/src/render_resource/bind_group.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ impl Deref for BindGroup {
8787
/// values: Vec<f32>,
8888
/// #[storage(4, read_only, buffer)]
8989
/// buffer: Buffer,
90+
/// #[storage_texture(5)]
91+
/// storage_texture: Handle<Image>,
9092
/// }
9193
/// ```
9294
///
@@ -97,6 +99,7 @@ impl Deref for BindGroup {
9799
/// @group(2) @binding(1) var color_texture: texture_2d<f32>;
98100
/// @group(2) @binding(2) var color_sampler: sampler;
99101
/// @group(2) @binding(3) var<storage> values: array<f32>;
102+
/// @group(2) @binding(5) var storage_texture: texture_storage_2d<rgba8unorm, read_write>;
100103
/// ```
101104
/// Note that the "group" index is determined by the usage context. It is not defined in [`AsBindGroup`]. For example, in Bevy material bind groups
102105
/// are generally bound to group 2.
@@ -123,6 +126,19 @@ impl Deref for BindGroup {
123126
/// | `multisampled` = ... | `true`, `false` | `false` |
124127
/// | `visibility(...)` | `all`, `none`, or a list-combination of `vertex`, `fragment`, `compute` | `vertex`, `fragment` |
125128
///
129+
/// * `storage_texture(BINDING_INDEX, arguments)`
130+
/// * This field's [`Handle<Image>`](bevy_asset::Handle) will be used to look up the matching [`Texture`](crate::render_resource::Texture)
131+
/// GPU resource, which will be bound as a storage texture in shaders. The field will be assumed to implement [`Into<Option<Handle<Image>>>`]. In practice,
132+
/// most fields should be a [`Handle<Image>`](bevy_asset::Handle) or [`Option<Handle<Image>>`]. If the value of an [`Option<Handle<Image>>`] is
133+
/// [`None`], the [`FallbackImage`] resource will be used instead.
134+
///
135+
/// | Arguments | Values | Default |
136+
/// |------------------------|--------------------------------------------------------------------------------------------|---------------|
137+
/// | `dimension` = "..." | `"1d"`, `"2d"`, `"2d_array"`, `"3d"`, `"cube"`, `"cube_array"` | `"2d"` |
138+
/// | `image_format` = ... | any member of [`TextureFormat`](crate::render_resource::TextureFormat) | `Rgba8Unorm` |
139+
/// | `access` = ... | any member of [`StorageTextureAccess`](crate::render_resource::StorageTextureAccess) | `ReadWrite` |
140+
/// | `visibility(...)` | `all`, `none`, or a list-combination of `vertex`, `fragment`, `compute` | `compute` |
141+
///
126142
/// * `sampler(BINDING_INDEX, arguments)`
127143
/// * This field's [`Handle<Image>`](bevy_asset::Handle) will be used to look up the matching [`Sampler`] GPU
128144
/// resource, which will be bound as a sampler in shaders. The field will be assumed to implement [`Into<Option<Handle<Image>>>`]. In practice,

examples/shader/compute_shader_game_of_life.rs

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use bevy::{
1010
render_asset::RenderAssetPersistencePolicy,
1111
render_asset::RenderAssets,
1212
render_graph::{self, RenderGraph},
13-
render_resource::{binding_types::texture_storage_2d, *},
13+
render_resource::*,
1414
renderer::{RenderContext, RenderDevice},
1515
Render, RenderApp, RenderSet,
1616
},
@@ -65,7 +65,7 @@ fn setup(mut commands: Commands, mut images: ResMut<Assets<Image>>) {
6565
});
6666
commands.spawn(Camera2dBundle::default());
6767

68-
commands.insert_resource(GameOfLifeImage(image));
68+
commands.insert_resource(GameOfLifeImage { texture: image });
6969
}
7070

7171
pub struct GameOfLifeComputePlugin;
@@ -95,8 +95,11 @@ impl Plugin for GameOfLifeComputePlugin {
9595
}
9696
}
9797

98-
#[derive(Resource, Clone, Deref, ExtractResource)]
99-
struct GameOfLifeImage(Handle<Image>);
98+
#[derive(Resource, Clone, Deref, ExtractResource, AsBindGroup)]
99+
struct GameOfLifeImage {
100+
#[storage_texture(0, image_format = Rgba8Unorm, access = ReadWrite)]
101+
texture: Handle<Image>,
102+
}
100103

101104
#[derive(Resource)]
102105
struct GameOfLifeImageBindGroup(BindGroup);
@@ -108,7 +111,7 @@ fn prepare_bind_group(
108111
game_of_life_image: Res<GameOfLifeImage>,
109112
render_device: Res<RenderDevice>,
110113
) {
111-
let view = gpu_images.get(&game_of_life_image.0).unwrap();
114+
let view = gpu_images.get(&game_of_life_image.texture).unwrap();
112115
let bind_group = render_device.create_bind_group(
113116
None,
114117
&pipeline.texture_bind_group_layout,
@@ -126,13 +129,8 @@ pub struct GameOfLifePipeline {
126129

127130
impl FromWorld for GameOfLifePipeline {
128131
fn from_world(world: &mut World) -> Self {
129-
let texture_bind_group_layout = world.resource::<RenderDevice>().create_bind_group_layout(
130-
None,
131-
&BindGroupLayoutEntries::single(
132-
ShaderStages::COMPUTE,
133-
texture_storage_2d(TextureFormat::Rgba8Unorm, StorageTextureAccess::ReadWrite),
134-
),
135-
);
132+
let render_device = world.resource::<RenderDevice>();
133+
let texture_bind_group_layout = GameOfLifeImage::bind_group_layout(render_device);
136134
let shader = world
137135
.resource::<AssetServer>()
138136
.load("shaders/game_of_life.wgsl");

0 commit comments

Comments
 (0)