Compare commits
	
		
			5 Commits
		
	
	
		
			be481eaa37
			...
			0842b4d0bc
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 0842b4d0bc | |||
| d1f845b068 | |||
| e6db6bcc68 | |||
| 86fdef4b38 | |||
| 4269799f8f | 
							
								
								
									
										1719
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										1719
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -6,9 +6,14 @@ description = "Core traits and types for building API clients with http_derive" | ||||
| license = "MIT OR Apache-2.0" | ||||
|  | ||||
| [dependencies] | ||||
| anyhow = "1.0.100" | ||||
| async-trait = "0.1" | ||||
| awc = "3"  | ||||
| reqwest = { version = "0.12.23" , features = ["json", "rustls-tls"] } | ||||
| serde = { version = "1", features = ["derive"] }  | ||||
| serde_json = "1.0.145" | ||||
| toml = "0.9.5" | ||||
| urlencoding = "2"  | ||||
| tokio = { version = "1.47", features = ["full"] } | ||||
| once_cell = "1.19" | ||||
| num_cpus = "1.16" | ||||
|  | ||||
|  | ||||
							
								
								
									
										449
									
								
								src/lib.rs
									
									
									
									
									
								
							
							
						
						
									
										449
									
								
								src/lib.rs
									
									
									
									
									
								
							| @ -1,64 +1,24 @@ | ||||
