1use std::{collections::HashMap, sync::Arc};
8
9use anyhow::Result;
10use bytes::Bytes;
11use serde::{Serialize, de::DeserializeOwned};
12use tokio::sync::{RwLock, mpsc, oneshot};
13#[allow(unused_imports)]
14use wasmtime::{Caller, Extern, Func, Linker, Store};
15
16use crate::dev_log;
17
18#[derive(Debug, thiserror::Error)]
20pub enum BridgeError {
21 #[error("Function not found: {0}")]
23 FunctionNotFound(String),
24
25 #[error("Invalid function signature: {0}")]
27 InvalidSignature(String),
28
29 #[error("Serialization failed: {0}")]
31 SerializationError(String),
32
33 #[error("Deserialization failed: {0}")]
35 DeserializationError(String),
36
37 #[error("Host function error: {0}")]
39 HostFunctionError(String),
40
41 #[error("Communication timeout")]
43 Timeout,
44
45 #[error("Bridge closed")]
47 BridgeClosed,
48}
49
50pub type BridgeResult<T> = Result<T, BridgeError>;
52
53#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
55pub struct FunctionSignature {
56 pub name:String,
58 pub param_types:Vec<ParamType>,
60 pub return_type:Option<ReturnType>,
62 pub is_async:bool,
64}
65
66#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
68pub enum ParamType {
69 I32,
71 I64,
73 F32,
75 F64,
77 Ptr,
79 Len,
81}
82
83#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
85pub enum ReturnType {
86 I32,
88 I64,
90 F32,
92 F64,
94 Void,
96}
97
98#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
100pub struct HostMessage {
101 pub message_id:String,
103 pub function:String,
105 pub args:Vec<Bytes>,
107 pub callback_token:Option<u64>,
109}
110
111#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
113pub struct HostResponse {
114 pub message_id:String,
116 pub success:bool,
118 pub data:Option<Bytes>,
120 pub error:Option<String>,
122}
123
124#[derive(Clone)]
126pub struct AsyncCallback {
127 sender:Arc<tokio::sync::Mutex<Option<tokio::sync::oneshot::Sender<HostResponse>>>>,
129 message_id:String,
131}
132
133impl std::fmt::Debug for AsyncCallback {
134 fn fmt(&self, f:&mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 f.debug_struct("AsyncCallback").field("message_id", &self.message_id).finish()
136 }
137}
138
139impl AsyncCallback {
140 pub async fn send(self, response:HostResponse) -> Result<()> {
142 let mut sender_opt = self.sender.lock().await;
143 if let Some(sender) = sender_opt.take() {
144 sender.send(response).map_err(|_| BridgeError::BridgeClosed)?;
145 Ok(())
146 } else {
147 Err(BridgeError::BridgeClosed.into())
148 }
149 }
150}
151
152#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
154pub struct WASMMessage {
155 pub function:String,
157 pub args:Vec<Bytes>,
159}
160
161pub type HostFunctionCallback = fn(Vec<Bytes>) -> Result<Bytes>;
163
164pub type AsyncHostFunctionCallback =
166 fn(Vec<Bytes>) -> Box<dyn std::future::Future<Output = Result<Bytes>> + Send + Unpin>;
167
168#[derive(Debug)]
170pub struct HostFunction {
171 pub name:String,
173 pub signature:FunctionSignature,
175 #[allow(dead_code)]
177 pub callback:Option<HostFunctionCallback>,
178 #[allow(dead_code)]
180 pub async_callback:Option<AsyncHostFunctionCallback>,
181}
182
183#[derive(Debug)]
185pub struct HostBridgeImpl {
186 host_functions:Arc<RwLock<HashMap<String, HostFunction>>>,
188 wasm_to_host_rx:mpsc::UnboundedReceiver<WASMMessage>,
190 host_to_wasm_tx:mpsc::UnboundedSender<WASMMessage>,
192 async_callbacks:Arc<RwLock<HashMap<u64, AsyncCallback>>>,
194 next_callback_token:Arc<std::sync::atomic::AtomicU64>,
196}
197
198impl HostBridgeImpl {
199 pub fn new() -> Self {
201 let (_wasm_to_host_tx, wasm_to_host_rx) = mpsc::unbounded_channel();
202 let (host_to_wasm_tx, host_to_wasm_rx) = mpsc::unbounded_channel();
203
204 drop(host_to_wasm_rx);
207
208 Self {
209 host_functions:Arc::new(RwLock::new(HashMap::new())),
210 wasm_to_host_rx,
211 host_to_wasm_tx,
212 async_callbacks:Arc::new(RwLock::new(HashMap::new())),
213 next_callback_token:Arc::new(std::sync::atomic::AtomicU64::new(0)),
214 }
215 }
216
217 pub async fn register_host_function(
219 &self,
220 name:&str,
221 signature:FunctionSignature,
222 callback:HostFunctionCallback,
223 ) -> BridgeResult<()> {
224 dev_log!("wasm", "Registering host function: {}", name);
225
226 let mut functions = self.host_functions.write().await;
227
228 if functions.contains_key(name) {
229 dev_log!("wasm", "warn: host function already registered: {}", name);
230 }
231
232 functions.insert(
233 name.to_string(),
234 HostFunction { name:name.to_string(), signature, callback:Some(callback), async_callback:None },
235 );
236
237 dev_log!("wasm", "Host function registered successfully: {}", name);
238 Ok(())
239 }
240
241 pub async fn register_async_host_function(
243 &self,
244 name:&str,
245 signature:FunctionSignature,
246 callback:AsyncHostFunctionCallback,
247 ) -> BridgeResult<()> {
248 dev_log!("wasm", "Registering async host function: {}", name);
249
250 let mut functions = self.host_functions.write().await;
251
252 functions.insert(
253 name.to_string(),
254 HostFunction { name:name.to_string(), signature, callback:None, async_callback:Some(callback) },
255 );
256
257 dev_log!("wasm", "Async host function registered successfully: {}", name);
258 Ok(())
259 }
260
261 pub async fn call_host_function(&self, function_name:&str, args:Vec<Bytes>) -> BridgeResult<Bytes> {
263 dev_log!("wasm", "Calling host function: {}", function_name);
264
265 let functions = self.host_functions.read().await;
266 let func = functions
267 .get(function_name)
268 .ok_or_else(|| BridgeError::FunctionNotFound(function_name.to_string()))?;
269
270 if let Some(callback) = func.callback {
271 let result =
273 callback(args).map_err(|e| BridgeError::HostFunctionError(format!("{}: {}", function_name, e)))?;
274 dev_log!("wasm", "Host function call completed: {}", function_name);
275 Ok(result)
276 } else if let Some(async_callback) = func.async_callback {
277 let future = async_callback(args);
279 let result = future
280 .await
281 .map_err(|e| BridgeError::HostFunctionError(format!("{}: {}", function_name, e)))?;
282 dev_log!("wasm", "Async host function call completed: {}", function_name);
283 Ok(result)
284 } else {
285 Err(BridgeError::FunctionNotFound(format!(
286 "No callback for function: {}",
287 function_name
288 )))
289 }
290 }
291
292 pub async fn send_to_wasm(&self, message:WASMMessage) -> BridgeResult<()> {
294 let function_name = message.function.clone();
295 self.host_to_wasm_tx.send(message).map_err(|_| BridgeError::BridgeClosed)?;
296 dev_log!("wasm", "Message sent to WASM: {}", function_name);
297 Ok(())
298 }
299
300 pub async fn receive_from_wasm(&mut self) -> Option<WASMMessage> { self.wasm_to_host_rx.recv().await }
302
303 pub async fn create_async_callback(&self, message_id:String) -> (AsyncCallback, u64) {
305 let token = self.next_callback_token.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
306 let (tx, _rx) = oneshot::channel();
307
308 let callback = AsyncCallback {
310 sender:Arc::new(tokio::sync::Mutex::new(Some(tx))),
311 message_id:message_id.clone(),
312 };
313
314 self.async_callbacks.write().await.insert(token, callback.clone());
315
316 (callback, token)
317 }
318
319 pub async fn get_callback(&self, token:u64) -> Option<AsyncCallback> {
321 self.async_callbacks.write().await.remove(&token)
322 }
323
324 pub async fn get_host_functions(&self) -> Vec<String> { self.host_functions.read().await.keys().cloned().collect() }
326
327 pub async fn unregister_host_function(&self, name:&str) -> bool {
329 let mut functions = self.host_functions.write().await;
330 let removed = functions.remove(name).is_some();
331 if removed {
332 dev_log!("wasm", "Host function unregistered: {}", name);
333 }
334 removed
335 }
336
337 pub async fn clear(&self) {
339 dev_log!("wasm", "Clearing all registered host functions");
340 self.host_functions.write().await.clear();
341 self.async_callbacks.write().await.clear();
342 }
343}
344
345impl Default for HostBridgeImpl {
346 fn default() -> Self { Self::new() }
347}
348
349pub fn serialize_to_bytes<T:Serialize>(data:&T) -> Result<Bytes> {
351 serde_json::to_vec(data)
352 .map(Bytes::from)
353 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
354}
355
356pub fn deserialize_from_bytes<T:DeserializeOwned>(bytes:&Bytes) -> Result<T> {
358 serde_json::from_slice(bytes).map_err(|e| anyhow::anyhow!("Deserialization error: {}", e))
359}
360
361pub fn marshal_args(args:Vec<Bytes>) -> Result<Vec<wasmtime::Val>> {
363 args.iter()
364 .map(|bytes| {
365 let value:serde_json::Value = serde_json::from_slice(bytes)?;
366 match value {
367 serde_json::Value::Number(n) => {
368 if let Some(i) = n.as_i64() {
369 Ok(wasmtime::Val::I32(i as i32))
370 } else if let Some(f) = n.as_f64() {
371 Ok(wasmtime::Val::F64(f.to_bits()))
372 } else {
373 Err(anyhow::anyhow!("Invalid number value"))
374 }
375 },
376 _ => Err(anyhow::anyhow!("Unsupported argument type")),
377 }
378 })
379 .collect()
380}
381
382pub fn unmarshal_return(val:wasmtime::Val) -> Result<Bytes> {
384 match val {
385 wasmtime::Val::I32(i) => {
386 let json = serde_json::to_string(&i)?;
387 Ok(Bytes::from(json))
388 },
389 wasmtime::Val::I64(i) => {
390 let json = serde_json::to_string(&i)?;
391 Ok(Bytes::from(json))
392 },
393 wasmtime::Val::F32(f) => {
394 let json = serde_json::to_string(&f)?;
395 Ok(Bytes::from(json))
396 },
397 wasmtime::Val::F64(f) => {
398 let json = serde_json::to_string(&f)?;
399 Ok(Bytes::from(json))
400 },
401 _ => Err(anyhow::anyhow!("Unsupported return type")),
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn test_function_signature_creation() {
411 let signature = FunctionSignature {
412 name:"test_func".to_string(),
413 param_types:vec![ParamType::I32, ParamType::Ptr],
414 return_type:Some(ReturnType::I32),
415 is_async:false,
416 };
417
418 assert_eq!(signature.name, "test_func");
419 assert_eq!(signature.param_types.len(), 2);
420 }
421
422 #[tokio::test]
423 async fn test_host_bridge_creation() {
424 let bridge = HostBridgeImpl::new();
425 assert_eq!(bridge.get_host_functions().await.len(), 0);
426 }
427
428 #[tokio::test]
429 async fn test_register_host_function() {
430 let bridge = HostBridgeImpl::new();
431
432 let signature = FunctionSignature {
433 name:"echo".to_string(),
434 param_types:vec![ParamType::I32],
435 return_type:Some(ReturnType::I32),
436 is_async:false,
437 };
438
439 let result = bridge
440 .register_host_function("echo", signature, |args| Ok(args[0].clone()))
441 .await;
442
443 assert!(result.is_ok());
444 assert_eq!(bridge.get_host_functions().await.len(), 1);
445 }
446
447 #[test]
448 fn test_serialize_deserialize() {
449 let data = vec![1, 2, 3, 4, 5];
450 let bytes = serialize_to_bytes(&data).unwrap();
451 let recovered:Vec<i32> = deserialize_from_bytes(&bytes).unwrap();
452 assert_eq!(data, recovered);
453 }
454
455 #[test]
456 fn test_marshal_unmarshal() {
457 let args = vec![serialize_to_bytes(&42i32).unwrap(), serialize_to_bytes(&3.14f64).unwrap()];
458
459 let marshaled = marshal_args(args);
461 assert!(marshaled.is_ok());
462 }
463}