Skip to content

Commit 24f9137

Browse files
committed
moved permutations to complete ImageWithMethods trait impls
1 parent 061727b commit 24f9137

File tree

4 files changed

+207
-115
lines changed

4 files changed

+207
-115
lines changed

crates/spirv-std/macros/src/lib.rs

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -631,30 +631,42 @@ const SAMPLE_PARAM_TYPES: [&str; SAMPLE_PARAM_COUNT] = ["B", "L", "S"];
631631
const SAMPLE_PARAM_OPERANDS: [&str; SAMPLE_PARAM_COUNT] = ["Bias", "Lod", "Sample"];
632632
const SAMPLE_PARAM_NAMES: [&str; SAMPLE_PARAM_COUNT] = ["bias", "lod", "sample_index"];
633633

634-
struct SampleFnRewriter(usize);
634+
struct SampleImplRewriter(usize, syn::Type);
635635

636-
impl SampleFnRewriter {
637-
pub fn rewrite(mask: usize, f: &syn::ItemFn) -> syn::ItemFn {
638-
let mut new_f = f.clone();
639-
let mut ty = String::from("SampleParams<");
636+
impl SampleImplRewriter {
637+
pub fn rewrite(mask: usize, f: &syn::ItemImpl) -> syn::ItemImpl {
638+
let mut new_impl = f.clone();
639+
let mut ty_str = String::from("SampleParams<");
640640

641+
// based on the mask, form a `SampleParams` type string and add the generic parameters to the `impl<>` generics
642+
// example type string: `"SampleParams<SomeTy<B>, NoneTy, NoneTy>"`
641643
for i in 0..SAMPLE_PARAM_COUNT {
642644
if mask & (1 << i) != 0 {
643-
new_f.sig.generics.params.push(syn::GenericParam::Type(
645+
new_impl.generics.params.push(syn::GenericParam::Type(
644646
syn::Ident::new(SAMPLE_PARAM_TYPES[i], Span::call_site()).into(),
645647
));
646-
ty.push_str(SAMPLE_PARAM_TYPES[i]);
648+
ty_str.push_str("SomeTy<");
649+
ty_str.push_str(SAMPLE_PARAM_TYPES[i]);
650+
ty_str.push('>');
647651
} else {
648-
ty.push_str("()");
652+
ty_str.push_str("NoneTy");
649653
}
650-
ty.push(',');
654+
ty_str.push(',');
651655
}
652-
ty.push('>');
653-
if let Some(syn::FnArg::Typed(p)) = new_f.sig.inputs.last_mut() {
654-
*p.ty.as_mut() = syn::parse(ty.parse().unwrap()).unwrap();
656+
ty_str.push_str(">");
657+
let ty: syn::Type = syn::parse(ty_str.parse().unwrap()).unwrap();
658+
659+
if let Some(t) = &mut new_impl.trait_ {
660+
if let syn::PathArguments::AngleBracketed(a) =
661+
&mut t.1.segments.last_mut().unwrap().arguments
662+
{
663+
if let Some(syn::GenericArgument::Type(t)) = a.args.last_mut() {
664+
*t = ty.clone();
665+
}
666+
}
655667
}
656-
SampleFnRewriter(mask).visit_item_fn_mut(&mut new_f);
657-
new_f
668+
SampleImplRewriter(mask, ty).visit_item_impl_mut(&mut new_impl);
669+
new_impl
658670
}
659671

660672
fn get_operands(&self) -> String {
@@ -686,18 +698,22 @@ impl SampleFnRewriter {
686698
fn add_regs(&self, t: &mut Vec<TokenTree>) {
687699
for i in 0..SAMPLE_PARAM_COUNT {
688700
if self.0 & (1 << i) != 0 {
689-
let s = format!("{0} = in(reg) &param.{0}", SAMPLE_PARAM_NAMES[i]);
690-
t.push(TokenTree::Literal(proc_macro2::Literal::string(s.as_str())));
691-
t.push(TokenTree::Punct(proc_macro2::Punct::new(
692-
',',
693-
proc_macro2::Spacing::Alone,
694-
)))
701+
let s = format!("{0} = in(reg) &param.{0},", SAMPLE_PARAM_NAMES[i]);
702+
let ts: proc_macro2::TokenStream = s.parse().unwrap();
703+
t.extend(ts);
695704
}
696705
}
697706
}
698707
}
699708

700-
impl syn::visit_mut::VisitMut for SampleFnRewriter {
709+
impl VisitMut for SampleImplRewriter {
710+
fn visit_impl_item_method_mut(&mut self, item: &mut syn::ImplItemMethod) {
711+
if let Some(syn::FnArg::Typed(p)) = item.sig.inputs.last_mut() {
712+
*p.ty.as_mut() = self.1.clone();
713+
}
714+
syn::visit_mut::visit_impl_item_method_mut(self, item);
715+
}
716+
701717
fn visit_macro_mut(&mut self, m: &mut syn::Macro) {
702718
if m.path.is_ident("asm") {
703719
let t = m.tokens.clone();
@@ -745,12 +761,13 @@ impl syn::visit_mut::VisitMut for SampleFnRewriter {
745761
#[proc_macro_attribute]
746762
#[doc(hidden)]
747763
pub fn gen_sample_param_permutations(_attr: TokenStream, item: TokenStream) -> TokenStream {
748-
let item_fn = syn::parse_macro_input!(item as syn::ItemFn);
764+
let item_impl = syn::parse_macro_input!(item as syn::ItemImpl);
749765
let mut fns = Vec::new();
750766

751-
for m in 1..((1 << SAMPLE_PARAM_COUNT) - 1) {
752-
fns.push(SampleFnRewriter::rewrite(m, &item_fn));
767+
for m in 1..(1 << SAMPLE_PARAM_COUNT) {
768+
fns.push(SampleImplRewriter::rewrite(m, &item_impl));
753769
}
754770

771+
println!("{}", quote! { #(#fns)* }.to_string());
755772
quote! { #(#fns)* }.into()
756773
}

crates/spirv-std/src/image.rs

Lines changed: 79 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,17 @@ use core::arch::asm;
1010
#[rustfmt::skip]
1111
mod params;
1212

13-
pub use self::params::{ImageCoordinate, ImageCoordinateSubpassData, SampleParams, SampleType};
13+
/// Contains extra image operands
14+
pub mod sample;
15+
16+
pub use self::params::{ImageCoordinate, ImageCoordinateSubpassData, SampleType};
1417
pub use crate::macros::Image;
1518
pub use spirv_std_types::image_params::{
1619
AccessQualifier, Arrayed, Dimensionality, ImageDepth, ImageFormat, Multisampled, Sampled,
1720
};
1821

22+
use sample::{NoneTy, SampleParams, SomeTy};
23+
1924
use crate::{float::Float, integer::Integer, vector::Vector, Sampler};
2025

2126
/// Re-export of primitive types to ensure the `Image` proc macro always points
@@ -158,33 +163,6 @@ impl<
158163
}
159164
result.truncate_into()
160165
}
161-
162-
#[crate::macros::gen_sample_param_permutations]
163-
#[crate::macros::gpu_only]
164-
#[doc(alias = "OpImageFetch")]
165-
pub fn fetch_with<I>(
166-
&self,
167-
coordinate: impl ImageCoordinate<I, DIM, ARRAYED>,
168-
params: SampleParams,
169-
) -> SampledType::SampleResult
170-
where
171-
I: Integer,
172-
B: Integer,
173-
{
174-
let mut result = SampledType::Vec4::default();
175-
unsafe {
176-
asm! {
177-
"%image = OpLoad _ {this}",
178-
"%coordinate = OpLoad _ {coordinate}",
179-
"%result = OpImageFetch typeof*{result} %image %coordinate $PARAMS",
180-
"OpStore {result} %result",
181-
result = in(reg) &mut result,
182-
this = in(reg) self,
183-
coordinate = in(reg) &coordinate,
184-
}
185-
}
186-
result.truncate_into()
187-
}
188166
}
189167

190168
impl<
@@ -1142,6 +1120,79 @@ impl<
11421120
}
11431121
}
11441122

1123+
/// Helper trait that defines all `*_with` methods on an `Image` that use the extra image operands,
1124+
/// such as bias or lod, defined by the `SampleParams` struct.
1125+
pub trait ImageWithMethods<
1126+
SampledType: SampleType<FORMAT, COMPONENTS>,
1127+
const DIM: u32,
1128+
const DEPTH: u32,
1129+
const ARRAYED: u32,
1130+
const MULTISAMPLED: u32,
1131+
const SAMPLED: u32,
1132+
const FORMAT: u32,
1133+
const COMPONENTS: u32,
1134+
Params,
1135+
>
1136+
{
1137+
/// Fetch a single texel with a sampler set at compile time
1138+
fn fetch_with<I>(
1139+
&self,
1140+
coordinate: impl ImageCoordinate<I, DIM, ARRAYED>,
1141+
params: Params,
1142+
) -> SampledType::SampleResult
1143+
where
1144+
I: Integer;
1145+
}
1146+
1147+
#[crate::macros::gen_sample_param_permutations]
1148+
impl<
1149+
SampledType: SampleType<FORMAT, COMPONENTS>,
1150+
const DIM: u32,
1151+
const DEPTH: u32,
1152+
const ARRAYED: u32,
1153+
const MULTISAMPLED: u32,
1154+
const SAMPLED: u32,
1155+
const FORMAT: u32,
1156+
const COMPONENTS: u32,
1157+
>
1158+
ImageWithMethods<
1159+
SampledType,
1160+
DIM,
1161+
DEPTH,
1162+
ARRAYED,
1163+
MULTISAMPLED,
1164+
SAMPLED,
1165+
FORMAT,
1166+
COMPONENTS,
1167+
SampleParams,
1168+
> for Image<SampledType, DIM, DEPTH, ARRAYED, MULTISAMPLED, SAMPLED, FORMAT, COMPONENTS>
1169+
{
1170+
#[crate::macros::gpu_only]
1171+
#[doc(alias = "OpImageFetch")]
1172+
fn fetch_with<I>(
1173+
&self,
1174+
coordinate: impl ImageCoordinate<I, DIM, ARRAYED>,
1175+
params: SampleParams,
1176+
) -> SampledType::SampleResult
1177+
where
1178+
I: Integer,
1179+
{
1180+
let mut result = SampledType::Vec4::default();
1181+
unsafe {
1182+
asm! {
1183+
"%image = OpLoad _ {this}",
1184+
"%coordinate = OpLoad _ {coordinate}",
1185+
"%result = OpImageFetch typeof*{result} %image %coordinate $PARAMS",
1186+
"OpStore {result} %result",
1187+
result = in(reg) &mut result,
1188+
this = in(reg) self,
1189+
coordinate = in(reg) &coordinate,
1190+
}
1191+
}
1192+
result.truncate_into()
1193+
}
1194+
}
1195+
11451196
/// This is a marker trait to represent the constraints on `OpImageGather` too complex to be
11461197
/// represented by const generics. Specifically:
11471198
///

crates/spirv-std/src/image/params.rs

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -194,66 +194,3 @@ impl<V: Vector<S, 4>, S: Scalar>
194194
pub trait ImageCoordinateSubpassData<T, const ARRAYED: u32> {}
195195
impl<V: Vector<I, 2>, I: Integer> ImageCoordinateSubpassData<I, { Arrayed::False as u32 }> for V {}
196196
impl<V: Vector<I, 3>, I: Integer> ImageCoordinateSubpassData<I, { Arrayed::True as u32 }> for V {}
197-
198-
/// Helper struct that allows building image operands. Start with a global function that returns this
199-
/// struct, and then chain additional calls.
200-
/// Example: `image.sample_with(coords, params::bias(3.0).sample_index(1))`
201-
pub struct SampleParams<B, L, S> {
202-
bias: B,
203-
lod: L,
204-
sample_index: S,
205-
}
206-
207-
pub fn bias<B>(bias: B) -> SampleParams<B, (), ()> {
208-
SampleParams {
209-
bias,
210-
lod: (),
211-
sample_index: (),
212-
}
213-
}
214-
215-
pub fn lod<L>(lod: L) -> SampleParams<(), L, ()> {
216-
SampleParams {
217-
bias: (),
218-
lod,
219-
sample_index: (),
220-
}
221-
}
222-
223-
pub fn sample_index<S>(sample_index: S) -> SampleParams<(), (), S> {
224-
SampleParams {
225-
bias: (),
226-
lod: (),
227-
sample_index,
228-
}
229-
}
230-
231-
impl<L, S> SampleParams<(), L, S> {
232-
pub fn bias<B>(self, bias: B) -> SampleParams<B, L, S> {
233-
SampleParams {
234-
bias,
235-
lod: self.lod,
236-
sample_index: self.sample_index,
237-
}
238-
}
239-
}
240-
241-
impl<B, S> SampleParams<B, (), S> {
242-
pub fn lod<L>(self, lod: L) -> SampleParams<B, L, S> {
243-
SampleParams {
244-
bias: self.bias,
245-
lod: lod,
246-
sample_index: self.sample_index,
247-
}
248-
}
249-
}
250-
251-
impl<B, L> SampleParams<B, L, ()> {
252-
pub fn sample_index<S>(self, sample_index: S) -> SampleParams<B, L, S> {
253-
SampleParams {
254-
bias: self.bias,
255-
lod: self.lod,
256-
sample_index: sample_index,
257-
}
258-
}
259-
}

crates/spirv-std/src/image/sample.rs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/// Helper trait to mimic `Option<T>`, but where the variant are types
2+
pub trait OptionTy {
3+
/// Whether this is a `NoneTy` (when false) or a `SomeTy<T>` (when true)
4+
const EXISTS: bool;
5+
}
6+
7+
impl OptionTy for NoneTy {
8+
const EXISTS: bool = false;
9+
}
10+
11+
impl<T> OptionTy for SomeTy<T> {
12+
const EXISTS: bool = true;
13+
}
14+
/// Helper struct that denotes that the type doesn't exist, analog to `Option::None`
15+
pub struct NoneTy;
16+
17+
/// Helper struct that denotes that the type does exist and is of type T, analog to `Option::Some(T)`
18+
pub struct SomeTy<T>(T);
19+
20+
/// Helper struct that allows building image operands. Start with a global function that returns this
21+
/// struct, and then chain additional calls.
22+
/// Example: `image.sample_with(coords, params::bias(3.0).sample_index(1))`
23+
pub struct SampleParams<B: OptionTy, L: OptionTy, S: OptionTy> {
24+
bias: B,
25+
lod: L,
26+
sample_index: S,
27+
}
28+
29+
/// Sets the 'Bias' image operand
30+
pub fn bias<B>(bias: B) -> SampleParams<SomeTy<B>, NoneTy, NoneTy> {
31+
SampleParams {
32+
bias: SomeTy(bias),
33+
lod: NoneTy,
34+
sample_index: NoneTy,
35+
}
36+
}
37+
38+
/// Sets the 'Lod' image operand
39+
pub fn lod<L>(lod: L) -> SampleParams<NoneTy, SomeTy<L>, NoneTy> {
40+
SampleParams {
41+
bias: NoneTy,
42+
lod: SomeTy(lod),
43+
sample_index: NoneTy,
44+
}
45+
}
46+
47+
/// Sets the 'Sample' image operand
48+
pub fn sample_index<S>(sample_index: S) -> SampleParams<NoneTy, NoneTy, SomeTy<S>> {
49+
SampleParams {
50+
bias: NoneTy,
51+
lod: NoneTy,
52+
sample_index: SomeTy(sample_index),
53+
}
54+
}
55+
56+
impl<L: OptionTy, S: OptionTy> SampleParams<NoneTy, L, S> {
57+
/// Sets the 'Bias' image operand
58+
pub fn bias<B>(self, bias: B) -> SampleParams<SomeTy<B>, L, S> {
59+
SampleParams {
60+
bias: SomeTy(bias),
61+
lod: self.lod,
62+
sample_index: self.sample_index,
63+
}
64+
}
65+
}
66+
67+
impl<B: OptionTy, S: OptionTy> SampleParams<B, NoneTy, S> {
68+
/// Sets the 'Lod' image operand
69+
pub fn lod<L>(self, lod: L) -> SampleParams<B, SomeTy<L>, S> {
70+
SampleParams {
71+
bias: self.bias,
72+
lod: SomeTy(lod),
73+
sample_index: self.sample_index,
74+
}
75+
}
76+
}
77+
78+
impl<B: OptionTy, L: OptionTy> SampleParams<B, L, NoneTy> {
79+
/// Sets the 'Sample' image operand
80+
pub fn sample_index<S>(self, sample_index: S) -> SampleParams<B, L, SomeTy<S>> {
81+
SampleParams {
82+
bias: self.bias,
83+
lod: self.lod,
84+
sample_index: SomeTy(sample_index),
85+
}
86+
}
87+
}

0 commit comments

Comments
 (0)