1use std::{collections::HashMap, sync::Arc};
7
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use tokio::sync::RwLock;
11use wasmtime::{Caller, Linker};
12
13use crate::{
14 WASM::HostBridge::{
15 FunctionSignature,
16 HostBridgeImpl,
17 HostBridgeImpl as HostBridge,
18 HostFunctionCallback,
19 ParamType,
20 ReturnType,
21 },
22 dev_log,
23};
24
25pub struct HostFunctionRegistry {
27 functions:Arc<RwLock<HashMap<String, RegisteredHostFunction>>>,
29 #[allow(dead_code)]
31 bridge:Arc<HostBridge>,
32}
33
34#[derive(Debug, Clone)]
36struct RegisteredHostFunction {
37 #[allow(dead_code)]
39 name:String,
40 #[allow(dead_code)]
42 signature:FunctionSignature,
43 callback:Option<HostFunctionCallback>,
45 #[allow(dead_code)]
47 registered_at:u64,
48 stats:FunctionStats,
50}
51
52#[derive(Debug, Clone, Default)]
54pub struct FunctionStats {
55 pub call_count:u64,
57 pub total_execution_ns:u64,
59 pub last_call_at:Option<u64>,
61 pub error_count:u64,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct ExportConfig {
68 pub auto_export:bool,
70 pub enable_stats:bool,
72 pub max_functions:usize,
74 pub name_prefix:Option<String>,
76}
77
78impl Default for ExportConfig {
79 fn default() -> Self {
80 Self {
81 auto_export:true,
82 enable_stats:true,
83 max_functions:1000,
84 name_prefix:Some("host_".to_string()),
85 }
86 }
87}
88
89pub struct FunctionExportImpl {
91 registry:Arc<HostFunctionRegistry>,
92 config:ExportConfig,
93}
94
95impl FunctionExportImpl {
96 pub fn new(bridge:Arc<HostBridge>) -> Self {
98 Self {
99 registry:Arc::new(HostFunctionRegistry { functions:Arc::new(RwLock::new(HashMap::new())), bridge }),
100 config:ExportConfig::default(),
101 }
102 }
103
104 pub fn with_config(bridge:Arc<HostBridge>, config:ExportConfig) -> Self {
106 Self {
107 registry:Arc::new(HostFunctionRegistry { functions:Arc::new(RwLock::new(HashMap::new())), bridge }),
108 config,
109 }
110 }
111
112 pub async fn register_function(
114 &self,
115 name:&str,
116 signature:FunctionSignature,
117 callback:HostFunctionCallback,
118 ) -> Result<()> {
119 dev_log!("wasm", "Registering host function for export: {}", name);
120
121 let functions = self.registry.functions.read().await;
122
123 if functions.len() >= self.config.max_functions {
125 return Err(anyhow::anyhow!(
126 "Maximum number of exported functions reached: {}",
127 self.config.max_functions
128 ));
129 }
130
131 drop(functions);
132
133 let mut functions = self.registry.functions.write().await;
134
135 let registered_at = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs();
136
137 functions.insert(
138 name.to_string(),
139 RegisteredHostFunction {
140 name:name.to_string(),
141 signature,
142 callback:Some(callback),
143 registered_at,
144 stats:FunctionStats::default(),
145 },
146 );
147
148 dev_log!("wasm", "Host function registered for WASM export: {}", name);
149 Ok(())
150 }
151
152 pub async fn register_functions(
154 &self,
155 signatures:Vec<FunctionSignature>,
156 callbacks:Vec<HostFunctionCallback>,
157 ) -> Result<()> {
158 if signatures.len() != callbacks.len() {
159 return Err(anyhow::anyhow!("Number of signatures must match number of callbacks"));
160 }
161
162 for (sig, callback) in signatures.into_iter().zip(callbacks) {
163 let name = sig.name.clone();
164 self.register_function(&name, sig, callback).await?;
165 }
166
167 Ok(())
168 }
169
170 pub async fn export_to_linker<T>(&self, linker:&mut Linker<T>) -> Result<()>
172 where
173 T: Send + 'static, {
174 dev_log!(
175 "wasm",
176 "Exporting {} host functions to linker",
177 self.registry.functions.read().await.len()
178 );
179
180 let functions = self.registry.functions.read().await;
181
182 for (name, func) in functions.iter() {
183 self.export_single_function(linker, name, func)?;
184 }
185
186 dev_log!("wasm", "All host functions exported to linker");
187 Ok(())
188 }
189
190 fn export_single_function<T>(&self, linker:&mut Linker<T>, name:&str, func:&RegisteredHostFunction) -> Result<()>
192 where
193 T: Send + 'static, {
194 dev_log!("wasm", "Exporting function: {}", name);
195
196 let callback = func
197 .callback
198 .ok_or_else(|| anyhow::anyhow!("No callback available for function: {}", name))?;
199
200 let func_name = if let Some(prefix) = &self.config.name_prefix {
201 format!("{}{}", prefix, name)
202 } else {
203 name.to_string()
204 };
205
206 let func_name_for_debug = func_name.clone();
207 let func_name_inner = func_name.clone();
208
209 let _wrapped_callback =
211 move |_caller:Caller<'_, T>, args:&[wasmtime::Val]| -> Result<Vec<wasmtime::Val>, wasmtime::Trap> {
212 let _start = std::time::Instant::now();
213
214 let args_bytes:Result<Vec<bytes::Bytes>, _> = args
216 .iter()
217 .map(|arg| {
218 match arg {
219 wasmtime::Val::I32(i) => {
220 serde_json::to_vec(i)
221 .map(bytes::Bytes::from)
222 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
223 },
224 wasmtime::Val::I64(i) => {
225 serde_json::to_vec(i)
226 .map(bytes::Bytes::from)
227 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
228 },
229 wasmtime::Val::F32(f) => {
230 serde_json::to_vec(f)
231 .map(bytes::Bytes::from)
232 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
233 },
234 wasmtime::Val::F64(f) => {
235 serde_json::to_vec(f)
236 .map(bytes::Bytes::from)
237 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
238 },
239 _ => Err(anyhow::anyhow!("Unsupported argument type")),
240 }
241 })
242 .collect();
243
244 let args_bytes = args_bytes.map_err(|_| {
245 dev_log!("wasm", "warn: error converting arguments for function '{}'", func_name_inner);
246 wasmtime::Trap::StackOverflow
247 })?;
248
249 let result = callback(args_bytes);
251
252 match result {
253 Ok(response_bytes) => {
254 let result_val:serde_json::Value = serde_json::from_slice(&response_bytes).map_err(|_| {
256 dev_log!("wasm", "warn: error deserializing response for function '{}'", func_name_inner);
257 wasmtime::Trap::StackOverflow
258 })?;
259
260 let ret_val = match result_val {
261 serde_json::Value::Number(n) => {
262 if let Some(i) = n.as_i64() {
263 wasmtime::Val::I32(i as i32)
264 } else if let Some(f) = n.as_f64() {
265 wasmtime::Val::I64(f as i64)
266 } else {
267 dev_log!("wasm", "warn: invalid number format for function '{}'", func_name_inner);
268 return Err(wasmtime::Trap::StackOverflow);
269 }
270 },
271 _ => {
272 dev_log!("wasm", "warn: unsupported response type for function '{}'", func_name_inner);
273 return Err(wasmtime::Trap::StackOverflow);
274 },
275 };
276
277 Ok(vec![ret_val])
278 },
279 Err(e) => {
280 dev_log!("wasm", "host function '{}' returned error: {}", func_name_inner, e);
282 Err(wasmtime::Trap::StackOverflow)
283 },
284 }
285 };
286
287 let _wasmparser_signature = wasmparser::FuncType::new([wasmparser::ValType::I32], [wasmparser::ValType::I32]);
289
290 let func_name_for_logging = func_name.clone();
294 linker
295 .func_wrap(
296 "_host", &func_name,
298 move |_caller:wasmtime::Caller<'_, T>, input_param:i32| -> i32 {
299 let start = std::time::Instant::now();
301
302 let args_bytes = match serde_json::to_vec(&input_param).map(bytes::Bytes::from) {
304 Ok(b) => b,
305 Err(e) => {
306 dev_log!(
307 "wasm",
308 "warn: serialization error for function '{}': {}",
309 func_name_for_logging,
310 e
311 );
312 return -1i32;
313 },
314 };
315
316 let result = callback(vec![args_bytes]);
318
319 match result {
320 Ok(response_bytes) => {
321 let result_val:serde_json::Value = match serde_json::from_slice(&response_bytes) {
323 Ok(v) => v,
324 Err(_) => {
325 dev_log!(
326 "wasm",
327 "warn: error deserializing response for function '{}'",
328 func_name_for_logging
329 );
330 return -1i32;
331 },
332 };
333
334 let ret_val = match result_val {
336 serde_json::Value::Number(n) => {
337 if let Some(i) = n.as_i64() {
338 i as i32
339 } else if let Some(f) = n.as_f64() {
340 f as i32
341 } else {
342 dev_log!(
343 "wasm",
344 "warn: invalid number format for function '{}'",
345 func_name_for_logging
346 );
347 -1i32
348 }
349 },
350 serde_json::Value::Bool(b) => {
351 if b {
352 1i32
353 } else {
354 0i32
355 }
356 },
357 _ => {
358 dev_log!(
359 "wasm",
360 "warn: unsupported response type for function '{}', expected number or bool",
361 func_name_for_logging
362 );
363 -1i32
364 },
365 };
366
367 let duration = start.elapsed();
369 dev_log!(
370 "wasm",
371 "[FunctionExport] Host function '{}' executed successfully in {}µs",
372 func_name_for_logging,
373 duration.as_micros()
374 );
375
376 ret_val
377 },
378 Err(e) => {
379 dev_log!(
381 "wasm",
382 "[FunctionExport] Host function '{}' returned error: {}",
383 func_name_for_logging,
384 e
385 );
386 -1i32
388 },
389 }
390 },
391 )
392 .map_err(|e| {
393 dev_log!(
394 "wasm",
395 "warn: [FunctionExport] failed to wrap host function '{}': {}",
396 func_name_for_debug,
397 e
398 );
399 e
400 })?;
401
402 dev_log!(
403 "wasm",
404 "[FunctionExport] Host function '{}' registered successfully",
405 func_name_for_debug
406 );
407
408 Ok(())
409 }
410
411 #[allow(dead_code)]
413 fn wasmtime_signature_from_signature(&self, _sig:&FunctionSignature) -> Result<wasmparser::FuncType> {
414 Ok(wasmparser::FuncType::new([], []))
417 }
418
419 pub async fn get_function_names(&self) -> Vec<String> {
421 self.registry.functions.read().await.keys().cloned().collect()
422 }
423
424 pub async fn get_function_stats(&self, name:&str) -> Option<FunctionStats> {
426 self.registry.functions.read().await.get(name).map(|f| f.stats.clone())
427 }
428
429 pub async fn unregister_function(&self, name:&str) -> Result<bool> {
431 let mut functions = self.registry.functions.write().await;
432 let removed = functions.remove(name).is_some();
433
434 if removed {
435 dev_log!("wasm", "Unregistered host function: {}", name);
436 } else {
437 dev_log!("wasm", "warn: attempted to unregister non-existent function: {}", name);
438 }
439
440 Ok(removed)
441 }
442
443 pub async fn clear(&self) {
445 dev_log!("wasm", "Clearing all registered host functions");
446 self.registry.functions.write().await.clear();
447 }
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453
454 #[tokio::test]
455 async fn test_function_export_creation() {
456 let bridge = Arc::new(HostBridgeImpl::new());
457 let export = FunctionExportImpl::new(bridge);
458
459 assert_eq!(export.get_function_names().await.len(), 0);
460 }
461
462 #[tokio::test]
463 async fn test_register_function() {
464 let bridge = Arc::new(HostBridgeImpl::new());
465 let export = FunctionExportImpl::new(bridge);
466
467 let signature = FunctionSignature {
468 name:"echo".to_string(),
469 param_types:vec![ParamType::I32],
470 return_type:Some(ReturnType::I32),
471 is_async:false,
472 };
473
474 let callback = |args:Vec<bytes::Bytes>| Ok(args.get(0).cloned().unwrap_or(bytes::Bytes::new()));
475
476 let result:anyhow::Result<()> = export.register_function("echo", signature, callback).await;
477 assert!(result.is_ok());
478 assert_eq!(export.get_function_names().await.len(), 1);
479 }
480
481 #[tokio::test]
482 async fn test_unregister_function() {
483 let bridge = Arc::new(HostBridgeImpl::new());
484 let export = FunctionExportImpl::new(bridge);
485
486 let signature = FunctionSignature {
487 name:"test".to_string(),
488 param_types:vec![ParamType::I32],
489 return_type:Some(ReturnType::I32),
490 is_async:false,
491 };
492
493 let callback = |_:Vec<bytes::Bytes>| Ok(bytes::Bytes::new());
494 let _:anyhow::Result<()> = export.register_function("test", signature, callback).await;
495
496 let result:bool = export.unregister_function("test").await.unwrap();
497 assert!(result);
498 assert_eq!(export.get_function_names().await.len(), 0);
499 }
500
501 #[test]
502 fn test_export_config_default() {
503 let config = ExportConfig::default();
504 assert_eq!(config.auto_export, true);
505 assert_eq!(config.max_functions, 1000);
506 }
507
508 #[test]
509 fn test_function_stats_default() {
510 let stats = FunctionStats::default();
511 assert_eq!(stats.call_count, 0);
512 assert_eq!(stats.error_count, 0);
513 }
514}