|
| 1 | +//! The `wasm-bindgen` multi-value transformation. |
| 2 | +//! |
| 3 | +//! This crate provides a transformation to turn exported functions that use a |
| 4 | +//! return pointer into exported functions that use multi-value. |
| 5 | +//! |
| 6 | +//! Consider the following function: |
| 7 | +//! |
| 8 | +//! ``` |
| 9 | +//! #[no_mangle] |
| 10 | +//! pub extern "C" fn pair(a: u32, b: u32) -> [u32; 2] { |
| 11 | +//! [a, b] |
| 12 | +//! } |
| 13 | +//! ``` |
| 14 | +//! |
| 15 | +//! LLVM will by default compile this down into the following Wasm: |
| 16 | +//! |
| 17 | +//! ```wasm |
| 18 | +//! (func $pair (param i32 i32 i32) |
| 19 | +//! local.get 0 |
| 20 | +//! local.get 2 |
| 21 | +//! i32.store offset=4 |
| 22 | +//! local.get 0 |
| 23 | +//! local.get 1 |
| 24 | +//! i32.store) |
| 25 | +//! ``` |
| 26 | +//! |
| 27 | +//! What's happening here is that the function is not directly returning the |
| 28 | +//! pair at all, but instead the first `i32` parameter is a pointer to some |
| 29 | +//! scratch space, and the return value is written into the scratch space. LLVM |
| 30 | +//! does this because it doesn't yet have support for multi-value Wasm, and so |
| 31 | +//! it only knows how to return a single value at a time. |
| 32 | +//! |
| 33 | +//! Ideally, with multi-value, what we would like instead is this: |
| 34 | +//! |
| 35 | +//! ```wasm |
| 36 | +//! (func $pair (param i32 i32) (result i32 i32) |
| 37 | +//! local.get 0 |
| 38 | +//! local.get 1) |
| 39 | +//! ``` |
| 40 | +//! |
| 41 | +//! However, that's not what this transformation does at the moment. This |
| 42 | +//! transformation is a little simpler than mutating existing functions to |
| 43 | +//! produce a multi-value result, instead it introduces new functions that wrap |
| 44 | +//! the original function and translate the return pointer to multi-value |
| 45 | +//! results in this wrapper function. |
| 46 | +//! |
| 47 | +//! With our running example, we end up with this: |
| 48 | +//! |
| 49 | +//! ```wasm |
| 50 | +//! ;; The original function. |
| 51 | +//! (func $pair (param i32 i32 i32) |
| 52 | +//! local.get 0 |
| 53 | +//! local.get 2 |
| 54 | +//! i32.store offset=4 |
| 55 | +//! local.get 0 |
| 56 | +//! local.get 1 |
| 57 | +//! i32.store) |
| 58 | +//! |
| 59 | +//! (func $pairWrapper (param i32 i32) (result i32 i32) |
| 60 | +//! ;; Our return pointer that points to the scratch space we are allocating |
| 61 | +//! ;; on the shadow stack for calling `$pair`. |
| 62 | +//! (local i32) |
| 63 | +//! |
| 64 | +//! ;; Allocate space on the shadow stack for the result. |
| 65 | +//! global.get $shadowStackPointer |
| 66 | +//! i32.const 8 |
| 67 | +//! i32.sub |
| 68 | +//! local.tee 2 |
| 69 | +//! global.set $shadowStackPointer |
| 70 | +//! |
| 71 | +//! ;; Call `$pair` with our allocated shadow stack space for its results. |
| 72 | +//! local.get 2 |
| 73 | +//! local.get 0 |
| 74 | +//! local.get 1 |
| 75 | +//! call $pair |
| 76 | +//! |
| 77 | +//! ;; Copy the return values from the shadow stack to the wasm stack. |
| 78 | +//! local.get 2 |
| 79 | +//! i32.load |
| 80 | +//! local.get 2 offset=4 |
| 81 | +//! i32.load |
| 82 | +//! |
| 83 | +//! ;; Finally, restore the shadow stack pointer. |
| 84 | +//! local.get 2 |
| 85 | +//! i32.const 8 |
| 86 | +//! i32.add |
| 87 | +//! global.set $shadowStackPointer) |
| 88 | +//! ``` |
| 89 | +//! |
| 90 | +//! This `$pairWrapper` function is what we actually end up exporting instead of |
| 91 | +//! `$pair`. |
| 92 | +
|
| 93 | +#![deny(missing_docs, missing_debug_implementations)] |
| 94 | + |
| 95 | +/// Run the transformation. |
| 96 | +/// |
| 97 | +/// See the module-level docs for details on the transformation. |
| 98 | +/// |
| 99 | +/// * `memory` is the module's memory that has the shadow stack where return |
| 100 | +/// pointers are allocated within. |
| 101 | +/// |
| 102 | +/// * `shadow_stack_pointer` is the global that is being used as the stack |
| 103 | +/// pointer for the shadow stack. With LLVM, this is typically the first |
| 104 | +/// global. |
| 105 | +/// |
| 106 | +/// * `to_xform` is the set of exported functions we want to transform and |
| 107 | +/// information required to transform them. The `usize` is the index of the |
| 108 | +/// return pointer parameter that will be removed. The `Vec<walrus::ValType>` |
| 109 | +/// is the new result type that will be returned directly instead of via the |
| 110 | +/// return pointer. |
| 111 | +pub fn run( |
| 112 | + module: &mut walrus::Module, |
| 113 | + memory: walrus::MemoryId, |
| 114 | + shadow_stack_pointer: walrus::GlobalId, |
| 115 | + to_xform: &[(walrus::ExportId, usize, &[walrus::ValType])], |
| 116 | +) -> Result<(), failure::Error> { |
| 117 | + for &(export, return_pointer_index, results) in to_xform { |
| 118 | + xform_one( |
| 119 | + module, |
| 120 | + memory, |
| 121 | + shadow_stack_pointer, |
| 122 | + export, |
| 123 | + return_pointer_index, |
| 124 | + results, |
| 125 | + )?; |
| 126 | + } |
| 127 | + Ok(()) |
| 128 | +} |
| 129 | + |
| 130 | +// Ensure that `n` is aligned to `align`, rounding up as necessary. |
| 131 | +fn round_up_to_alignment(n: u32, align: u32) -> u32 { |
| 132 | + debug_assert!(align.is_power_of_two()); |
| 133 | + (n + align - 1) & !(align - 1) |
| 134 | +} |
| 135 | + |
| 136 | +fn xform_one( |
| 137 | + module: &mut walrus::Module, |
| 138 | + memory: walrus::MemoryId, |
| 139 | + shadow_stack_pointer: walrus::GlobalId, |
| 140 | + export: walrus::ExportId, |
| 141 | + return_pointer_index: usize, |
| 142 | + results: &[walrus::ValType], |
| 143 | +) -> Result<(), failure::Error> { |
| 144 | + if module.globals.get(shadow_stack_pointer).ty != walrus::ValType::I32 { |
| 145 | + failure::bail!("shadow stack pointer global does not have type `i32`"); |
| 146 | + } |
| 147 | + |
| 148 | + let func = match module.exports.get(export).item { |
| 149 | + walrus::ExportItem::Function(f) => f, |
| 150 | + _ => { |
| 151 | + failure::bail!("can only multi-value transform exported functions, found non-function") |
| 152 | + } |
| 153 | + }; |
| 154 | + |
| 155 | + // Compute the total size of all results, potentially with padding to ensure |
| 156 | + // that each result is aligned. |
| 157 | + let mut results_size = 0; |
| 158 | + for ty in results { |
| 159 | + results_size = match ty { |
| 160 | + walrus::ValType::I32 | walrus::ValType::F32 => { |
| 161 | + debug_assert_eq!(results_size % 4, 0); |
| 162 | + results_size + 4 |
| 163 | + } |
| 164 | + walrus::ValType::I64 | walrus::ValType::F64 => { |
| 165 | + round_up_to_alignment(results_size, 8) + 8 |
| 166 | + } |
| 167 | + walrus::ValType::V128 => round_up_to_alignment(results_size, 16) + 16, |
| 168 | + walrus::ValType::Anyref => failure::bail!( |
| 169 | + "cannot multi-value transform functions that return \ |
| 170 | + anyref, since they can't go into linear memory" |
| 171 | + ), |
| 172 | + }; |
| 173 | + } |
| 174 | + // Round up to 16-byte alignment, since that's what LLVM's emitted Wasm code |
| 175 | + // seems to expect. |
| 176 | + let results_size = round_up_to_alignment(results_size, 16); |
| 177 | + |
| 178 | + let ty = module.funcs.get(func).ty(); |
| 179 | + let (ty_params, ty_results) = module.types.params_results(ty); |
| 180 | + |
| 181 | + if !ty_results.is_empty() { |
| 182 | + failure::bail!( |
| 183 | + "can only multi-value transform functions that don't return any \ |
| 184 | + results (since they should be returned on the stack via a pointer)" |
| 185 | + ); |
| 186 | + } |
| 187 | + |
| 188 | + match ty_params.get(return_pointer_index) { |
| 189 | + Some(walrus::ValType::I32) => {} |
| 190 | + None => failure::bail!("the return pointer parameter doesn't exist"), |
| 191 | + Some(_) => failure::bail!("the return pointer parameter is not `i32`"), |
| 192 | + } |
| 193 | + |
| 194 | + let new_params: Vec<_> = ty_params |
| 195 | + .iter() |
| 196 | + .cloned() |
| 197 | + .enumerate() |
| 198 | + .filter_map(|(i, ty)| { |
| 199 | + if i == return_pointer_index { |
| 200 | + None |
| 201 | + } else { |
| 202 | + Some(ty) |
| 203 | + } |
| 204 | + }) |
| 205 | + .collect(); |
| 206 | + |
| 207 | + // The locals for the function parameters. |
| 208 | + let params: Vec<_> = new_params.iter().map(|ty| module.locals.add(*ty)).collect(); |
| 209 | + |
| 210 | + // A local to hold our shadow stack-allocated return pointer. |
| 211 | + let return_pointer = module.locals.add(walrus::ValType::I32); |
| 212 | + |
| 213 | + let mut wrapper = walrus::FunctionBuilder::new(&mut module.types, &new_params, results); |
| 214 | + let mut body = wrapper.func_body(); |
| 215 | + |
| 216 | + // Allocate space in the shadow stack for the call. |
| 217 | + body.global_get(shadow_stack_pointer) |
| 218 | + .i32_const(results_size as i32) |
| 219 | + .binop(walrus::ir::BinaryOp::I32Sub) |
| 220 | + .local_tee(return_pointer) |
| 221 | + .global_set(shadow_stack_pointer); |
| 222 | + |
| 223 | + // Push the parameters for calling our wrapped function -- including the |
| 224 | + // return pointer! -- on to the stack. |
| 225 | + for (i, local) in params.iter().enumerate() { |
| 226 | + if i == return_pointer_index { |
| 227 | + body.local_get(return_pointer); |
| 228 | + } |
| 229 | + body.local_get(*local); |
| 230 | + } |
| 231 | + if return_pointer_index == params.len() { |
| 232 | + body.local_get(return_pointer); |
| 233 | + } |
| 234 | + |
| 235 | + // Call our wrapped function. |
| 236 | + body.call(func); |
| 237 | + |
| 238 | + // Copy the return values from our shadow stack-allocated space and onto the |
| 239 | + // Wasm stack. |
| 240 | + let mut offset = 0; |
| 241 | + for ty in results { |
| 242 | + debug_assert!(offset < results_size); |
| 243 | + body.local_get(return_pointer); |
| 244 | + match ty { |
| 245 | + walrus::ValType::I32 => { |
| 246 | + debug_assert_eq!(offset % 4, 0); |
| 247 | + body.load( |
| 248 | + memory, |
| 249 | + walrus::ir::LoadKind::I32 { atomic: false }, |
| 250 | + walrus::ir::MemArg { align: 4, offset }, |
| 251 | + ); |
| 252 | + offset += 4; |
| 253 | + } |
| 254 | + walrus::ValType::I64 => { |
| 255 | + offset = round_up_to_alignment(offset, 8); |
| 256 | + body.load( |
| 257 | + memory, |
| 258 | + walrus::ir::LoadKind::I64 { atomic: false }, |
| 259 | + walrus::ir::MemArg { align: 8, offset }, |
| 260 | + ); |
| 261 | + offset += 8; |
| 262 | + } |
| 263 | + walrus::ValType::F32 => { |
| 264 | + debug_assert_eq!(offset % 4, 0); |
| 265 | + body.load( |
| 266 | + memory, |
| 267 | + walrus::ir::LoadKind::F32, |
| 268 | + walrus::ir::MemArg { align: 4, offset }, |
| 269 | + ); |
| 270 | + offset += 4; |
| 271 | + } |
| 272 | + walrus::ValType::F64 => { |
| 273 | + offset = round_up_to_alignment(offset, 8); |
| 274 | + body.load( |
| 275 | + memory, |
| 276 | + walrus::ir::LoadKind::F64, |
| 277 | + walrus::ir::MemArg { align: 8, offset }, |
| 278 | + ); |
| 279 | + offset += 8; |
| 280 | + } |
| 281 | + walrus::ValType::V128 => { |
| 282 | + offset = round_up_to_alignment(offset, 16); |
| 283 | + body.load( |
| 284 | + memory, |
| 285 | + walrus::ir::LoadKind::V128, |
| 286 | + walrus::ir::MemArg { align: 16, offset }, |
| 287 | + ); |
| 288 | + offset += 16; |
| 289 | + } |
| 290 | + walrus::ValType::Anyref => unreachable!(), |
| 291 | + } |
| 292 | + } |
| 293 | + |
| 294 | + // Finally, restore the shadow stack pointer. |
| 295 | + body.local_get(return_pointer) |
| 296 | + .i32_const(results_size as i32) |
| 297 | + .binop(walrus::ir::BinaryOp::I32Add) |
| 298 | + .global_set(shadow_stack_pointer); |
| 299 | + |
| 300 | + let wrapper = wrapper.finish(params, &mut module.funcs); |
| 301 | + |
| 302 | + // Replace the old export with our new multi-value wrapper for it! |
| 303 | + match module.exports.get_mut(export).item { |
| 304 | + walrus::ExportItem::Function(ref mut f) => *f = wrapper, |
| 305 | + _ => unreachable!(), |
| 306 | + } |
| 307 | + |
| 308 | + Ok(()) |
| 309 | +} |
| 310 | + |
| 311 | +#[cfg(test)] |
| 312 | +mod tests { |
| 313 | + #[test] |
| 314 | + fn round_up_to_alignment_works() { |
| 315 | + for (n, align, expected) in vec![ |
| 316 | + (0, 1, 0), |
| 317 | + (1, 1, 1), |
| 318 | + (2, 1, 2), |
| 319 | + (0, 2, 0), |
| 320 | + (1, 2, 2), |
| 321 | + (2, 2, 2), |
| 322 | + (3, 2, 4), |
| 323 | + (0, 4, 0), |
| 324 | + (1, 4, 4), |
| 325 | + (2, 4, 4), |
| 326 | + (3, 4, 4), |
| 327 | + (4, 4, 4), |
| 328 | + (5, 4, 8), |
| 329 | + ] { |
| 330 | + let actual = super::round_up_to_alignment(n, align); |
| 331 | + println!( |
| 332 | + "round_up_to_alignment(n = {}, align = {}) = {} (expected {})", |
| 333 | + n, align, actual, expected |
| 334 | + ); |
| 335 | + assert_eq!(actual, expected); |
| 336 | + } |
| 337 | + } |
| 338 | +} |
0 commit comments