Skip to main content

Grove/WASM/
HostBridge.rs

1//! Host Bridge
2//!
3//! Provides bidirectional communication between the host (Grove) and WASM
4//! modules. Handles function calls, data transfer, and marshalling between the
5//! two environments.
6
7use std::{collections::HashMap, sync::Arc};
8
9use anyhow::Result;
10use bytes::Bytes;
11use serde::{Serialize, de::DeserializeOwned};
12use tokio::sync::{RwLock, mpsc, oneshot};
13#[allow(unused_imports)]
14use wasmtime::{Caller, Extern, Func, Linker, Store};
15
16use crate::dev_log;
17
18/// Host bridge error types
19#[derive(Debug, thiserror::Error)]
20pub enum BridgeError {
21	/// Function not found error
22	#[error("Function not found: {0}")]
23	FunctionNotFound(String),
24
25	/// Invalid function signature error
26	#[error("Invalid function signature: {0}")]
27	InvalidSignature(String),
28
29	/// Serialization failed error
30	#[error("Serialization failed: {0}")]
31	SerializationError(String),
32
33	/// Deserialization failed error
34	#[error("Deserialization failed: {0}")]
35	DeserializationError(String),
36
37	/// Host function error
38	#[error("Host function error: {0}")]
39	HostFunctionError(String),
40
41	/// Communication timeout error
42	#[error("Communication timeout")]
43	Timeout,
44
45	/// Bridge closed error
46	#[error("Bridge closed")]
47	BridgeClosed,
48}
49
50/// Type-safe result for operations
51pub type BridgeResult<T> = Result<T, BridgeError>;
52
53/// Function signature information
54#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
55pub struct FunctionSignature {
56	/// Function name
57	pub name:String,
58	/// Parameter types
59	pub param_types:Vec<ParamType>,
60	/// Return type
61	pub return_type:Option<ReturnType>,
62	/// Whether this is an async function
63	pub is_async:bool,
64}
65
66/// Parameter types for WASM functions
67#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
68pub enum ParamType {
69	/// 32-bit signed integer parameter
70	I32,
71	/// 64-bit signed integer parameter
72	I64,
73	/// 32-bit floating point parameter
74	F32,
75	/// 64-bit floating point parameter
76	F64,
77	/// Pointer to memory
78	Ptr,
79	/// Length parameter following a pointer
80	Len,
81}
82
83/// Return types for WASM functions
84#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
85pub enum ReturnType {
86	/// 32-bit signed integer return type
87	I32,
88	/// 64-bit signed integer return type
89	I64,
90	/// 32-bit floating point return type
91	F32,
92	/// 64-bit floating point return type
93	F64,
94	/// No return value (void)
95	Void,
96}
97
98/// Message sent from WASM to host
99#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
100pub struct HostMessage {
101	/// Message ID for correlation
102	pub message_id:String,
103	/// Function name to call
104	pub function:String,
105	/// Serialized arguments
106	pub args:Vec<Bytes>,
107	/// Callback token for async responses
108	pub callback_token:Option<u64>,
109}
110
111/// Response sent from host to WASM
112#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
113pub struct HostResponse {
114	/// Correlating message ID
115	pub message_id:String,
116	/// Success flag
117	pub success:bool,
118	/// Response data
119	pub data:Option<Bytes>,
120	/// Error message if failed
121	pub error:Option<String>,
122}
123
124/// Callback for async function responses
125#[derive(Clone)]
126pub struct AsyncCallback {
127	/// Sender for transmitting the response
128	sender:Arc<tokio::sync::Mutex<Option<tokio::sync::oneshot::Sender<HostResponse>>>>,
129	/// Message ID for correlation
130	message_id:String,
131}
132
133impl std::fmt::Debug for AsyncCallback {
134	fn fmt(&self, f:&mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135		f.debug_struct("AsyncCallback").field("message_id", &self.message_id).finish()
136	}
137}
138
139impl AsyncCallback {
140	/// Send response through the callback
141	pub async fn send(self, response:HostResponse) -> Result<()> {
142		let mut sender_opt = self.sender.lock().await;
143		if let Some(sender) = sender_opt.take() {
144			sender.send(response).map_err(|_| BridgeError::BridgeClosed)?;
145			Ok(())
146		} else {
147			Err(BridgeError::BridgeClosed.into())
148		}
149	}
150}
151
152/// Message from host to WASM
153#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
154pub struct WASMMessage {
155	/// Target function in WASM
156	pub function:String,
157	/// Arguments
158	pub args:Vec<Bytes>,
159}
160
161/// Host function callback type
162pub type HostFunctionCallback = fn(Vec<Bytes>) -> Result<Bytes>;
163
164/// Async host function callback type
165pub type AsyncHostFunctionCallback =
166	fn(Vec<Bytes>) -> Box<dyn std::future::Future<Output = Result<Bytes>> + Send + Unpin>;
167
168/// Host function definition
169#[derive(Debug)]
170pub struct HostFunction {
171	/// Function name
172	pub name:String,
173	/// Function signature
174	pub signature:FunctionSignature,
175	/// Synchronous callback - not serializable (skip serde derive)
176	#[allow(dead_code)]
177	pub callback:Option<HostFunctionCallback>,
178	/// Async callback - not serializable (skip serde derive)
179	#[allow(dead_code)]
180	pub async_callback:Option<AsyncHostFunctionCallback>,
181}
182
183/// Host Bridge for WASM communication
184#[derive(Debug)]
185pub struct HostBridgeImpl {
186	/// Registry of host functions exported to WASM
187	host_functions:Arc<RwLock<HashMap<String, HostFunction>>>,
188	/// Channel for receiving messages from WASM
189	wasm_to_host_rx:mpsc::UnboundedReceiver<WASMMessage>,
190	/// Channel for sending messages to WASM
191	host_to_wasm_tx:mpsc::UnboundedSender<WASMMessage>,
192	/// Active async callbacks
193	async_callbacks:Arc<RwLock<HashMap<u64, AsyncCallback>>>,
194	/// Next callback token
195	next_callback_token:Arc<std::sync::atomic::AtomicU64>,
196}
197
198impl HostBridgeImpl {
199	/// Create a new host bridge
200	pub fn new() -> Self {
201		let (_wasm_to_host_tx, wasm_to_host_rx) = mpsc::unbounded_channel();
202		let (host_to_wasm_tx, host_to_wasm_rx) = mpsc::unbounded_channel();
203
204		// In a real implementation, we'd need to wire these up properly
205		// For now, we drop the receiver to avoid unused warnings
206		drop(host_to_wasm_rx);
207
208		Self {
209			host_functions:Arc::new(RwLock::new(HashMap::new())),
210			wasm_to_host_rx,
211			host_to_wasm_tx,
212			async_callbacks:Arc::new(RwLock::new(HashMap::new())),
213			next_callback_token:Arc::new(std::sync::atomic::AtomicU64::new(0)),
214		}
215	}
216
217	/// Register a host function to be exported to WASM
218	pub async fn register_host_function(
219		&self,
220		name:&str,
221		signature:FunctionSignature,
222		callback:HostFunctionCallback,
223	) -> BridgeResult<()> {
224		dev_log!("wasm", "Registering host function: {}", name);
225
226		let mut functions = self.host_functions.write().await;
227
228		if functions.contains_key(name) {
229			dev_log!("wasm", "warn: host function already registered: {}", name);
230		}
231
232		functions.insert(
233			name.to_string(),
234			HostFunction { name:name.to_string(), signature, callback:Some(callback), async_callback:None },
235		);
236
237		dev_log!("wasm", "Host function registered successfully: {}", name);
238		Ok(())
239	}
240
241	/// Register an async host function
242	pub async fn register_async_host_function(
243		&self,
244		name:&str,
245		signature:FunctionSignature,
246		callback:AsyncHostFunctionCallback,
247	) -> BridgeResult<()> {
248		dev_log!("wasm", "Registering async host function: {}", name);
249
250		let mut functions = self.host_functions.write().await;
251
252		functions.insert(
253			name.to_string(),
254			HostFunction { name:name.to_string(), signature, callback:None, async_callback:Some(callback) },
255		);
256
257		dev_log!("wasm", "Async host function registered successfully: {}", name);
258		Ok(())
259	}
260
261	/// Call a host function from WASM
262	pub async fn call_host_function(&self, function_name:&str, args:Vec<Bytes>) -> BridgeResult<Bytes> {
263		dev_log!("wasm", "Calling host function: {}", function_name);
264
265		let functions = self.host_functions.read().await;
266		let func = functions
267			.get(function_name)
268			.ok_or_else(|| BridgeError::FunctionNotFound(function_name.to_string()))?;
269
270		if let Some(callback) = func.callback {
271			// Synchronous call
272			let result =
273				callback(args).map_err(|e| BridgeError::HostFunctionError(format!("{}: {}", function_name, e)))?;
274			dev_log!("wasm", "Host function call completed: {}", function_name);
275			Ok(result)
276		} else if let Some(async_callback) = func.async_callback {
277			// Async call
278			let future = async_callback(args);
279			let result = future
280				.await
281				.map_err(|e| BridgeError::HostFunctionError(format!("{}: {}", function_name, e)))?;
282			dev_log!("wasm", "Async host function call completed: {}", function_name);
283			Ok(result)
284		} else {
285			Err(BridgeError::FunctionNotFound(format!(
286				"No callback for function: {}",
287				function_name
288			)))
289		}
290	}
291
292	/// Send a message to WASM
293	pub async fn send_to_wasm(&self, message:WASMMessage) -> BridgeResult<()> {
294		let function_name = message.function.clone();
295		self.host_to_wasm_tx.send(message).map_err(|_| BridgeError::BridgeClosed)?;
296		dev_log!("wasm", "Message sent to WASM: {}", function_name);
297		Ok(())
298	}
299
300	/// Receive a message from WASM (blocking)
301	pub async fn receive_from_wasm(&mut self) -> Option<WASMMessage> { self.wasm_to_host_rx.recv().await }
302
303	/// Create async callback
304	pub async fn create_async_callback(&self, message_id:String) -> (AsyncCallback, u64) {
305		let token = self.next_callback_token.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
306		let (tx, _rx) = oneshot::channel();
307
308		// Create callback with Arc-wrapped sender
309		let callback = AsyncCallback {
310			sender:Arc::new(tokio::sync::Mutex::new(Some(tx))),
311			message_id:message_id.clone(),
312		};
313
314		self.async_callbacks.write().await.insert(token, callback.clone());
315
316		(callback, token)
317	}
318
319	/// Get callback by token
320	pub async fn get_callback(&self, token:u64) -> Option<AsyncCallback> {
321		self.async_callbacks.write().await.remove(&token)
322	}
323
324	/// Get all registered host functions
325	pub async fn get_host_functions(&self) -> Vec<String> { self.host_functions.read().await.keys().cloned().collect() }
326
327	/// Unregister a host function
328	pub async fn unregister_host_function(&self, name:&str) -> bool {
329		let mut functions = self.host_functions.write().await;
330		let removed = functions.remove(name).is_some();
331		if removed {
332			dev_log!("wasm", "Host function unregistered: {}", name);
333		}
334		removed
335	}
336
337	/// Clear all registered functions
338	pub async fn clear(&self) {
339		dev_log!("wasm", "Clearing all registered host functions");
340		self.host_functions.write().await.clear();
341		self.async_callbacks.write().await.clear();
342	}
343}
344
345impl Default for HostBridgeImpl {
346	fn default() -> Self { Self::new() }
347}
348
349/// Utility function to serialize data to Bytes
350pub fn serialize_to_bytes<T:Serialize>(data:&T) -> Result<Bytes> {
351	serde_json::to_vec(data)
352		.map(Bytes::from)
353		.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
354}
355
356/// Utility function to deserialize Bytes to data
357pub fn deserialize_from_bytes<T:DeserializeOwned>(bytes:&Bytes) -> Result<T> {
358	serde_json::from_slice(bytes).map_err(|e| anyhow::anyhow!("Deserialization error: {}", e))
359}
360
361/// Marshal arguments for WASM function call
362pub fn marshal_args(args:Vec<Bytes>) -> Result<Vec<wasmtime::Val>> {
363	args.iter()
364		.map(|bytes| {
365			let value:serde_json::Value = serde_json::from_slice(bytes)?;
366			match value {
367				serde_json::Value::Number(n) => {
368					if let Some(i) = n.as_i64() {
369						Ok(wasmtime::Val::I32(i as i32))
370					} else if let Some(f) = n.as_f64() {
371						Ok(wasmtime::Val::F64(f.to_bits()))
372					} else {
373						Err(anyhow::anyhow!("Invalid number value"))
374					}
375				},
376				_ => Err(anyhow::anyhow!("Unsupported argument type")),
377			}
378		})
379		.collect()
380}
381
382/// Unmarshal return values from WASM function call
383pub fn unmarshal_return(val:wasmtime::Val) -> Result<Bytes> {
384	match val {
385		wasmtime::Val::I32(i) => {
386			let json = serde_json::to_string(&i)?;
387			Ok(Bytes::from(json))
388		},
389		wasmtime::Val::I64(i) => {
390			let json = serde_json::to_string(&i)?;
391			Ok(Bytes::from(json))
392		},
393		wasmtime::Val::F32(f) => {
394			let json = serde_json::to_string(&f)?;
395			Ok(Bytes::from(json))
396		},
397		wasmtime::Val::F64(f) => {
398			let json = serde_json::to_string(&f)?;
399			Ok(Bytes::from(json))
400		},
401		_ => Err(anyhow::anyhow!("Unsupported return type")),
402	}
403}
404
405#[cfg(test)]
406mod tests {
407	use super::*;
408
409	#[test]
410	fn test_function_signature_creation() {
411		let signature = FunctionSignature {
412			name:"test_func".to_string(),
413			param_types:vec![ParamType::I32, ParamType::Ptr],
414			return_type:Some(ReturnType::I32),
415			is_async:false,
416		};
417
418		assert_eq!(signature.name, "test_func");
419		assert_eq!(signature.param_types.len(), 2);
420	}
421
422	#[tokio::test]
423	async fn test_host_bridge_creation() {
424		let bridge = HostBridgeImpl::new();
425		assert_eq!(bridge.get_host_functions().await.len(), 0);
426	}
427
428	#[tokio::test]
429	async fn test_register_host_function() {
430		let bridge = HostBridgeImpl::new();
431
432		let signature = FunctionSignature {
433			name:"echo".to_string(),
434			param_types:vec![ParamType::I32],
435			return_type:Some(ReturnType::I32),
436			is_async:false,
437		};
438
439		let result = bridge
440			.register_host_function("echo", signature, |args| Ok(args[0].clone()))
441			.await;
442
443		assert!(result.is_ok());
444		assert_eq!(bridge.get_host_functions().await.len(), 1);
445	}
446
447	#[test]
448	fn test_serialize_deserialize() {
449		let data = vec![1, 2, 3, 4, 5];
450		let bytes = serialize_to_bytes(&data).unwrap();
451		let recovered:Vec<i32> = deserialize_from_bytes(&bytes).unwrap();
452		assert_eq!(data, recovered);
453	}
454
455	#[test]
456	fn test_marshal_unmarshal() {
457		let args = vec![serialize_to_bytes(&42i32).unwrap(), serialize_to_bytes(&3.14f64).unwrap()];
458
459		// Test that marshaling works (we don't assert on exact type conversion)
460		let marshaled = marshal_args(args);
461		assert!(marshaled.is_ok());
462	}
463}