Skip to main content

Grove/WASM/
FunctionExport.rs

1//! Function Export Module
2//!
3//! Handles exporting host functions to WASM modules.
4//! Provides registration and management of functions that WASM can call.
5
6use std::{collections::HashMap, sync::Arc};
7
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use tokio::sync::RwLock;
11use wasmtime::{Caller, Linker};
12
13use crate::{
14	WASM::HostBridge::{
15		FunctionSignature,
16		HostBridgeImpl,
17		HostBridgeImpl as HostBridge,
18		HostFunctionCallback,
19		ParamType,
20		ReturnType,
21	},
22	dev_log,
23};
24
25/// Host function registry for WASM exports
26pub struct HostFunctionRegistry {
27	/// Registered host functions
28	functions:Arc<RwLock<HashMap<String, RegisteredHostFunction>>>,
29	/// Associated host bridge
30	#[allow(dead_code)]
31	bridge:Arc<HostBridge>,
32}
33
34/// Registered host function with metadata
35#[derive(Debug, Clone)]
36struct RegisteredHostFunction {
37	/// Function name
38	#[allow(dead_code)]
39	name:String,
40	/// Function signature
41	#[allow(dead_code)]
42	signature:FunctionSignature,
43	/// Synchronous callback
44	callback:Option<HostFunctionCallback>,
45	/// Registration timestamp
46	#[allow(dead_code)]
47	registered_at:u64,
48	/// Call statistics
49	stats:FunctionStats,
50}
51
52/// Function statistics
53#[derive(Debug, Clone, Default)]
54pub struct FunctionStats {
55	/// Number of times called
56	pub call_count:u64,
57	/// Total execution time in nanoseconds
58	pub total_execution_ns:u64,
59	/// Last call timestamp
60	pub last_call_at:Option<u64>,
61	/// Number of errors
62	pub error_count:u64,
63}
64
65/// Export configuration for WASM functions
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct ExportConfig {
68	/// Enable function export by default
69	pub auto_export:bool,
70	/// Enable timing statistics
71	pub enable_stats:bool,
72	/// Maximum number of functions that can be exported
73	pub max_functions:usize,
74	/// Function name prefix for exports
75	pub name_prefix:Option<String>,
76}
77
78impl Default for ExportConfig {
79	fn default() -> Self {
80		Self {
81			auto_export:true,
82			enable_stats:true,
83			max_functions:1000,
84			name_prefix:Some("host_".to_string()),
85		}
86	}
87}
88
89/// Function export for WASM
90pub struct FunctionExportImpl {
91	registry:Arc<HostFunctionRegistry>,
92	config:ExportConfig,
93}
94
95impl FunctionExportImpl {
96	/// Create a new function export manager
97	pub fn new(bridge:Arc<HostBridge>) -> Self {
98		Self {
99			registry:Arc::new(HostFunctionRegistry { functions:Arc::new(RwLock::new(HashMap::new())), bridge }),
100			config:ExportConfig::default(),
101		}
102	}
103
104	/// Create with custom configuration
105	pub fn with_config(bridge:Arc<HostBridge>, config:ExportConfig) -> Self {
106		Self {
107			registry:Arc::new(HostFunctionRegistry { functions:Arc::new(RwLock::new(HashMap::new())), bridge }),
108			config,
109		}
110	}
111
112	/// Register a host function for export to WASM
113	pub async fn register_function(
114		&self,
115		name:&str,
116		signature:FunctionSignature,
117		callback:HostFunctionCallback,
118	) -> Result<()> {
119		dev_log!("wasm", "Registering host function for export: {}", name);
120
121		let functions = self.registry.functions.read().await;
122
123		// Check max function limit
124		if functions.len() >= self.config.max_functions {
125			return Err(anyhow::anyhow!(
126				"Maximum number of exported functions reached: {}",
127				self.config.max_functions
128			));
129		}
130
131		drop(functions);
132
133		let mut functions = self.registry.functions.write().await;
134
135		let registered_at = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs();
136
137		functions.insert(
138			name.to_string(),
139			RegisteredHostFunction {
140				name:name.to_string(),
141				signature,
142				callback:Some(callback),
143				registered_at,
144				stats:FunctionStats::default(),
145			},
146		);
147
148		dev_log!("wasm", "Host function registered for WASM export: {}", name);
149		Ok(())
150	}
151
152	/// Register multiple host functions
153	pub async fn register_functions(
154		&self,
155		signatures:Vec<FunctionSignature>,
156		callbacks:Vec<HostFunctionCallback>,
157	) -> Result<()> {
158		if signatures.len() != callbacks.len() {
159			return Err(anyhow::anyhow!("Number of signatures must match number of callbacks"));
160		}
161
162		for (sig, callback) in signatures.into_iter().zip(callbacks) {
163			let name = sig.name.clone();
164			self.register_function(&name, sig, callback).await?;
165		}
166
167		Ok(())
168	}
169
170	/// Export all registered functions to a WASMtime linker
171	pub async fn export_to_linker<T>(&self, linker:&mut Linker<T>) -> Result<()>
172	where
173		T: Send + 'static, {
174		dev_log!(
175			"wasm",
176			"Exporting {} host functions to linker",
177			self.registry.functions.read().await.len()
178		);
179
180		let functions = self.registry.functions.read().await;
181
182		for (name, func) in functions.iter() {
183			self.export_single_function(linker, name, func)?;
184		}
185
186		dev_log!("wasm", "All host functions exported to linker");
187		Ok(())
188	}
189
190	/// Export a single function to the linker
191	fn export_single_function<T>(&self, linker:&mut Linker<T>, name:&str, func:&RegisteredHostFunction) -> Result<()>
192	where
193		T: Send + 'static, {
194		dev_log!("wasm", "Exporting function: {}", name);
195
196		let callback = func
197			.callback
198			.ok_or_else(|| anyhow::anyhow!("No callback available for function: {}", name))?;
199
200		let func_name = if let Some(prefix) = &self.config.name_prefix {
201			format!("{}{}", prefix, name)
202		} else {
203			name.to_string()
204		};
205
206		let func_name_for_debug = func_name.clone();
207		let func_name_inner = func_name.clone();
208
209		// Create a wrapper function that handles stats and error handling
210		let _wrapped_callback =
211			move |_caller:Caller<'_, T>, args:&[wasmtime::Val]| -> Result<Vec<wasmtime::Val>, wasmtime::Trap> {
212				let _start = std::time::Instant::now();
213
214				// Convert args to bytes
215				let args_bytes:Result<Vec<bytes::Bytes>, _> = args
216					.iter()
217					.map(|arg| {
218						match arg {
219							wasmtime::Val::I32(i) => {
220								serde_json::to_vec(i)
221									.map(bytes::Bytes::from)
222									.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
223							},
224							wasmtime::Val::I64(i) => {
225								serde_json::to_vec(i)
226									.map(bytes::Bytes::from)
227									.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
228							},
229							wasmtime::Val::F32(f) => {
230								serde_json::to_vec(f)
231									.map(bytes::Bytes::from)
232									.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
233							},
234							wasmtime::Val::F64(f) => {
235								serde_json::to_vec(f)
236									.map(bytes::Bytes::from)
237									.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
238							},
239							_ => Err(anyhow::anyhow!("Unsupported argument type")),
240						}
241					})
242					.collect();
243
244				let args_bytes = args_bytes.map_err(|_| {
245					dev_log!("wasm", "warn: error converting arguments for function '{}'", func_name_inner);
246					wasmtime::Trap::StackOverflow
247				})?;
248
249				// Call the callback
250				let result = callback(args_bytes);
251
252				match result {
253					Ok(response_bytes) => {
254						// Deserialize response
255						let result_val:serde_json::Value = serde_json::from_slice(&response_bytes).map_err(|_| {
256							dev_log!("wasm", "warn: error deserializing response for function '{}'", func_name_inner);
257							wasmtime::Trap::StackOverflow
258						})?;
259
260						let ret_val = match result_val {
261							serde_json::Value::Number(n) => {
262								if let Some(i) = n.as_i64() {
263									wasmtime::Val::I32(i as i32)
264								} else if let Some(f) = n.as_f64() {
265									wasmtime::Val::I64(f as i64)
266								} else {
267									dev_log!("wasm", "warn: invalid number format for function '{}'", func_name_inner);
268									return Err(wasmtime::Trap::StackOverflow);
269								}
270							},
271							_ => {
272								dev_log!("wasm", "warn: unsupported response type for function '{}'", func_name_inner);
273								return Err(wasmtime::Trap::StackOverflow);
274							},
275						};
276
277						Ok(vec![ret_val])
278					},
279					Err(e) => {
280						// Error handling
281						dev_log!("wasm", "host function '{}' returned error: {}", func_name_inner, e);
282						Err(wasmtime::Trap::StackOverflow)
283					},
284				}
285			};
286
287		// Define the function signature for WASMtime
288		let _wasmparser_signature = wasmparser::FuncType::new([wasmparser::ValType::I32], [wasmparser::ValType::I32]);
289
290		// Register host function with the linker using simple i32->i32 signature
291		// In Wasmtime 20, func_wrap expects parameters to be inferred from the closure
292		// signature
293		let func_name_for_logging = func_name.clone();
294		linker
295			.func_wrap(
296				"_host", // Module name for host functions
297				&func_name,
298				move |_caller:wasmtime::Caller<'_, T>, input_param:i32| -> i32 {
299					// Track function call for metrics
300					let start = std::time::Instant::now();
301
302					// Convert input parameter to bytes for callback
303					let args_bytes = match serde_json::to_vec(&input_param).map(bytes::Bytes::from) {
304						Ok(b) => b,
305						Err(e) => {
306							dev_log!(
307								"wasm",
308								"warn: serialization error for function '{}': {}",
309								func_name_for_logging,
310								e
311							);
312							return -1i32;
313						},
314					};
315
316					// Call the registered callback
317					let result = callback(vec![args_bytes]);
318
319					match result {
320						Ok(response_bytes) => {
321							// Deserialize response
322							let result_val:serde_json::Value = match serde_json::from_slice(&response_bytes) {
323								Ok(v) => v,
324								Err(_) => {
325									dev_log!(
326										"wasm",
327										"warn: error deserializing response for function '{}'",
328										func_name_for_logging
329									);
330									return -1i32;
331								},
332							};
333
334							// Extract result value
335							let ret_val = match result_val {
336								serde_json::Value::Number(n) => {
337									if let Some(i) = n.as_i64() {
338										i as i32
339									} else if let Some(f) = n.as_f64() {
340										f as i32
341									} else {
342										dev_log!(
343											"wasm",
344											"warn: invalid number format for function '{}'",
345											func_name_for_logging
346										);
347										-1i32
348									}
349								},
350								serde_json::Value::Bool(b) => {
351									if b {
352										1i32
353									} else {
354										0i32
355									}
356								},
357								_ => {
358									dev_log!(
359										"wasm",
360										"warn: unsupported response type for function '{}', expected number or bool",
361										func_name_for_logging
362									);
363									-1i32
364								},
365							};
366
367							// Log successful call
368							let duration = start.elapsed();
369							dev_log!(
370								"wasm",
371								"[FunctionExport] Host function '{}' executed successfully in {}µs",
372								func_name_for_logging,
373								duration.as_micros()
374							);
375
376							ret_val
377						},
378						Err(e) => {
379							// Error handling - return error code to WASM caller
380							dev_log!(
381								"wasm",
382								"[FunctionExport] Host function '{}' returned error: {}",
383								func_name_for_logging,
384								e
385							);
386							// Return -1 to indicate error to WASM caller
387							-1i32
388						},
389					}
390				},
391			)
392			.map_err(|e| {
393				dev_log!(
394					"wasm",
395					"warn: [FunctionExport] failed to wrap host function '{}': {}",
396					func_name_for_debug,
397					e
398				);
399				e
400			})?;
401
402		dev_log!(
403			"wasm",
404			"[FunctionExport] Host function '{}' registered successfully",
405			func_name_for_debug
406		);
407
408		Ok(())
409	}
410
411	/// Convert our signature to WASMtime signature type
412	#[allow(dead_code)]
413	fn wasmtime_signature_from_signature(&self, _sig:&FunctionSignature) -> Result<wasmparser::FuncType> {
414		// This is a placeholder - actual implementation depends on the exact types
415		// In production, this would map ParamType and ReturnType to WASMtime types
416		Ok(wasmparser::FuncType::new([], []))
417	}
418
419	/// Get all registered function names
420	pub async fn get_function_names(&self) -> Vec<String> {
421		self.registry.functions.read().await.keys().cloned().collect()
422	}
423
424	/// Get function statistics
425	pub async fn get_function_stats(&self, name:&str) -> Option<FunctionStats> {
426		self.registry.functions.read().await.get(name).map(|f| f.stats.clone())
427	}
428
429	/// Unregister a function
430	pub async fn unregister_function(&self, name:&str) -> Result<bool> {
431		let mut functions = self.registry.functions.write().await;
432		let removed = functions.remove(name).is_some();
433
434		if removed {
435			dev_log!("wasm", "Unregistered host function: {}", name);
436		} else {
437			dev_log!("wasm", "warn: attempted to unregister non-existent function: {}", name);
438		}
439
440		Ok(removed)
441	}
442
443	/// Clear all registered functions
444	pub async fn clear(&self) {
445		dev_log!("wasm", "Clearing all registered host functions");
446		self.registry.functions.write().await.clear();
447	}
448}
449
450#[cfg(test)]
451mod tests {
452	use super::*;
453
454	#[tokio::test]
455	async fn test_function_export_creation() {
456		let bridge = Arc::new(HostBridgeImpl::new());
457		let export = FunctionExportImpl::new(bridge);
458
459		assert_eq!(export.get_function_names().await.len(), 0);
460	}
461
462	#[tokio::test]
463	async fn test_register_function() {
464		let bridge = Arc::new(HostBridgeImpl::new());
465		let export = FunctionExportImpl::new(bridge);
466
467		let signature = FunctionSignature {
468			name:"echo".to_string(),
469			param_types:vec![ParamType::I32],
470			return_type:Some(ReturnType::I32),
471			is_async:false,
472		};
473
474		let callback = |args:Vec<bytes::Bytes>| Ok(args.get(0).cloned().unwrap_or(bytes::Bytes::new()));
475
476		let result:anyhow::Result<()> = export.register_function("echo", signature, callback).await;
477		assert!(result.is_ok());
478		assert_eq!(export.get_function_names().await.len(), 1);
479	}
480
481	#[tokio::test]
482	async fn test_unregister_function() {
483		let bridge = Arc::new(HostBridgeImpl::new());
484		let export = FunctionExportImpl::new(bridge);
485
486		let signature = FunctionSignature {
487			name:"test".to_string(),
488			param_types:vec![ParamType::I32],
489			return_type:Some(ReturnType::I32),
490			is_async:false,
491		};
492
493		let callback = |_:Vec<bytes::Bytes>| Ok(bytes::Bytes::new());
494		let _:anyhow::Result<()> = export.register_function("test", signature, callback).await;
495
496		let result:bool = export.unregister_function("test").await.unwrap();
497		assert!(result);
498		assert_eq!(export.get_function_names().await.len(), 0);
499	}
500
501	#[test]
502	fn test_export_config_default() {
503		let config = ExportConfig::default();
504		assert_eq!(config.auto_export, true);
505		assert_eq!(config.max_functions, 1000);
506	}
507
508	#[test]
509	fn test_function_stats_default() {
510		let stats = FunctionStats::default();
511		assert_eq!(stats.call_count, 0);
512		assert_eq!(stats.error_count, 0);
513	}
514}