| use anyhow::{anyhow, Context, Result}; | ||||
| use async_trait::async_trait; | ||||
| use serde::{ Deserialize, Serialize }; | ||||
| use std::collections::HashMap; | ||||
| use std::path::Path; | ||||
| use std::fs; | ||||
| use once_cell::sync::Lazy; | ||||
| use num_cpus; | ||||
| use reqwest::{Client, Method, header::HeaderMap }; | ||||
| use serde::{de::DeserializeOwned, Deserialize, Serialize}; | ||||
| use serde_json::Value; | ||||
| use std::{collections::HashMap, fs, path::Path, sync::Arc}; | ||||
| use tokio::sync::Semaphore; | ||||
|  | ||||
| /// A trait every API request type should implement (via macro). | ||||
| /// Provides info about the HTTP protocol, URLs, and supported methods. | ||||
| pub trait HasHttp { | ||||
|     /// Methods this request supports (GET, POST, etc.) | ||||
|     fn http_methods() -> &'static [&'static str]; | ||||
| /// --------------------------------------------------------------------------- | ||||
| /// GLOBAL CONCURRENCY CONTROL | ||||
| /// --------------------------------------------------------------------------- | ||||
|  | ||||
|     /// Production endpoint (base URL). | ||||
|     fn live_url() -> &'static str; | ||||
| /// Shared semaphore to limit concurrent requests (2 × CPU threads). | ||||
| static SEMAPHORE: Lazy<Arc<Semaphore>> = | ||||
|     Lazy::new(|| Arc::new(Semaphore::new(num_cpus::get() * 2))); | ||||
|  | ||||
|     /// Sandbox endpoint (base URL). | ||||
|     fn sandbox_url() -> &'static str; | ||||
| } | ||||
|  | ||||
| /// Trait implemented by all request types that can be sent to an HTTP API. | ||||
| /// Usually derived via `#[derive(HttpRequest)]`. | ||||
| #[async_trait] | ||||
| pub trait Queryable: HasHttp + Send + Sync { | ||||
|     /// The response type for this query. | ||||
|     type R; | ||||
|     /// The error type for this query. | ||||
|     type E: std::error::Error + Send + Sync + 'static; | ||||
|  | ||||
|     /// Send the query, with flexibility for environment and overrides. | ||||
|     /// | ||||
|     /// - `override_url`: if `Some`, use this instead of live/sandbox. | ||||
|     /// - `sandbox`: if true, use sandbox_url. | ||||
|     /// - `method_override`: optional HTTP method override (GET, POST, etc.). | ||||
|     /// - `headers`: optional headers. | ||||
|     async fn send( | ||||
|         &self, | ||||
|         override_url: Option<&str>, | ||||
|         sandbox: bool, | ||||
|         method_override: Option<&str>, | ||||
|         headers: Option<Vec<(&str, &str)>>, | ||||
|     ) -> Result<Self::R, Self::E>; | ||||
| } | ||||
|  | ||||
| /// Trait for dispatching API calls in bulk or single-shot mode. | ||||
| /// | ||||
| /// Implemented automatically for generated enums by the `#[api_dispatch]` macro. | ||||
| #[async_trait] | ||||
| pub trait ApiDispatch { | ||||
|     /// Send all queries represented by this enum variant. | ||||
|     /// | ||||
|     /// - `client`: A preconfigured `awc::Client` | ||||
|     /// - `keys`: A shared API key store (depends on your app crate) | ||||
|     async fn send_all( | ||||
|         &self, | ||||
|         client: &awc::Client, | ||||
|         keys: &HashMap<String, crate::Keys>, // ⚠️ placeholder, you may need to re-export Keys | ||||
|         override_url: Option<&str>, | ||||
|         sandbox: bool, | ||||
|         method_override: Option<&str>, | ||||
|     ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>; | ||||
| } | ||||
| /// --------------------------------------------------------------------------- | ||||
| /// API KEY MANAGEMENT | ||||
| /// --------------------------------------------------------------------------- | ||||
|  | ||||
| #[derive(Debug, Clone, Deserialize, Serialize)] | ||||
| pub struct Key { | ||||
| @ -72,17 +32,16 @@ pub struct Keys { | ||||
| } | ||||
|  | ||||
| impl Keys { | ||||
|     /// Convenience method to get a key by name (e.g. "alpaca") | ||||
|     pub fn get(&self, name: &str) -> Option<&Key> { | ||||
|         self.keys.get(name) | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// Loads API keys from ~/.config/keys or fallback ./keys directory. | ||||
| /// Each `.toml` file should contain a `Key { key, secret }` struct. | ||||
| /// The filename (without `.toml`) becomes the key name in the map. | ||||
| pub fn load_api_keys() -> Result<Keys, Box<dyn std::error::Error>> { | ||||
|     let home_dir = std::env::var("HOME")?; | ||||
| /// Load API keys from either: | ||||
| /// - `$HOME/.config/keys` | ||||
| /// - or local `./keys` directory | ||||
| pub fn load_api_keys() -> Result<Keys> { | ||||
|     let home_dir = std::env::var("HOME").context("Failed to read $HOME")?; | ||||
|     let config_dir = Path::new(&home_dir).join(".config/keys"); | ||||
|     let fallback_dir = Path::new("keys"); | ||||
|  | ||||
| @ -94,14 +53,17 @@ pub fn load_api_keys() -> Result<Keys, Box<dyn std::error::Error>> { | ||||
|  | ||||
|     let mut keys_map = HashMap::new(); | ||||
|  | ||||
|     for entry in fs::read_dir(dir_path)? { | ||||
|         let entry = entry?; | ||||
|     for entry in fs::read_dir(&dir_path) | ||||
|         .with_context(|| format!("Failed to read keys directory: {}", dir_path.display()))? | ||||
|     { | ||||
|         let entry = entry.context("Failed to read directory entry")?; | ||||
|         let path = entry.path(); | ||||
|  | ||||
|         if path.extension().and_then(|s| s.to_str()) == Some("toml") { | ||||
|             let contents = fs::read_to_string(&path)?; | ||||
|             let key: Key = toml::from_str(&contents)?; | ||||
|  | ||||
|             let contents = fs::read_to_string(&path) | ||||
|                 .with_context(|| format!("Failed to read file: {}", path.display()))?; | ||||
|             let key: Key = toml::from_str(&contents) | ||||
|                 .with_context(|| format!("Failed to parse TOML: {}", path.display()))?; | ||||
|             if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) { | ||||
|                 keys_map.insert(stem.to_string(), key); | ||||
|             } | ||||
| @ -111,4 +73,355 @@ pub fn load_api_keys() -> Result<Keys, Box<dyn std::error::Error>> { | ||||
|     Ok(Keys { keys: keys_map }) | ||||
| } | ||||
|  | ||||
| /// --------------------------------------------------------------------------- | ||||
| /// CORE HTTP REQUEST HELPER | ||||
| /// --------------------------------------------------------------------------- | ||||
|  | ||||
| /// Shared helper to send HTTP requests with concurrency limiting, API key injection, | ||||
| /// and optional custom headers. | ||||
| pub async fn send_request<T, R>( | ||||
|     client: Arc<Client>, | ||||
|     keys: Arc<Keys>, | ||||
|     api_name: &str, | ||||
|     base_url: &str, | ||||
|     path_template: &str, | ||||
|     override_url: Option<&str>, | ||||
|     method: &str, | ||||
|     headers: Option<&HeaderMap>, | ||||
|     body: &T, | ||||
| ) -> Result<R> | ||||
| where | ||||
|     T: Serialize + Send + Sync, | ||||
|     R: DeserializeOwned + Send + 'static, | ||||
| { | ||||
|     // Acquire semaphore permit (limits concurrency) | ||||
|     let _permit = SEMAPHORE.acquire().await.expect("Semaphore poisoned"); | ||||
|  | ||||
|     // Build initial URL: override or base + path_template | ||||
|     let mut url = if let Some(o) = override_url { | ||||
|         o.to_string() | ||||
|     } else { | ||||
|         if base_url.ends_with('/') { | ||||
|             format!("{}{}", base_url.trim_end_matches('/'), path_template) | ||||
|         } else { | ||||
|             format!( | ||||
|                 "{}/{}", | ||||
|                 base_url.trim_end_matches('/'), | ||||
|                 path_template.trim_start_matches('/') | ||||
|             ) | ||||
|         } | ||||
|     }; | ||||
|  | ||||
|     // Apply path + query replacements (safe no-ops if none exist) | ||||
|     url = replace_path_params(&url, &[], body) | ||||
|         .context("Failed to replace path params")?; | ||||
|     url = append_query_params(&url, &[], body) | ||||
|         .context("Failed to append query params")?; | ||||
|  | ||||
|     // Load API keys (pre-provided Arc<Keys>) | ||||
|     let (api_key, secret_key) = keys | ||||
|         .get(api_name) | ||||
|         .map(|k| (k.key.clone(), k.secret.clone())) | ||||
|         .unwrap_or_else(|| (String::new(), String::new())); | ||||
|  | ||||
|     // Parse HTTP method into reqwest::Method | ||||
|     let method_obj = | ||||
|         Method::from_bytes(method.as_bytes()).context("Invalid HTTP method string")?; | ||||
|  | ||||
|     // Build request | ||||
|     let mut req = client.request(method_obj.clone(), &url); | ||||
|  | ||||
|     // Attach API credentials if present | ||||
|     if !api_key.is_empty() { | ||||
|         req = req | ||||
|             .header("APCA-API-KEY-ID", api_key) | ||||
|             .header("APCA-API-SECRET-KEY", secret_key); | ||||
|     } | ||||
|  | ||||
|     // ✅ Add custom headers if provided | ||||
|     if let Some(extra_headers) = headers { | ||||
|         for (k, v) in extra_headers { | ||||
|             req = req.header(k, v); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // Attach JSON body only for non-GET/HEAD methods | ||||
|     if !(method_obj == Method::GET || method_obj == Method::HEAD) { | ||||
|         req = req.json(body); | ||||
|     } | ||||
|  | ||||
|     // Send | ||||
|     let resp = req.send().await.context("HTTP send failed")?; | ||||
|     let status = resp.status(); | ||||
|  | ||||
|     if status.is_success() { | ||||
|         let parsed: R = resp.json().await.context("Failed to parse JSON response")?; | ||||
|         Ok(parsed) | ||||
|     } else { | ||||
|         let text = resp.text().await.unwrap_or_default(); | ||||
|         Err(anyhow!( | ||||
|             "[{}] HTTP {} failed: {}", | ||||
|             api_name, | ||||
|             status, | ||||
|             text | ||||
|         )) | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// --------------------------------------------------------------------------- | ||||
| /// TRAITS | ||||
| /// --------------------------------------------------------------------------- | ||||
|  | ||||
| pub trait HasHttp { | ||||
|     fn http_methods() -> &'static [&'static str]; | ||||
|     fn live_url() -> &'static str; | ||||
|     fn sandbox_url() -> &'static str; | ||||
| } | ||||
|  | ||||
| #[async_trait] | ||||
| pub trait Queryable: Send + Sync + 'static { | ||||
|     type Response: Send + Sync + 'static; | ||||
|  | ||||
|     /// Primary method: send the request using a provided `keys` set. | ||||
|     async fn send_with_keys( | ||||
|         &self, | ||||
|         client: Arc<Client>, | ||||
|         keys: Arc<Keys>, | ||||
|         override_url: Option<&str>, | ||||
|         sandbox: bool, | ||||
|         method_override: Option<&str>, | ||||
|         headers: Option<Vec<(&str, &str)>>, | ||||
|     ) -> Result<Self::Response>; | ||||
|  | ||||
|     /// Convenience wrapper: loads keys automatically from disk. | ||||
|     async fn send( | ||||
|         &self, | ||||
|         client: Arc<Client>, | ||||
|         override_url: Option<&str>, | ||||
|         sandbox: bool, | ||||
|         method_override: Option<&str>, | ||||
|         headers: Option<Vec<(&str, &str)>>, | ||||
|     ) -> Result<Self::Response> { | ||||
|         let keys = load_api_keys().map_err(|e| anyhow!("Failed to load API keys: {}", e))?; | ||||
|         self.send_with_keys(client, Arc::new(keys), override_url, sandbox, method_override, headers) | ||||
|             .await | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[async_trait] | ||||
| pub trait ApiDispatch: Send + Sync + 'static { | ||||
|     type Request; | ||||
|     type Response; | ||||
|     type Error; | ||||
|  | ||||
|     async fn send_all( | ||||
|         &self, | ||||
|         client: Arc<Client>, | ||||
|         keys: Arc<Keys>, | ||||
|         override_url: Option<&str>, | ||||
|         sandbox: bool, | ||||
|         method_override: Option<&str>, | ||||
|     ) -> Result<Vec<anyhow::Result<Value>>>; | ||||
| } | ||||
|  | ||||
| /// --------------------------------------------------------------------------- | ||||
| /// API ENDPOINT STRUCTURE + HELPERS | ||||
| /// --------------------------------------------------------------------------- | ||||
|  | ||||
| #[derive(Debug, Clone, Serialize, Deserialize)] | ||||
| pub struct ApiEndpoint { | ||||
|     pub name: String, | ||||
|     pub operation_id: Option<String>, | ||||
|     pub method: String, | ||||
|     pub path_template: String, | ||||
|     pub path_params: Vec<String>, | ||||
|     pub query_params: Vec<String>, | ||||
|     pub live_url: String, | ||||
|     pub sandbox_url: String, | ||||
|     pub request_type: Option<String>, | ||||
|     pub responses: Vec<(String, String)>, | ||||
|     pub requires_auth: bool, | ||||
| } | ||||
|  | ||||
| pub fn replace_path_params<T: Serialize>(url: &str, params: &[String], req: &T) -> Result<String> { | ||||
|     let mut url = url.to_string(); | ||||
|     let val = serde_json::to_value(req)?; | ||||
|     let obj = val.as_object().ok_or_else(|| anyhow!("Expected object in path param serialization"))?; | ||||
|  | ||||
|     for p in params { | ||||
|         if let Some(v) = obj.get(p) { | ||||
|             let s = v.as_str().map(|s| s.to_string()).unwrap_or_else(|| v.to_string()); | ||||
|             url = url.replace(&format!("{{{}}}", p), &s); | ||||
|         } | ||||
|     } | ||||
|     Ok(url) | ||||
| } | ||||
|  | ||||
| pub fn append_query_params<T: Serialize>(url: &str, params: &[String], req: &T) -> Result<String> { | ||||
|     if params.is_empty() { | ||||
|         return Ok(url.to_string()); | ||||
|     } | ||||
|  | ||||
|     let val = serde_json::to_value(req)?; | ||||
|     let obj = val.as_object().ok_or_else(|| anyhow!("Expected object in query param serialization"))?; | ||||
|  | ||||
|     let mut query_pairs = Vec::new(); | ||||
|     for p in params { | ||||
|         if let Some(v) = obj.get(p) { | ||||
|             let s = v.as_str().map(|s| s.to_string()).unwrap_or_else(|| v.to_string()); | ||||
|             query_pairs.push(format!("{}={}", p, urlencoding::encode(&s))); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if query_pairs.is_empty() { | ||||
|         Ok(url.to_string()) | ||||
|     } else if url.contains('?') { | ||||
|         Ok(format!("{}&{}", url, query_pairs.join("&"))) | ||||
|     } else { | ||||
|         Ok(format!("{}?{}", url, query_pairs.join("&"))) | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// --------------------------------------------------------------------------- | ||||
| /// BLANKET ADAPTER — make any Queryable act as an ApiDispatch | ||||
| /// --------------------------------------------------------------------------- | ||||
|  | ||||
| #[async_trait] | ||||
| impl<T> ApiDispatch for T | ||||
| where | ||||
|     T: Queryable + Serialize + Send + Sync, | ||||
|     T::Response: Serialize + Send + Sync + 'static, | ||||
| { | ||||
|     type Request = (); | ||||
|     type Response = (); | ||||
|     type Error = anyhow::Error; | ||||
|  | ||||
|     async fn send_all( | ||||
|         &self, | ||||
|         client: Arc<Client>, | ||||
|         keys: Arc<Keys>, | ||||
|         override_url: Option<&str>, | ||||
|         sandbox: bool, | ||||
|         method_override: Option<&str>, | ||||
|     ) -> Result<Vec<anyhow::Result<Value>>> { | ||||
|         let transport_res = self | ||||
|             .send_with_keys(client, keys, override_url, sandbox, method_override, None) | ||||
|             .await; | ||||
|  | ||||
|         let res: anyhow::Result<Value> = match transport_res { | ||||
|             Ok(v) => Ok(serde_json::to_value(v).unwrap_or(Value::Null)), | ||||
|             Err(e) => Err(e), | ||||
|         }; | ||||
|  | ||||
|         Ok(vec![res]) | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// --------------------------------------------------------------------------- | ||||
| /// UNIVERSAL HTTP HELPER (for generated API structs) | ||||
| /// --------------------------------------------------------------------------- | ||||
|  | ||||
| pub async fn http_helper<T>( | ||||
|     req: &T, | ||||
|     client: Arc<Client>, | ||||
|     keys: Arc<Keys>, | ||||
|     override_url: Option<String>, | ||||
|     sandbox: bool, | ||||
|     method_override: Option<&str>, | ||||
| ) -> Result<T::Response> | ||||
| where | ||||
|     T: Queryable + Serialize + Sync, | ||||
|     T::Response: Send + Sync + 'static, | ||||
| { | ||||
|     let override_str = override_url.as_deref(); | ||||
|     req.send_with_keys(client, keys, override_str, sandbox, method_override, None) | ||||
|         .await | ||||
| } | ||||
|  | ||||
| /// Execute a list of ApiDispatch objects concurrently using a shared Client and Keys. | ||||
| /// - `dispatches`: vector of dispatchable objects (Arc<dyn ApiDispatch>). | ||||
| /// - `client`: shared reqwest client (Arc) — cloned for each task cheaply. | ||||
| /// - `keys`: shared keys (Arc). | ||||
| /// - `override_url` and `sandbox` forwarded to each dispatch call. | ||||
| /// - `concurrency_limit`: optional cap; if None, uses SEMAPHORE global limit. | ||||
| /// | ||||
| /// Returns flattened Vec<anyhow::Result<Value>> (each dispatch may yield multiple Values). | ||||
| /// Execute a list of ApiDispatch objects concurrently using a shared Client and Keys. | ||||
| pub async fn execute_dispatches_concurrent( | ||||
|     dispatches: Vec< | ||||
|         Arc< | ||||
|             dyn ApiDispatch< | ||||
|                 Request = (), | ||||
|                 Response = serde_json::Value, | ||||
|                 Error = anyhow::Error, | ||||
|             >, | ||||
|         >, | ||||
|     >, | ||||
|     client: Arc<Client>, | ||||
|     keys: Arc<Keys>, | ||||
|     override_url: Option<&str>, | ||||
|     sandbox: bool, | ||||
|     concurrency_limit: Option<usize>, | ||||
| ) -> anyhow::Result<Vec<anyhow::Result<serde_json::Value>>> { | ||||
|     use tokio::task::JoinSet; | ||||
|     use std::sync::Arc; | ||||
|  | ||||
|     // Optional per-call semaphore (create local one if a specific limit is provided) | ||||
|     let local_sem = concurrency_limit.map(|n| Arc::new(Semaphore::new(n))); | ||||
|  | ||||
|     let mut set = JoinSet::new(); | ||||
|  | ||||
|     for dispatch in dispatches.into_iter() { | ||||
|         let client = client.clone(); | ||||
|         let keys = keys.clone(); | ||||
|         let override_url = override_url.map(|s| s.to_string()); | ||||
|         let local_sem_clone = local_sem.clone(); | ||||
|         let dispatch_arc = dispatch.clone(); | ||||
|  | ||||
|         set.spawn(async move { | ||||
|             // ✅ FIXED: use acquire_owned() so permit owns the semaphore Arc | ||||
|             let _permit = if let Some(ls) = local_sem_clone { | ||||
|                 let sem = Arc::clone(&ls); | ||||
|                 let p = sem.acquire_owned().await.expect("local semaphore poisoned"); | ||||
|                 Some(p) | ||||
|             } else { | ||||
|                 let sem = Arc::clone(&SEMAPHORE); | ||||
|                 let p = sem.acquire_owned().await.expect("global semaphore poisoned"); | ||||
|                 Some(p) | ||||
|             }; | ||||
|  | ||||
|             // Call send_all (each dispatch returns Vec<anyhow::Result<Value>>) | ||||
|             let res = dispatch_arc | ||||
|                 .send_all( | ||||
|                     client, | ||||
|                     keys, | ||||
|                     override_url.as_deref(), | ||||
|                     sandbox, | ||||
|                     None, // method_override not used here | ||||
|                 ) | ||||
|                 .await; | ||||
|  | ||||
|             match res { | ||||
|                 Ok(vec_vals) => Ok(vec_vals), | ||||
|                 Err(e) => Err(e), | ||||
|             } | ||||
|         }); | ||||
|     } | ||||
|  | ||||
|     // Collect results, flatten Vec<Vec<Result<Value>>> -> Vec<Result<Value>> | ||||
|     let mut out: Vec<anyhow::Result<serde_json::Value>> = Vec::new(); | ||||
|  | ||||
|     while let Some(join_res) = set.join_next().await { | ||||
|         match join_res { | ||||
|             Ok(inner_res) => match inner_res { | ||||
|                 Ok(vec_vals) => out.extend(vec_vals), | ||||
|                 Err(e) => out.push(Err(e)), | ||||
|             }, | ||||
|             Err(join_err) => { | ||||
|                 out.push(Err(anyhow::anyhow!("task join error: {}", join_err))); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     Ok(out) | ||||
| } | ||||
|  | ||||
		Reference in New Issue
	
	Block a user