1use std::{collections::HashMap, sync::Arc};
7
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use tokio::sync::RwLock;
11
12use crate::dev_log;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct APICallRequest {
17 pub extension_id:String,
19 pub api_method:String,
21 pub arguments:Vec<serde_json::Value>,
23 pub correlation_id:Option<String>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct APICallResponse {
30 pub success:bool,
32 pub data:Option<serde_json::Value>,
34 pub error:Option<String>,
36 pub correlation_id:Option<String>,
38}
39
40#[allow(dead_code)]
42pub struct APICall {
43 extension_id:String,
45 api_method:String,
47 arguments:Vec<serde_json::Value>,
49 timestamp:u64,
51}
52
53#[allow(dead_code)]
55type APIMethodHandler = fn(&str, Vec<serde_json::Value>) -> Result<serde_json::Value>;
56
57#[allow(dead_code)]
59type AsyncAPIMethodHandler =
60 fn(&str, Vec<serde_json::Value>) -> Box<dyn std::future::Future<Output = Result<serde_json::Value>> + Send + Unpin>;
61
62#[derive(Clone)]
64pub struct APIMethodInfo {
65 #[allow(dead_code)]
67 name:String,
68 #[allow(dead_code)]
70 description:String,
71 #[allow(dead_code)]
73 parameters:Option<serde_json::Value>,
74 #[allow(dead_code)]
76 returns:Option<serde_json::Value>,
77 #[allow(dead_code)]
79 is_async:bool,
80 call_count:u64,
82 total_time_us:u64,
84}
85
86pub struct APIBridgeImpl {
88 api_methods:Arc<RwLock<HashMap<String, APIMethodInfo>>>,
90 stats:Arc<RwLock<APIStats>>,
92 contexts:Arc<RwLock<HashMap<String, APIContext>>>,
94}
95
96#[derive(Debug, Clone, Default, Serialize, Deserialize)]
98pub struct APIStats {
99 pub total_calls:u64,
101 pub successful_calls:u64,
103 pub failed_calls:u64,
105 pub avg_latency_us:u64,
107 pub active_contexts:usize,
109}
110
111#[derive(Debug, Clone)]
113pub struct APIContext {
114 pub extension_id:String,
116 pub context_id:String,
118 pub workspace_folder:Option<String>,
120 pub active_editor:Option<String>,
122 pub selections:Vec<Selection>,
124 pub created_at:u64,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct Selection {
131 pub start_line:u32,
133 pub start_character:u32,
135 pub end_line:u32,
137 pub end_character:u32,
139}
140
141impl Default for Selection {
142 fn default() -> Self { Self { start_line:0, start_character:0, end_line:0, end_character:0 } }
143}
144
145impl APIBridgeImpl {
146 pub fn new() -> Self {
148 let bridge = Self {
149 api_methods:Arc::new(RwLock::new(HashMap::new())),
150 stats:Arc::new(RwLock::new(APIStats::default())),
151 contexts:Arc::new(RwLock::new(HashMap::new())),
152 };
153
154 bridge.register_builtin_methods();
155
156 bridge
157 }
158
159 fn register_builtin_methods(&self) {
161 dev_log!("extensions", "Registered built-in VS Code API methods");
170 }
171
172 pub async fn register_method(
174 &self,
175 name:&str,
176 description:&str,
177 parameters:Option<serde_json::Value>,
178 returns:Option<serde_json::Value>,
179 is_async:bool,
180 ) -> Result<()> {
181 let mut methods = self.api_methods.write().await;
182
183 if methods.contains_key(name) {
184 dev_log!("extensions", "warn: API method already registered: {}", name);
185 }
186
187 methods.insert(
188 name.to_string(),
189 APIMethodInfo {
190 name:name.to_string(),
191 description:description.to_string(),
192 parameters,
193 returns,
194 is_async,
195 call_count:0,
196 total_time_us:0,
197 },
198 );
199
200 dev_log!("extensions", "Registered API method: {}", name);
201
202 Ok(())
203 }
204
205 pub async fn create_context(&self, extension_id:&str) -> Result<APIContext> {
207 let context_id = format!("{}-{}", extension_id, uuid::Uuid::new_v4());
208
209 let context = APIContext {
210 extension_id:extension_id.to_string(),
211 context_id:context_id.clone(),
212 workspace_folder:None,
213 active_editor:None,
214 selections:Vec::new(),
215 created_at:std::time::SystemTime::now()
216 .duration_since(std::time::UNIX_EPOCH)
217 .map(|d| d.as_secs())
218 .unwrap_or(0),
219 };
220
221 let mut contexts = self.contexts.write().await;
222 contexts.insert(context_id.clone(), context.clone());
223
224 let mut stats = self.stats.write().await;
226 stats.active_contexts = contexts.len();
227
228 dev_log!("extensions", "Created API context for extension: {}", extension_id);
229
230 Ok(context)
231 }
232
233 pub async fn get_context(&self, context_id:&str) -> Option<APIContext> {
235 self.contexts.read().await.get(context_id).cloned()
236 }
237
238 pub async fn update_context(&self, context:APIContext) -> Result<()> {
240 let mut contexts = self.contexts.write().await;
241 contexts.insert(context.context_id.clone(), context);
242 Ok(())
243 }
244
245 pub async fn remove_context(&self, context_id:&str) -> Result<bool> {
247 let mut contexts = self.contexts.write().await;
248 let removed = contexts.remove(context_id).is_some();
249
250 if removed {
251 let mut stats = self.stats.write().await;
252 stats.active_contexts = contexts.len();
253 }
254
255 Ok(removed)
256 }
257
258 pub async fn Call(&self, request:APICallRequest) -> Result<APICallResponse> {
260 let start = std::time::Instant::now();
261
262 dev_log!(
263 "extensions",
264 "Handling API call: {} from extension {}",
265 request.api_method,
266 request.extension_id
267 );
268
269 let exists = {
271 let methods = self.api_methods.read().await;
272 methods.contains_key(&request.api_method)
273 };
274
275 if !exists {
276 return Ok(APICallResponse {
277 success:false,
278 data:None,
279 error:Some(format!("API method not found: {}", request.api_method)),
280 correlation_id:request.correlation_id,
281 });
282 }
283
284 let result = self
287 .execute_method(&request.extension_id, &request.api_method, &request.arguments)
288 .await;
289
290 let elapsed_us = start.elapsed().as_micros() as u64;
291
292 let mut stats = self.stats.write().await;
294 stats.total_calls += 1;
295 stats.total_calls += 1;
296 if exists {
297 stats.successful_calls += 1;
298 stats.avg_latency_us =
300 (stats.avg_latency_us * (stats.successful_calls - 1) + elapsed_us) / stats.successful_calls;
301 }
302
303 {
305 let mut methods = self.api_methods.write().await;
306 if let Some(method) = methods.get_mut(&request.api_method) {
307 method.call_count += 1;
308 method.total_time_us += elapsed_us;
309 }
310 }
311
312 dev_log!("extensions", "API call {} completed in {}µs", request.api_method, elapsed_us);
313
314 match result {
315 Ok(data) => {
316 Ok(
317 APICallResponse {
318 success:true,
319 data:Some(data),
320 error:None,
321 correlation_id:request.correlation_id,
322 },
323 )
324 },
325 Err(e) => {
326 Ok(APICallResponse {
327 success:false,
328 data:None,
329 error:Some(e.to_string()),
330 correlation_id:request.correlation_id,
331 })
332 },
333 }
334 }
335
336 async fn execute_method(
338 &self,
339 _extension_id:&str,
340 _method_name:&str,
341 _arguments:&[serde_json::Value],
342 ) -> Result<serde_json::Value> {
343 Ok(serde_json::Value::Null)
352 }
353
354 pub async fn stats(&self) -> APIStats { self.stats.read().await.clone() }
356
357 pub async fn get_methods(&self) -> Vec<APIMethodInfo> { self.api_methods.read().await.values().cloned().collect() }
359
360 pub async fn unregister_method(&self, name:&str) -> Result<bool> {
362 let mut methods = self.api_methods.write().await;
363 let removed = methods.remove(name).is_some();
364
365 if removed {
366 dev_log!("extensions", "Unregistered API method: {}", name);
367 }
368
369 Ok(removed)
370 }
371}
372
373impl Default for APIBridgeImpl {
374 fn default() -> Self { Self::new() }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380
381 #[tokio::test]
382 async fn test_api_bridge_creation() {
383 let bridge = APIBridgeImpl::new();
384 let stats = bridge.stats().await;
385 assert_eq!(stats.total_calls, 0);
386 assert_eq!(stats.successful_calls, 0);
387 }
388
389 #[tokio::test]
390 async fn test_context_creation() {
391 let bridge = APIBridgeImpl::new();
392 let context = bridge.create_context("test.ext").await.unwrap();
393 assert_eq!(context.extension_id, "test.ext");
394 assert!(!context.context_id.is_empty());
395 }
396
397 #[tokio::test]
398 async fn test_method_registration() {
399 let bridge = APIBridgeImpl::new();
400 let result:Result<()> = bridge.register_method("test.method", "Test method", None, None, false).await;
401 assert!(result.is_ok());
402
403 let methods:Vec<APIMethodInfo> = bridge.get_methods().await;
404 assert!(methods.iter().any(|m| m.name == "test.method"));
405 }
406
407 #[tokio::test]
408 async fn test_api_call_request() {
409 let request = APICallRequest {
410 extension_id:"test.ext".to_string(),
411 api_method:"test.method".to_string(),
412 arguments:vec![serde_json::json!("arg1")],
413 correlation_id:Some("test-id".to_string()),
414 };
415
416 assert_eq!(request.extension_id, "test.ext");
417 assert_eq!(request.api_method, "test.method");
418 assert_eq!(request.arguments.len(), 1);
419 }
420
421 #[test]
422 fn test_selection_default() {
423 let selection = Selection::default();
424 assert_eq!(selection.start_line, 0);
425 assert_eq!(selection.end_line, 0);
426 }
427}